@@ -30,7 +30,7 @@ class Dummy:
30
30
31
31
class MyPickle (pickle .Unpickler ):
32
32
def find_class (self , module , name ):
33
- #making the following available will expose a vulnerability from 2011:
33
+ #making all of the following available will expose a vulnerability from 2011, unclear if patched
34
34
#globals, getattr, dict, apply
35
35
36
36
#print(module, name)
@@ -42,16 +42,26 @@ def find_class(self, module, name):
42
42
return np .int64
43
43
if name == 'HalfStorage' :
44
44
return np .float16
45
+ if module == 'numpy.core.multiarray' and name == 'scalar' :
46
+ return np .core .multiarray .scalar
47
+ if module == 'numpy' and name == 'dtype' :
48
+ return np .dtype
45
49
if module == "torch._utils" :
46
50
if name == "_rebuild_tensor_v2" :
47
51
return HackTensor
48
52
elif name == "_rebuild_parameter" :
49
53
return HackParameter
50
54
if module == "collections" and name == "OrderedDict" :
51
55
return OrderedDict
56
+ if module == '_codecs' and name == 'encode' :
57
+ from _codecs import encode
58
+ return encode
59
+ if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint' :
60
+ return Dummy
61
+ if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint' :
62
+ return Dummy
52
63
else :
53
- #return Dummy
54
- raise pickle .UnpicklingError ("'%s.%s' is forbidden" % (module , name ))
64
+ raise pickle .UnpicklingError ("'%s.%s' is forbidden" % (module , name ))
55
65
56
66
def persistent_load (self , pid ):
57
67
return pid
0 commit comments