Skip to content

Commit 31728e2

Browse files
committed
Cleaned up comments; added more classes to find_class so that SD v1.4 now imports.
1 parent ed81f11 commit 31728e2

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

backends/model_converter/fake_torch.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Dummy:
3030

3131
class MyPickle(pickle.Unpickler):
3232
def find_class(self, module, name):
33-
#making the following available will expose a vulnerability from 2011:
33+
#making all of the following available will expose a vulnerability from 2011, unclear if patched
3434
#globals, getattr, dict, apply
3535

3636
#print(module, name)
@@ -42,16 +42,26 @@ def find_class(self, module, name):
4242
return np.int64
4343
if name == 'HalfStorage':
4444
return np.float16
45+
if module == 'numpy.core.multiarray' and name == 'scalar':
46+
return np.core.multiarray.scalar
47+
if module == 'numpy' and name == 'dtype':
48+
return np.dtype
4549
if module == "torch._utils":
4650
if name == "_rebuild_tensor_v2":
4751
return HackTensor
4852
elif name == "_rebuild_parameter":
4953
return HackParameter
5054
if module == "collections" and name == "OrderedDict":
5155
return OrderedDict
56+
if module == '_codecs' and name == 'encode':
57+
from _codecs import encode
58+
return encode
59+
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
60+
return Dummy
61+
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
62+
return Dummy
5263
else:
53-
#return Dummy
54-
raise pickle.UnpicklingError("'%s.%s' is forbidden" % (module, name))
64+
raise pickle.UnpicklingError("'%s.%s' is forbidden" % (module, name))
5565

5666
def persistent_load(self, pid):
5767
return pid

0 commit comments

Comments
 (0)