@@ -85,6 +85,84 @@ def version1_lte_version2(version1, version2):
85
85
return parse_version (version1 ) < parse_version (version2 ) or parse_version (version1 ) == parse_version (version2 )
86
86
87
87
88
+ class _SafeUnpickler (pickle .Unpickler ):
89
+ def find_class (self , module , name ):
90
+ """Find a class in a module.
91
+
92
+ Args:
93
+ module (str): The module name.
94
+ name (str): The class name.
95
+
96
+ Returns:
97
+ The class if it is safe to unpickle, otherwise raises UnpicklingError.
98
+ """
99
+ # Allowed built-in types
100
+ allowed_builtins = {
101
+ "dict" ,
102
+ "list" ,
103
+ "tuple" ,
104
+ "set" ,
105
+ "frozenset" ,
106
+ "str" ,
107
+ "bytes" ,
108
+ "int" ,
109
+ "float" ,
110
+ "complex" ,
111
+ "bool" ,
112
+ "NoneType" ,
113
+ "slice" ,
114
+ "type" ,
115
+ "object" ,
116
+ "bytearray" ,
117
+ "ellipsis" ,
118
+ "filter" ,
119
+ "map" ,
120
+ "range" ,
121
+ "reversed" ,
122
+ "zip" ,
123
+ }
124
+ if module == "builtins" and name in allowed_builtins :
125
+ return getattr (builtins , name )
126
+
127
+ # Allow collections.OrderedDict
128
+ if module == "collections" and name == "OrderedDict" :
129
+ return OrderedDict
130
+
131
+ # Allow specific neural_compressor classes
132
+ if module .startswith ("neural_compressor" ):
133
+ # Validate class name exists in module
134
+ mod_path = module .replace (".__" , " " ) # Handle submodules
135
+ for part in mod_path .split ():
136
+ try :
137
+ __import__ (part )
138
+ except ImportError :
139
+ continue
140
+ mod = sys .modules .get (module )
141
+ if mod and hasattr (mod , name ):
142
+ return getattr (mod , name )
143
+
144
+ # Allow all numpy classes
145
+ allowed_classes = ["numpy" , "torch" , "tensorflow" , "onnx" , "onnxruntime" ]
146
+ for allowed_class in allowed_classes :
147
+ if module .startswith (allowed_class ):
148
+ try :
149
+ mod = importlib .import_module (module )
150
+ return getattr (mod , name )
151
+ except (ImportError , AttributeError ):
152
+ continue
153
+
154
+ # Block all other classes
155
+ raise pickle .UnpicklingError (f"Unsafe class: { module } .{ name } " )
156
+
157
+
158
+ def _safe_pickle_load (fp ):
159
+ """Load a pickle file safely."""
160
+ try :
161
+ return _SafeUnpickler (fp ).load ()
162
+ except Exception as e :
163
+ raise pickle .UnpicklingError (f"Failed to unpickle file: { e } " )
164
+
165
+
88
166
class LazyImport (object ):
89
167
"""Lazy import python module till use."""
90
168
@@ -398,66 +476,8 @@ def get_tuning_history(history_path):
398
476
Args:
399
477
history_path: The tuning history path, which need users to assign
400
478
"""
401
-
402
- class SafeUnpickler (pickle .Unpickler ):
403
- def find_class (self , module , name ):
404
- # Allowed built-in types
405
- allowed_builtins = {
406
- "dict" ,
407
- "list" ,
408
- "tuple" ,
409
- "set" ,
410
- "frozenset" ,
411
- "str" ,
412
- "bytes" ,
413
- "int" ,
414
- "float" ,
415
- "complex" ,
416
- "bool" ,
417
- "NoneType" ,
418
- "slice" ,
419
- "type" ,
420
- "object" ,
421
- "bytearray" ,
422
- "ellipsis" ,
423
- "filter" ,
424
- "map" ,
425
- "range" ,
426
- "reversed" ,
427
- "zip" ,
428
- }
429
- if module == "builtins" and name in allowed_builtins :
430
- return getattr (builtins , name )
431
-
432
- # Allow collections.OrderedDict
433
- if module == "collections" and name == "OrderedDict" :
434
- return OrderedDict
435
-
436
- # Allow specific neural_compressor classes
437
- if module .startswith ("neural_compressor" ):
438
- # Validate class name exists in module
439
- mod_path = module .replace (".__" , " " ) # Handle submodules
440
- for part in mod_path .split ():
441
- try :
442
- __import__ (part )
443
- except ImportError :
444
- continue
445
- mod = sys .modules .get (module )
446
- if mod and hasattr (mod , name ):
447
- return getattr (mod , name )
448
-
449
- # Allow all numpy classes
450
- if module .startswith ("numpy" ):
451
-
452
- mod = sys .modules .get (module )
453
- if mod and hasattr (mod , name ):
454
- return getattr (mod , name )
455
-
456
- # Block all other classes
457
- raise pickle .UnpicklingError (f"Unsafe class: { module } .{ name } " )
458
-
459
479
with open (history_path , "rb" ) as f :
460
- strategy_object = SafeUnpickler ( f ). load ( )
480
+ strategy_object = _safe_pickle_load ( f )
461
481
tuning_history = strategy_object .tuning_history
462
482
return tuning_history
463
483
@@ -626,7 +646,7 @@ def load_data_from_pkl(path, filename):
626
646
try :
627
647
file_path = os .path .join (path , filename )
628
648
with open (file_path , "rb" ) as fp :
629
- data = pickle . load (fp )
649
+ data = _safe_pickle_load (fp )
630
650
return data
631
651
except FileExistsError :
632
652
logging .getLogger ("neural_compressor" ).info ("Can not open %s." % path )
@@ -933,7 +953,7 @@ def get_tensors_info(workload_location, model_type: str = "optimized") -> dict:
933
953
if not os .path .exists (tensors_path ):
934
954
raise Exception ("Could not find tensor data for specified optimization." )
935
955
with open (tensors_path , "rb" ) as tensors_pickle :
936
- dump_tensor_result = pickle . load (tensors_pickle )
956
+ dump_tensor_result = _safe_pickle_load (tensors_pickle )
937
957
return dump_tensor_result
938
958
939
959
@@ -1159,7 +1179,7 @@ def get_op_list(minmax_file_path, input_model_tensors, optimized_model_tensors)
1159
1179
list of OpEntry elements
1160
1180
"""
1161
1181
with open (minmax_file_path , "rb" ) as min_max_file :
1162
- min_max_data : dict = pickle . load (min_max_file )
1182
+ min_max_data : dict = _safe_pickle_load (min_max_file )
1163
1183
1164
1184
op_list : List [OpEntry ] = []
1165
1185
0 commit comments