Skip to content

Commit 5c98a7d

Browse files
committed
Restored fake_torch.py to master
1 parent 85b7f84 commit 5c98a7d

File tree

1 file changed

+4
-20
lines changed

1 file changed

+4
-20
lines changed

backends/model_converter/fake_torch.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
import math
44
import zipfile
5-
from collections import OrderedDict
65

76
def prod(x):
87
return math.prod(x)
@@ -30,38 +29,23 @@ class Dummy:
3029

3130
class MyPickle(pickle.Unpickler):
3231
def find_class(self, module, name):
33-
#making all of the following available will expose a vulnerability from 2011, unclear if patched
34-
#globals, getattr, dict, apply
35-
3632
#print(module, name)
3733
if name == 'FloatStorage':
3834
return np.float32
39-
if name == 'IntStorage':
40-
return np.int32
4135
if name == 'LongStorage':
4236
return np.int64
4337
if name == 'HalfStorage':
4438
return np.float16
45-
if module == 'numpy.core.multiarray' and name == 'scalar':
46-
return np.core.multiarray.scalar
47-
if module == 'numpy' and name == 'dtype':
48-
return np.dtype
4939
if module == "torch._utils":
5040
if name == "_rebuild_tensor_v2":
5141
return HackTensor
5242
elif name == "_rebuild_parameter":
5343
return HackParameter
54-
if module == "collections" and name == "OrderedDict":
55-
return OrderedDict
56-
if module == '_codecs' and name == 'encode':
57-
from _codecs import encode
58-
return encode
59-
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
60-
return Dummy
61-
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
62-
return Dummy
6344
else:
64-
raise pickle.UnpicklingError("'%s.%s' is forbidden" % (module, name))
45+
try:
46+
return pickle.Unpickler.find_class(self, module, name)
47+
except Exception:
48+
return Dummy
6549

6650
def persistent_load(self, pid):
6751
return pid

0 commit comments

Comments
 (0)