11
11
else :
12
12
from astunparse import unparse
13
13
14
- NO_PICKLE_DEBUG = True
14
+ NO_PICKLE_DEBUG = False
15
15
16
16
def extract_weights_from_checkpoint (fb0 ):
17
17
torch_weights = {}
@@ -22,32 +22,40 @@ def extract_weights_from_checkpoint(fb0):
22
22
raise ValueError ("Looks like the checkpoints file is in the wrong format" )
23
23
folder_name = folder_name [0 ].replace ("/data.pkl" , "" ).replace ("\\ data.pkl" , "" )
24
24
with myzip .open (folder_name + '/data.pkl' ) as myfile :
25
- load_instructions , special_instructions = examine_pickle (myfile )
25
+ load_instructions = examine_pickle (myfile )
26
26
for sd_key ,load_instruction in load_instructions .items ():
27
27
with myzip .open (folder_name + f'/data/{ load_instruction .obj_key } ' ) as myfile :
28
28
if (load_instruction .load_from_file_buffer (myfile )):
29
29
torch_weights ['state_dict' ][sd_key ] = load_instruction .get_data ()
30
- if len (special_instructions ) > 0 :
31
- torch_weights ['state_dict' ]['_metadata' ] = {}
32
- for sd_key ,special in special_instructions .items ():
33
- torch_weights ['state_dict' ]['_metadata' ][sd_key ] = special
30
+ # if len(special_instructions) > 0:
31
+ # torch_weights['state_dict']['_metadata'] = {}
32
+ # for sd_key,special in special_instructions.items():
33
+ # torch_weights['state_dict']['_metadata'][sd_key] = special
34
34
return torch_weights
35
35
36
- def examine_pickle (fb0 ):
36
+ def examine_pickle (fb0 , return_special = False ):
37
+ ## return_special:
38
+ ## A rabbit hole I chased trying to debug a model that wouldn't import that had 1300 useless metadata statements
39
+ ## If for some reason it's needed in the future turn it on. It is passed into the class AssignInstructions and
40
+ ## if turned on collect_special will be True
41
+ ##
37
42
43
+ #turn the pickle file into text we can parse
38
44
decompiled = unparse (Pickled .load (fb0 ).ast ).splitlines ()
39
45
40
- ## LINES WE CARE ABOUT:
41
- ## 1: this defines a data file and what kind of data is in it
42
- ## _var1 = _rebuild_tensor_v2(UNPICKLER.persistent_load(('storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
43
- ##
44
- ## 2: this massive line assigns the previous data to dictionary entries
45
- ## _var2262 = {'model.diffusion_model.input_blocks.0.0.weight': _var1, [..... continue for ever]}
46
- ##
47
- ## 3: this massive line also assigns values to keys, but does so differently
48
- ## _var2262.update({ 'cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias': _var2001, [ .... and on and on ]})
49
- ##
50
- ## that's it
46
+ ## Parsing the decompiled pickle:
47
+ ## LINES WE CARE ABOUT:
48
+ ## 1: this defines a data file and what kind of data is in it
49
+ ## _var1 = _rebuild_tensor_v2(UNPICKLER.persistent_load(('storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
50
+ ##
51
+ ## 2: this massive line assigns the previous data to dictionary entries
52
+ ## _var2262 = {'model.diffusion_model.input_blocks.0.0.weight': _var1, [..... continue for ever]}
53
+ ##
54
+ ## 3: this massive line also assigns values to keys, but does so differently
55
+ ## _var2262.update({ 'cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias': _var2001, [ .... and on and on ]})
56
+ ##
57
+ ## that's it
58
+
51
59
# make some REs to match the above.
52
60
re_rebuild = re .compile ('^_var\d+ = _rebuild_tensor_v2\(UNPICKLER\.persistent_load\(\(.*\)$' )
53
61
re_assign = re .compile ('^_var\d+ = \{.*\}$' )
@@ -62,15 +70,15 @@ def examine_pickle(fb0):
62
70
line = line .strip ()
63
71
if re_rebuild .match (line ):
64
72
variable_name , load_instruction = line .split (' = ' , 1 )
65
- load_instructions [variable_name ] = LoadInstruction (line )
73
+ load_instructions [variable_name ] = LoadInstruction (line , variable_name )
66
74
elif re_assign .match (line ):
67
75
assign_instructions .parse_assign_line (line )
68
76
elif re_update .match (line ):
69
77
assign_instructions .parse_update_line (line )
70
78
elif re_ordered_dict .match (line ):
71
79
#do nothing
72
80
continue
73
- else :
81
+ elif NO_PICKLE_DEBUG :
74
82
print (f'unmatched line: { line } ' )
75
83
76
84
@@ -79,14 +87,16 @@ def examine_pickle(fb0):
79
87
80
88
assign_instructions .integrate (load_instructions )
81
89
82
- return assign_instructions .integrated_instructions , assign_instructions .special_instructions
83
- #return assign_instructions.integrated_instructions, {}
90
+ if return_special :
91
+ return assign_instructions .integrated_instructions , assign_instructions .special_instructions
92
+ return assign_instructions .integrated_instructions
84
93
85
94
class AssignInstructions :
86
- def __init__ (self ):
95
+ def __init__ (self , collect_special = False ):
87
96
self .instructions = {}
88
97
self .special_instructions = {}
89
98
self .integrated_instructions = {}
99
+ self .collect_special = collect_special ;
90
100
91
101
def parse_assign_line (self , line ):
92
102
# input looks like this:
@@ -115,7 +125,7 @@ def _add_assignment(self, assignment, re_var):
115
125
if re_var .match (fickling_var ):
116
126
self .instructions [sd_key ] = fickling_var
117
127
return True
118
- else :
128
+ elif self . collect_special :
119
129
# now convert the string "{'version': 1}" into a dictionary {'version': 1}
120
130
entries = fickling_var .split (',' )
121
131
special_dict = {}
@@ -133,14 +143,9 @@ def integrate(self, load_instructions):
133
143
for sd_key , fickling_var in self .instructions .items ():
134
144
if fickling_var in load_instructions :
135
145
self .integrated_instructions [sd_key ] = load_instructions [fickling_var ]
136
- if sd_key in self .special_instructions :
137
- if NO_PICKLE_DEBUG :
138
- print (f"Key found in both load and special instructions: { sd_key } " )
139
146
else :
140
- unfound_keys [sd_key ] = True ;
141
- #for sd_key, special in self.special_instructions.items():
142
- # if sd_key in unfound_keys:
143
- # #todo
147
+ if NO_PICKLE_DEBUG :
148
+ print (f"no load instruction found for { sd_key } " )
144
149
145
150
if NO_PICKLE_DEBUG :
146
151
print (f"Have { len (self .integrated_instructions )} integrated load/assignment instructions" )
@@ -164,14 +169,16 @@ def parse_update_line(self, line):
164
169
print (f"Added/merged { update_count } updates. Total of { len (self .instructions )} assignment instructions" )
165
170
166
171
class LoadInstruction :
167
- def __init__ (self , instruction_string ):
172
+ def __init__ (self , instruction_string , variable_name , extra_debugging = False ):
168
173
self .ident = False
169
174
self .storage_type = False
170
175
self .obj_key = False
171
176
self .location = False #unused
172
177
self .obj_size = False
173
178
self .stride = False #unused
174
- self .data = False ;
179
+ self .data = False
180
+ self .variable_name = variable_name
181
+ self .extra_debugging = extra_debugging
175
182
self .parse_instruction (instruction_string )
176
183
177
184
def parse_instruction (self , instruction_string ):
@@ -185,12 +192,24 @@ def parse_instruction(self, instruction_string):
185
192
#
186
193
# the following comments will show the output of each string manipulation as if it started with the above.
187
194
195
+ if self .extra_debugging :
196
+ print (f"input: '{ instruction_string } '" )
197
+
188
198
garbage , storage_etc = instruction_string .split ('((' , 1 )
189
199
# storage_etc = 'storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
200
+
201
+ if self .extra_debugging :
202
+ print ("storage_etc, reference: ''storage', HalfStorage, '0', 'cpu', 11520)), 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)'" )
203
+ print (f"storage_etc, actual: '{ storage_etc } '\n " )
190
204
191
205
storage , etc = storage_etc .split ('))' , 1 )
192
206
# storage = 'storage', HalfStorage, '0', 'cpu', 11520
193
- # etc = 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
207
+ # etc = , 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
208
+ if self .extra_debugging :
209
+ print ("storage, reference: ''storage', HalfStorage, '0', 'cpu', 11520'" )
210
+ print (f"storage, actual: '{ storage } '\n " )
211
+ print ("etc, reference: ', 0, (320, 4, 3, 3), (36, 9, 3, 1), False, _var0)'" )
212
+ print (f"etc, actual: '{ etc } '\n " )
194
213
195
214
## call below maps to: ('storage', HalfStorage, '0', 'cpu', 11520)
196
215
self .ident , self .storage_type , self .obj_key , self .location , self .obj_size = storage .split (', ' , 4 )
@@ -201,10 +220,16 @@ def parse_instruction(self, instruction_string):
201
220
self .obj_size = int (self .obj_size )
202
221
self .storage_type = self ._torch_to_numpy (self .storage_type )
203
222
223
+ if self .extra_debugging :
224
+ print (f"{ self .ident } , { self .obj_key } , { self .location } , { self .obj_size } , { self .storage_type } " )
225
+
204
226
assert (self .ident == 'storage' )
205
227
206
228
garbage , etc = etc .split (', (' , 1 )
207
229
# etc = 320, 4, 3, 3), (36, 9, 3, 1), False, _var0)
230
+ if self .extra_debugging :
231
+ print ("etc, reference: '320, 4, 3, 3), (36, 9, 3, 1), False, _var0)'" )
232
+ print (f"etc, actual: '{ etc } '\n " )
208
233
209
234
size , stride , garbage = etc .split ('), ' , 2 )
210
235
# size = 320, 4, 3, 3
@@ -223,12 +248,19 @@ def parse_instruction(self, instruction_string):
223
248
else :
224
249
self .stride = tuple (map (int , stride .split (', ' )))
225
250
251
+
252
+ if self .extra_debugging :
253
+ print (f"size: { self .size_tuple } , stride: { self .stride } " )
254
+
226
255
prod_size = prod (self .size_tuple )
227
256
assert prod (self .size_tuple ) == self .obj_size # does the size in the storage call match the size tuple
228
257
229
258
# zero out the data
230
259
self .data = np .zeros (self .size_tuple , dtype = self .storage_type )
231
260
261
+ def sayHi (self ):
262
+ print (f"Hi, I'm an instance of LoadInstruction that will be used to load datafile { self .obj_key } " )
263
+
232
264
@staticmethod
233
265
def _torch_to_numpy (storage_type ):
234
266
if storage_type == 'FloatStorage' :
0 commit comments