Skip to content

Commit c5dccc5

Browse files
authored
Fix code safety check (#1105)
* fix * linting * fix * fix * linting
1 parent c092e08 commit c5dccc5

File tree

18 files changed

+229
-37
lines changed

18 files changed

+229
-37
lines changed

examples/hunyuan_dit/hydit/data_loader/csv2arrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def parse_data(data):
2020

2121
with open(img_path, "rb") as fp:
2222
image = fp.read()
23-
md5 = hashlib.md5(image).hexdigest()
23+
md5 = hashlib.md5(image, usedforsecurity=False).hexdigest()
2424

2525
with Image.open(img_path) as f:
2626
width, height = f.size

examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/opensora/models/causalvideovae/eval/fvd/styleganv/fvd.py
33
import math
44
import os
5+
import shlex
6+
import subprocess
57

68
from mindspore import context, export, load, mint, nn, ops
79

@@ -23,7 +25,7 @@ def load_i3d_pretrained(bs=1):
2325
if not os.path.exists(mindir_filepath):
2426
if not os.path.exists(filepath):
2527
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
26-
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
28+
subprocess.run(shlex.split(f"wget {i3D_WEIGHTS_URL} -O {filepath}"), shell=False)
2729
if not os.path.exists(onnx_filepath):
2830
# convert torch jit model to onnx model
2931
model = torch.jit.load(filepath).eval()

examples/opensora_pku/opensora/models/causalvideovae/eval/fvd/videogpt/fvd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import math
55
import os
6+
import shlex
7+
import subprocess
68

79
import numpy as np
810

@@ -26,7 +28,7 @@ def load_i3d_pretrained():
2628
if not os.path.exists(ms_filepath):
2729
if not os.path.exists(filepath):
2830
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
29-
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
31+
subprocess.run(shlex.split(f"wget {i3D_WEIGHTS_URL} -O {filepath}"), shell=False)
3032
# convert torch ckpt to mindspore ckpt
3133
state_dict = torch.load_state_dict(torch.load(filepath))
3234
raise ValueError("Not converted")

examples/stable_diffusion_v2/ldm/data/sync_data.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,42 @@ def get_rank_id():
3232
return int(global_rank_id)
3333

3434

35+
def is_safe_member(member, target_dir):
36+
member_path = os.path.join(target_dir, member.name)
37+
abs_target_dir = os.path.abspath(target_dir)
38+
abs_member_path = os.path.abspath(member_path)
39+
40+
if not abs_member_path.startswith(abs_target_dir):
41+
return False
42+
43+
if member.name.startswith("/") or ".." in member.name:
44+
return False
45+
46+
if member.islnk() or member.issym():
47+
return False
48+
49+
return True
50+
51+
52+
def safe_members(tar, target_dir):
53+
for member in tar.getmembers():
54+
if is_safe_member(member, target_dir):
55+
yield member
56+
else:
57+
print(f"Discarding unsafe member: {member.name}")
58+
59+
3560
def extract_tar(file_path):
3661
try:
3762
with tarfile.open(file_path, "r") as archive:
3863
if "/" not in archive.getnames()[1]:
3964
subfolder_path = file_path[:-4]
4065
os.makedirs(subfolder_path, exist_ok=True)
41-
archive.extractall(subfolder_path)
66+
archive.extractall(subfolder_path, members=safe_members(archive, subfolder_path))
4267
else:
43-
archive.extractall(os.path.dirname(file_path))
68+
archive.extractall(
69+
os.path.dirname(file_path), members=safe_members(archive, os.path.dirname(file_path))
70+
)
4471
os.remove(file_path)
4572
_logger.info("finish extract: " + file_path)
4673
return True

examples/stable_diffusion_v2/tests/st/test_train_infer_dummy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import shlex
23
import subprocess
34
import sys
45

@@ -19,7 +20,7 @@ def test_train_infer(use_ema, finetuning):
1920
epochs = 1
2021
image_size = 512
2122
if os.path.exists(output_path):
22-
os.system(f"rm {output_path} -rf")
23+
subprocess.run(shlex.split(f"rm {output_path} -rf"), shell=False)
2324
if finetuning == "Vanilla":
2425
use_lora = False
2526
elif finetuning == "LoRA":
@@ -65,7 +66,7 @@ def test_train_infer_DreamBooth(use_ema):
6566
epochs = 1
6667
image_size = 512
6768
if os.path.exists(output_path):
68-
os.system(f"rm {output_path} -rf")
69+
subprocess.run(shlex.split(f"rm {output_path} -rf"), shell=False)
6970
cmd = (
7071
f"python train_dreambooth.py --mode=0 --instance_data_dir={data_path} --instance_prompt='{instance_prompt}' "
7172
f"--model_config={model_config_file} --class_data_dir={data_path} --class_prompt='{instance_prompt}' "

examples/stable_diffusion_v2/tools/eval/fid/utils.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import os
77
import pathlib
88
import ssl
9+
import subprocess
910
import tarfile
11+
import time
1012
import urllib
1113
import urllib.error
1214
import urllib.request
@@ -85,6 +87,39 @@ def detect_file_type(filename: str): # pylint: disable=inconsistent-return-stat
8587
return suffix, None, suffix
8688

8789

90+
def download_weights(url, dest):
91+
start = time.time()
92+
print("downloading url: ", url)
93+
print("downloading to: ", dest)
94+
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
95+
print("downloading took: ", time.time() - start)
96+
97+
98+
def is_safe_member(member, target_dir):
99+
member_path = os.path.join(target_dir, member.name)
100+
abs_target_dir = os.path.abspath(target_dir)
101+
abs_member_path = os.path.abspath(member_path)
102+
103+
if not abs_member_path.startswith(abs_target_dir):
104+
return False
105+
106+
if member.name.startswith("/") or ".." in member.name:
107+
return False
108+
109+
if member.islnk() or member.issym():
110+
return False
111+
112+
return True
113+
114+
115+
def safe_members(tar, target_dir):
116+
for member in tar.getmembers():
117+
if is_safe_member(member, target_dir):
118+
yield member
119+
else:
120+
print(f"Discarding unsafe member: {member.name}")
121+
122+
88123
class Download:
89124
"""Base utility class for downloading."""
90125

@@ -96,7 +131,7 @@ class Download:
96131
@staticmethod
97132
def calculate_md5(file_path: str, chunk_size: int = 1024 * 1024) -> str:
98133
"""Calculate md5 value."""
99-
md5 = hashlib.md5()
134+
md5 = hashlib.md5(usedforsecurity=False)
100135
with open(file_path, "rb") as fp:
101136
for chunk in iter(lambda: fp.read(chunk_size), b""):
102137
md5.update(chunk)
@@ -111,15 +146,15 @@ def extract_tar(from_path: str, to_path: Optional[str] = None, compression: Opti
111146
"""Extract tar format file."""
112147

113148
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
114-
tar.extractall(to_path)
149+
tar.extract_all(tar, members=safe_members(tar, to_path))
115150

116151
@staticmethod
117152
def extract_zip(from_path: str, to_path: Optional[str] = None, compression: Optional[str] = None) -> None:
118153
"""Extract zip format file."""
119154

120155
compression_mode = zipfile.ZIP_BZIP2 if compression else zipfile.ZIP_STORED
121156
with zipfile.ZipFile(from_path, "r", compression=compression_mode) as zip_file:
122-
zip_file.extractall(to_path)
157+
zipfile.extract_all(zip_file, members=safe_members(zip_file, to_path))
123158

124159
def extract_archive(self, from_path: str, to_path: str = None) -> str:
125160
"""Extract and archive from path to path."""

examples/stable_diffusion_v2/utils/download.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""Utility of downloading"""
2+
23
import bz2
34
import gzip
45
import hashlib
56
import logging
67
import os
78
import ssl
9+
import subprocess
810
import tarfile
11+
import time
912
import urllib
1013
import urllib.error
1114
import urllib.request
@@ -32,6 +35,39 @@ def set_default_download_root(path):
3235
_DEFAULT_DOWNLOAD_ROOT = path
3336

3437

38+
def download_weights(url, dest):
39+
start = time.time()
40+
print("downloading url: ", url)
41+
print("downloading to: ", dest)
42+
subprocess.check_call(["pget", "-x", url, dest], close_fds=False)
43+
print("downloading took: ", time.time() - start)
44+
45+
46+
def is_safe_member(member, target_dir):
47+
member_path = os.path.join(target_dir, member.name)
48+
abs_target_dir = os.path.abspath(target_dir)
49+
abs_member_path = os.path.abspath(member_path)
50+
51+
if not abs_member_path.startswith(abs_target_dir):
52+
return False
53+
54+
if member.name.startswith("/") or ".." in member.name:
55+
return False
56+
57+
if member.islnk() or member.issym():
58+
return False
59+
60+
return True
61+
62+
63+
def safe_members(tar, target_dir):
64+
for member in tar.getmembers():
65+
if is_safe_member(member, target_dir):
66+
yield member
67+
else:
68+
print(f"Discarding unsafe member: {member.name}")
69+
70+
3571
class DownLoad:
3672
"""Base utility class for downloading."""
3773

@@ -43,7 +79,7 @@ class DownLoad:
4379
@staticmethod
4480
def calculate_md5(file_path: str, chunk_size: int = 1024 * 1024) -> str:
4581
"""Calculate md5 value."""
46-
md5 = hashlib.md5()
82+
md5 = hashlib.md5(usedforsecurity=False)
4783
with open(file_path, "rb") as fp:
4884
for chunk in iter(lambda: fp.read(chunk_size), b""):
4985
md5.update(chunk)
@@ -58,15 +94,15 @@ def extract_tar(from_path: str, to_path: Optional[str] = None, compression: Opti
5894
"""Extract tar format file."""
5995

6096
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
61-
tar.extractall(to_path)
97+
tar.extract_all(tar, members=safe_members(tar, to_path))
6298

6399
@staticmethod
64100
def extract_zip(from_path: str, to_path: Optional[str] = None, compression: Optional[str] = None) -> None:
65101
"""Extract zip format file."""
66102

67103
compression_mode = zipfile.ZIP_BZIP2 if compression else zipfile.ZIP_STORED
68104
with zipfile.ZipFile(from_path, "r", compression=compression_mode) as zip_file:
69-
zip_file.extractall(to_path)
105+
zipfile.extract_all(zip_file, members=safe_members(zip_file, to_path))
70106

71107
def extract_archive(self, from_path: str, to_path: str = None) -> str:
72108
"""Extract and archive from path to path."""

examples/stable_diffusion_xl/tools/rank_table_generation/hccl_tools.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import json
1717
import os
1818
import re
19+
import shlex
1920
import socket
21+
import subprocess
2022
import sys
2123
from argparse import ArgumentParser
2224
from typing import Any, Dict
@@ -133,8 +135,11 @@ def main():
133135
device_ips: Dict[Any, Any] = {}
134136
try:
135137
for device_id in device_num_list:
136-
ret = os.popen("hccn_tool -i %d -ip -g" % device_id).readlines()
137-
device_ips[str(device_id)] = ret[0].split(":")[1].replace("\n", "")
138+
cmd = "hccn_tool -i %d -ip -g" % device_id
139+
result = subprocess.run(shlex.split(cmd), capture_output=True, text=True, shell=False)
140+
output = result.stdout.strip()
141+
device_ips[str(device_id)] = output.split(":")[1].replace("\n", "")
142+
138143
except IndexError:
139144
print("Failed to call hccn_tool, try to read /etc/hccn.conf instead")
140145
try:

examples/t2v_turbo/utils/download.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,31 @@ def download_weights(url, dest):
4343
print("downloading took: ", time.time() - start)
4444

4545

46+
def is_safe_member(member, target_dir):
47+
member_path = os.path.join(target_dir, member.name)
48+
abs_target_dir = os.path.abspath(target_dir)
49+
abs_member_path = os.path.abspath(member_path)
50+
51+
if not abs_member_path.startswith(abs_target_dir):
52+
return False
53+
54+
if member.name.startswith("/") or ".." in member.name:
55+
return False
56+
57+
if member.islnk() or member.issym():
58+
return False
59+
60+
return True
61+
62+
63+
def safe_members(tar, target_dir):
64+
for member in tar.getmembers():
65+
if is_safe_member(member, target_dir):
66+
yield member
67+
else:
68+
print(f"Discarding unsafe member: {member.name}")
69+
70+
4671
class DownLoad:
4772
"""Base utility class for downloading."""
4873

@@ -54,7 +79,7 @@ class DownLoad:
5479
@staticmethod
5580
def calculate_md5(file_path: str, chunk_size: int = 1024 * 1024) -> str:
5681
"""Calculate md5 value."""
57-
md5 = hashlib.md5()
82+
md5 = hashlib.md5(usedforsecurity=False)
5883
with open(file_path, "rb") as fp:
5984
for chunk in iter(lambda: fp.read(chunk_size), b""):
6085
md5.update(chunk)
@@ -69,15 +94,15 @@ def extract_tar(from_path: str, to_path: Optional[str] = None, compression: Opti
6994
"""Extract tar format file."""
7095

7196
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
72-
tar.extractall(to_path)
97+
tar.extract_all(tar, members=safe_members(tar, to_path))
7398

7499
@staticmethod
75100
def extract_zip(from_path: str, to_path: Optional[str] = None, compression: Optional[str] = None) -> None:
76101
"""Extract zip format file."""
77102

78103
compression_mode = zipfile.ZIP_BZIP2 if compression else zipfile.ZIP_STORED
79104
with zipfile.ZipFile(from_path, "r", compression=compression_mode) as zip_file:
80-
zip_file.extractall(to_path)
105+
zipfile.extract_all(zip_file, members=safe_members(zip_file, to_path))
81106

82107
def extract_archive(self, from_path: str, to_path: str = None) -> str:
83108
"""Extract and archive from path to path."""

examples/venhancer/inference_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import logging
66
import os
7+
import shlex
8+
import shutil
79
import subprocess
810
import tempfile
911

@@ -109,8 +111,10 @@ def save_video(video, save_dir, file_name, fps=16.0):
109111
tmp_path = os.path.join(save_dir, "tmp.mp4")
110112
cmd = f"ffmpeg -y -f image2 -framerate {fps} -i {temp_dir}/%06d.png \
111113
-crf 17 -pix_fmt yuv420p {tmp_path}"
112-
status, output = subprocess.getstatusoutput(cmd)
114+
result = subprocess.run(shlex.split(cmd), capture_output=True, text=True, shell=False)
115+
status = result.returncode
116+
output = result.stdout
113117
if status != 0:
114118
logger.error(f"Save Video Error with {output}")
115-
os.system(f"rm -rf {temp_dir}")
119+
shutil.rmtree(temp_dir)
116120
os.rename(tmp_path, output_path)

0 commit comments

Comments
 (0)