Skip to content

Commit 2f6489a

Browse files
committed
Added options for collecting metadata that the conversion script doesn't care about as well as fixing error when importing PaperCut model
1 parent c9677e8 commit 2f6489a

File tree

1 file changed

+66
-34
lines changed

1 file changed

+66
-34
lines changed

backends/model_converter/no_pickle_fake_torch.py

Lines changed: 66 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
else:
1212
from astunparse import unparse
1313

14-
NO_PICKLE_DEBUG = True
14+
NO_PICKLE_DEBUG = False
1515

1616
def extract_weights_from_checkpoint(fb0):
1717
torch_weights = {}
@@ -22,32 +22,40 @@ def extract_weights_from_checkpoint(fb0):
2222
raise ValueError("Looks like the checkpoints file is in the wrong format")
2323
folder_name = folder_name[0].replace("/data.pkl" , "").replace("\\data.pkl" , "")
2424
with myzip.open(folder_name+'/data.pkl') as myfile:
25-
load_instructions, special_instructions = examine_pickle(myfile)
25+
load_instructions = examine_pickle(myfile)
2626
for sd_key,load_instruction in load_instructions.items():
2727
with myzip.open(folder_name + f'/data/{load_instruction.obj_key}') as myfile:
2828
if (load_instruction.load_from_file_buffer(myfile)):
2929
torch_weights['state_dict'][sd_key] = load_instruction.get_data()
30-
if len(special_instructions) > 0:
31-
torch_weights['state_dict']['_metadata'] = {}
32-
for sd_key,special in special_instructions.items():
33-
torch_weights['state_dict']['_metadata'][sd_key] = special
30+
#if len(special_instructions) > 0:
31+
# torch_weights['state_dict']['_metadata'] = {}
32+
# for sd_key,special in special_instructions.items():
33+
# torch_weights['state_dict']['_metadata'][sd_key] = special
3434
return torch_weights
3535

36-
def examine_pickle(fb0):
36+
def examine_pickle(fb0, return_special=False):
37+
## return_special:
38+
## A rabbit hole I chased trying to debug a model that wouldn't import that had 1300 useless metadata statements
39+
## If for some reason it's needed in the future turn it on. It is passed into the class AssignInstructions and
40+
## if turned on collect_special will be True
41+
##
3742

43+
#turn the pickle file into text we can parse
3844
decompiled = unparse(Pickled.load(fb0).ast).splitlines()
3945

40-
## LINES WE CARE ABOUT:
41-
## 1: this defines a data file and what kind of data is in it
42-
## _var1 = _rebuild_tensor_v2(UNPICKLER.persistent_load(('storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
43-
##
44-
## 2: this massive line assigns the previous data to dictionary entries
45-
## _var2262 = {'model.diffusion_model.input_blocks.0.0.weight': _var1, [..... continue for ever]}
46-
##
47-
## 3: this massive line also assigns values to keys, but does so differently
48-
## _var2262.update({ 'cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias': _var2001, [ .... and on and on ]})
49-
##
50-
## that's it
46+
## Parsing the decompiled pickle:
47+
## LINES WE CARE ABOUT:
48+
## 1: this defines a data file and what kind of data is in it
49+
## _var1 = _rebuild_tensor_v2(UNPICKLER.persistent_load(('storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
50+
##
51+
## 2: this massive line assigns the previous data to dictionary entries
52+
## _var2262 = {'model.diffusion_model.input_blocks.0.0.weight': _var1, [..... continue for ever]}
53+
##
54+
## 3: this massive line also assigns values to keys, but does so differently
55+
## _var2262.update({ 'cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias': _var2001, [ .... and on and on ]})
56+
##
57+
## that's it
58+
5159
# make some REs to match the above.
5260
re_rebuild = re.compile('^_var\d+ = _rebuild_tensor_v2\(UNPICKLER\.persistent_load\(\(.*\)$')
5361
re_assign = re.compile('^_var\d+ = \{.*\}$')
@@ -62,15 +70,15 @@ def examine_pickle(fb0):
6270
line = line.strip()
6371
if re_rebuild.match(line):
6472
variable_name, load_instruction = line.split(' = ', 1)
65-
load_instructions[variable_name] = LoadInstruction(line)
73+
load_instructions[variable_name] = LoadInstruction(line, variable_name)
6674
elif re_assign.match(line):
6775
assign_instructions.parse_assign_line(line)
6876
elif re_update.match(line):
6977
assign_instructions.parse_update_line(line)
7078
elif re_ordered_dict.match(line):
7179
#do nothing
7280
continue
73-
else:
81+
elif NO_PICKLE_DEBUG:
7482
print(f'unmatched line: {line}')
7583

7684

@@ -79,14 +87,16 @@ def examine_pickle(fb0):
7987

8088
assign_instructions.integrate(load_instructions)
8189

82-
return assign_instructions.integrated_instructions, assign_instructions.special_instructions
83-
#return assign_instructions.integrated_instructions, {}
90+
if return_special:
91+
return assign_instructions.integrated_instructions, assign_instructions.special_instructions
92+
return assign_instructions.integrated_instructions
8493

8594
class AssignInstructions:
86-
def __init__(self):
95+
def __init__(self, collect_special=False):
8796
self.instructions = {}
8897
self.special_instructions = {}
8998
self.integrated_instructions = {}
99+
self.collect_special = collect_special;
90100

91101
def parse_assign_line(self, line):
92102
# input looks like this:
@@ -115,7 +125,7 @@ def _add_assignment(self, assignment, re_var):
115125
if re_var.match(fickling_var):
116126
self.instructions[sd_key] = fickling_var
117127
return True
118-
else:
128+
elif self.collect_special:
119129
# now convert the string "{'version': 1}" into a dictionary {'version': 1}
120130
entries = fickling_var.split(',')
121131
special_dict = {}
@@ -133,14 +143,9 @@ def integrate(self, load_instructions):
133143
for sd_key, fickling_var in self.instructions.items():
134144
if fickling_var in load_instructions:
135145
self.integrated_instructions[sd_key] = load_instructions[fickling_var]
136-
if sd_key in self.special_instructions:
137-
if NO_PICKLE_DEBUG:
138-
print(f"Key found in both load and special instructions: {sd_key}")
139146
else:
140-
unfound_keys[sd_key] = True;
141-
#for sd_key, special in self.special_instructions.items():
142-
# if sd_key in unfound_keys:
143-
# #todo
147+
if NO_PICKLE_DEBUG:
148+
print(f"no load instruction found for {sd_key}")
144149

145150
if NO_PICKLE_DEBUG:
146151
print(f"Have {len(self.integrated_instructions)} integrated load/assignment instructions")
@@ -164,14 +169,16 @@ def parse_update_line(self, line):
164169
print(f"Added/merged {update_count} updates. Total of {len(self.instructions)} assignment instructions")
165170

166171
class LoadInstruction:
167-
def __init__(self, instruction_string):
172+
def __init__(self, instruction_string, variable_name, extra_debugging = False):
168173
self.ident = False
169174
self.storage_type = False
170175
self.obj_key = False
171176
self.location = False #unused
172177
self.obj_size = False
173178
self.stride = False #unused
174-
self.data = False;
179+
self.data = False
180+
self.variable_name = variable_name
181+
self.extra_debugging = extra_debugging
175182
self.parse_instruction(instruction_string)
176183

177184
def parse_instruction(self, instruction_string):
@@ -185,12 +192,24 @@ def parse_instruction(self, instruction_string):
185192
#
186193
# the following comments will show the output of each string manipulation as if it started with the above.
187194

195+
if self.extra_debugging:
196+
print(f"input: '{instruction_string}'")
197+
188198
garbage, storage_etc = instruction_string.split('((', 1)
189199
# storage_etc = 'storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
200+
201+
if self.extra_debugging:
202+
print("storage_etc, reference: ''storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)'")
203+
print(f"storage_etc, actual: '{storage_etc}'\n")
190204

191205
storage, etc = storage_etc.split('))', 1)
192206
# storage = 'storage', HalfStorage, '0', 'cpu', 11520
193-
# etc = 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
207+
# etc = , 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
208+
if self.extra_debugging:
209+
print("storage, reference: ''storage', HalfStorage, '0', 'cpu', 11520'")
210+
print(f"storage, actual: '{storage}'\n")
211+
print("etc, reference: ', 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)'")
212+
print(f"etc, actual: '{etc}'\n")
194213

195214
## call below maps to: ('storage', HalfStorage, '0', 'cpu', 11520)
196215
self.ident, self.storage_type, self.obj_key, self.location, self.obj_size = storage.split(', ', 4)
@@ -201,10 +220,16 @@ def parse_instruction(self, instruction_string):
201220
self.obj_size = int(self.obj_size)
202221
self.storage_type = self._torch_to_numpy(self.storage_type)
203222

223+
if self.extra_debugging:
224+
print(f"{self.ident}, {self.obj_key}, {self.location}, {self.obj_size}, {self.storage_type}")
225+
204226
assert (self.ident == 'storage')
205227

206228
garbage, etc = etc.split(', (', 1)
207229
# etc = 320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
230+
if self.extra_debugging:
231+
print("etc, reference: '320, 4, 3, 3), (36, 9, 3, 1), False, _var0)'")
232+
print(f"etc, actual: '{etc}'\n")
208233

209234
size, stride, garbage = etc.split('), ', 2)
210235
# size = 320, 4, 3, 3
@@ -223,12 +248,19 @@ def parse_instruction(self, instruction_string):
223248
else:
224249
self.stride = tuple(map(int, stride.split(', ')))
225250

251+
252+
if self.extra_debugging:
253+
print(f"size: {self.size_tuple}, stride: {self.stride}")
254+
226255
prod_size = prod(self.size_tuple)
227256
assert prod(self.size_tuple) == self.obj_size # does the size in the storage call match the size tuple
228257

229258
# zero out the data
230259
self.data = np.zeros(self.size_tuple, dtype=self.storage_type)
231260

261+
def sayHi(self):
262+
print(f"Hi, I'm an instance of LoadInstruction that will be used to load datafile {self.obj_key}")
263+
232264
@staticmethod
233265
def _torch_to_numpy(storage_type):
234266
if storage_type == 'FloatStorage':

0 commit comments

Comments
 (0)