|
2 | 2 | import numpy as np
|
3 | 3 | import math
|
4 | 4 | import zipfile
|
5 |
| -from collections import OrderedDict |
6 | 5 |
|
7 | 6 | def prod(x):
|
8 | 7 | return math.prod(x)
|
@@ -30,38 +29,23 @@ class Dummy:
|
30 | 29 |
|
31 | 30 | class MyPickle(pickle.Unpickler):
|
32 | 31 | def find_class(self, module, name):
|
33 |
| - #making all of the following available will expose a vulnerability from 2011, unclear if patched |
34 |
| - #globals, getattr, dict, apply |
35 |
| - |
36 | 32 | #print(module, name)
|
37 | 33 | if name == 'FloatStorage':
|
38 | 34 | return np.float32
|
39 |
| - if name == 'IntStorage': |
40 |
| - return np.int32 |
41 | 35 | if name == 'LongStorage':
|
42 | 36 | return np.int64
|
43 | 37 | if name == 'HalfStorage':
|
44 | 38 | return np.float16
|
45 |
| - if module == 'numpy.core.multiarray' and name == 'scalar': |
46 |
| - return np.core.multiarray.scalar |
47 |
| - if module == 'numpy' and name == 'dtype': |
48 |
| - return np.dtype |
49 | 39 | if module == "torch._utils":
|
50 | 40 | if name == "_rebuild_tensor_v2":
|
51 | 41 | return HackTensor
|
52 | 42 | elif name == "_rebuild_parameter":
|
53 | 43 | return HackParameter
|
54 |
| - if module == "collections" and name == "OrderedDict": |
55 |
| - return OrderedDict |
56 |
| - if module == '_codecs' and name == 'encode': |
57 |
| - from _codecs import encode |
58 |
| - return encode |
59 |
| - if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint': |
60 |
| - return Dummy |
61 |
| - if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint': |
62 |
| - return Dummy |
63 | 44 | else:
|
64 |
| - raise pickle.UnpicklingError("'%s.%s' is forbidden" % (module, name)) |
| 45 | + try: |
| 46 | + return pickle.Unpickler.find_class(self, module, name) |
| 47 | + except Exception: |
| 48 | + return Dummy |
65 | 49 |
|
66 | 50 | def persistent_load(self, pid):
|
67 | 51 | return pid
|
|
0 commit comments