Skip to content

Commit a3da6e8

Browse files
yiliu30Copilot
andauthored
Replace all pickle load with safe load (#2252)
* replace all pickle load Signed-off-by: yiliu30 <yi4.liu@intel.com> * Update neural_compressor/utils/utility.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add docstring Signed-off-by: yiliu30 <yi4.liu@intel.com> * Update utility.py --------- Signed-off-by: yiliu30 <yi4.liu@intel.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 1cc864e commit a3da6e8

File tree

1 file changed

+82
-62
lines changed

1 file changed

+82
-62
lines changed

neural_compressor/utils/utility.py

Lines changed: 82 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,84 @@ def version1_lte_version2(version1, version2):
8585
return parse_version(version1) < parse_version(version2) or parse_version(version1) == parse_version(version2)
8686

8787

88+
class _SafeUnpickler(pickle.Unpickler):
89+
def find_class(self, module, name):
90+
"""Find a class in a module.
91+
92+
Args:
93+
module (str): The module name.
94+
name (str): The class name.
95+
96+
Returns:
97+
The class if it is safe to unpickle, otherwise raises UnpicklingError.
98+
"""
99+
# Allowed built-in types
100+
allowed_builtins = {
101+
"dict",
102+
"list",
103+
"tuple",
104+
"set",
105+
"frozenset",
106+
"str",
107+
"bytes",
108+
"int",
109+
"float",
110+
"complex",
111+
"bool",
112+
"NoneType",
113+
"slice",
114+
"type",
115+
"object",
116+
"bytearray",
117+
"ellipsis",
118+
"filter",
119+
"map",
120+
"range",
121+
"reversed",
122+
"zip",
123+
}
124+
if module == "builtins" and name in allowed_builtins:
125+
return getattr(builtins, name)
126+
127+
# Allow collections.OrderedDict
128+
if module == "collections" and name == "OrderedDict":
129+
return OrderedDict
130+
131+
# Allow specific neural_compressor classes
132+
if module.startswith("neural_compressor"):
133+
# Validate class name exists in module
134+
mod_path = module.replace(".__", " ") # Handle submodules
135+
for part in mod_path.split():
136+
try:
137+
__import__(part)
138+
except ImportError:
139+
continue
140+
mod = sys.modules.get(module)
141+
if mod and hasattr(mod, name):
142+
return getattr(mod, name)
143+
144+
# Allow all numpy classes
145+
allowed_classes = ["numpy", "torch", "tensorflow", "onnx", "onnxruntime"]
146+
for allowed_class in allowed_classes:
147+
if module.startswith(allowed_class):
148+
try:
149+
mod = importlib.import_module(module)
150+
return getattr(mod, name)
151+
except (ImportError, AttributeError):
152+
continue
153+
154+
# Block all other classes
155+
raise pickle.UnpicklingError(f"Unsafe class: {module}.{name}")
156+
157+
158+
def _safe_pickle_load(fp):
159+
"""Load a pickle file safely."""
160+
try:
161+
return _SafeUnpickler(fp).load()
162+
except Exception as e:
163+
raise pickle.UnpicklingError(f"Failed to unpickle file: {e}")
164+
165+
88166
class LazyImport(object):
89167
"""Lazy import python module till use."""
90168

@@ -398,66 +476,8 @@ def get_tuning_history(history_path):
398476
Args:
399477
history_path: The tuning history path, which need users to assign
400478
"""
401-
402-
class SafeUnpickler(pickle.Unpickler):
403-
def find_class(self, module, name):
404-
# Allowed built-in types
405-
allowed_builtins = {
406-
"dict",
407-
"list",
408-
"tuple",
409-
"set",
410-
"frozenset",
411-
"str",
412-
"bytes",
413-
"int",
414-
"float",
415-
"complex",
416-
"bool",
417-
"NoneType",
418-
"slice",
419-
"type",
420-
"object",
421-
"bytearray",
422-
"ellipsis",
423-
"filter",
424-
"map",
425-
"range",
426-
"reversed",
427-
"zip",
428-
}
429-
if module == "builtins" and name in allowed_builtins:
430-
return getattr(builtins, name)
431-
432-
# Allow collections.OrderedDict
433-
if module == "collections" and name == "OrderedDict":
434-
return OrderedDict
435-
436-
# Allow specific neural_compressor classes
437-
if module.startswith("neural_compressor"):
438-
# Validate class name exists in module
439-
mod_path = module.replace(".__", " ") # Handle submodules
440-
for part in mod_path.split():
441-
try:
442-
__import__(part)
443-
except ImportError:
444-
continue
445-
mod = sys.modules.get(module)
446-
if mod and hasattr(mod, name):
447-
return getattr(mod, name)
448-
449-
# Allow all numpy classes
450-
if module.startswith("numpy"):
451-
452-
mod = sys.modules.get(module)
453-
if mod and hasattr(mod, name):
454-
return getattr(mod, name)
455-
456-
# Block all other classes
457-
raise pickle.UnpicklingError(f"Unsafe class: {module}.{name}")
458-
459479
with open(history_path, "rb") as f:
460-
strategy_object = SafeUnpickler(f).load()
480+
strategy_object = _safe_pickle_load(f)
461481
tuning_history = strategy_object.tuning_history
462482
return tuning_history
463483

@@ -626,7 +646,7 @@ def load_data_from_pkl(path, filename):
626646
try:
627647
file_path = os.path.join(path, filename)
628648
with open(file_path, "rb") as fp:
629-
data = pickle.load(fp)
649+
data = _safe_pickle_load(fp)
630650
return data
631651
except FileExistsError:
632652
logging.getLogger("neural_compressor").info("Can not open %s." % path)
@@ -933,7 +953,7 @@ def get_tensors_info(workload_location, model_type: str = "optimized") -> dict:
933953
if not os.path.exists(tensors_path):
934954
raise Exception("Could not find tensor data for specified optimization.")
935955
with open(tensors_path, "rb") as tensors_pickle:
936-
dump_tensor_result = pickle.load(tensors_pickle)
956+
dump_tensor_result = _safe_pickle_load(tensors_pickle)
937957
return dump_tensor_result
938958

939959

@@ -1159,7 +1179,7 @@ def get_op_list(minmax_file_path, input_model_tensors, optimized_model_tensors)
11591179
list of OpEntry elements
11601180
"""
11611181
with open(minmax_file_path, "rb") as min_max_file:
1162-
min_max_data: dict = pickle.load(min_max_file)
1182+
min_max_data: dict = _safe_pickle_load(min_max_file)
11631183

11641184
op_list: List[OpEntry] = []
11651185

0 commit comments

Comments
 (0)