Skip to content

Commit 9125d5a

Browse files
committed
find_class now only allows a very narrow range of items.
added IntStorage to list.
1 parent e6625cd commit 9125d5a

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

backends/model_converter/fake_torch.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import math
44
import zipfile
5+
from collections import OrderedDict
56

67
def prod(x):
78
return math.prod(x)
@@ -29,9 +30,11 @@ class Dummy:
2930

3031
class MyPickle(pickle.Unpickler):
3132
def find_class(self, module, name):
32-
#print(module, name)
33+
print(module, name)
3334
if name == 'FloatStorage':
3435
return np.float32
36+
if name == 'IntStorage':
37+
return np.int32
3538
if name == 'LongStorage':
3639
return np.int64
3740
if name == 'HalfStorage':
@@ -41,11 +44,11 @@ def find_class(self, module, name):
4144
return HackTensor
4245
elif name == "_rebuild_parameter":
4346
return HackParameter
47+
if module == "collections" and name == "OrderedDict":
48+
return OrderedDict
4449
else:
45-
try:
46-
return pickle.Unpickler.find_class(self, module, name)
47-
except Exception:
48-
return Dummy
50+
#return Dummy
51+
raise pickle.UnpicklingError("'%s.%s' is forbidden" % (module, name))
4952

5053
def persistent_load(self, pid):
5154
return pid

0 commit comments

Comments
 (0)