11
11
else :
12
12
from astunparse import unparse
13
13
14
+ NO_PICKLE_DEBUG = True
14
15
15
16
def extract_weights_from_checkpoint (fb0 ):
16
17
torch_weights = {}
@@ -21,11 +22,13 @@ def extract_weights_from_checkpoint(fb0):
21
22
raise ValueError ("Looks like the checkpoints file is in the wrong format" )
22
23
folder_name = folder_name [0 ].replace ("/data.pkl" , "" ).replace ("\\ data.pkl" , "" )
23
24
with myzip .open (folder_name + '/data.pkl' ) as myfile :
24
- instructions = examine_pickle (myfile )
25
- for sd_key ,load_instruction in instructions .items ():
25
+ load_instructions , special_instructions = examine_pickle (myfile )
26
+ for sd_key ,load_instruction in load_instructions .items ():
26
27
with myzip .open (folder_name + f'/data/{ load_instruction .obj_key } ' ) as myfile :
27
28
if (load_instruction .load_from_file_buffer (myfile )):
28
29
torch_weights ['state_dict' ][sd_key ] = load_instruction .get_data ()
30
+ for sd_key ,special in special_instructions .items ():
31
+ torch_weights ['state_dict' ][sd_key ] = special
29
32
return torch_weights
30
33
31
34
def examine_pickle (fb0 ):
@@ -47,6 +50,7 @@ def examine_pickle(fb0):
47
50
re_rebuild = re .compile ('^_var\d+ = _rebuild_tensor_v2\(UNPICKLER\.persistent_load\(\(.*\)$' )
48
51
re_assign = re .compile ('^_var\d+ = \{.*\}$' )
49
52
re_update = re .compile ('^_var\d+\.update\(\{.*\}\)$' )
53
+ re_ordered_dict = re .compile ('^_var\d+ = OrderedDict\(\)$' )
50
54
51
55
load_instructions = {}
52
56
assign_instructions = AssignInstructions ()
@@ -61,18 +65,25 @@ def examine_pickle(fb0):
61
65
assign_instructions .parse_assign_line (line )
62
66
elif re_update .match (line ):
63
67
assign_instructions .parse_update_line (line )
64
- #else:
65
- # print('kicking rocks')
68
+ elif re_ordered_dict .match (line ):
69
+ #do nothing
70
+ continue
71
+ else :
72
+ print (f'unmatched line: { line } ' )
73
+
66
74
67
- #print(f"Found {len(load_instructions)} load instructions")
75
+ if NO_PICKLE_DEBUG :
76
+ print (f"Found { len (load_instructions )} load instructions" )
68
77
69
78
assign_instructions .integrate (load_instructions )
70
79
71
- return assign_instructions .integrated_instructions
80
+ #return assign_instructions.integrated_instructions, assign_instructions.special_instructions
81
+ return assign_instructions .integrated_instructions , {}
72
82
73
83
class AssignInstructions :
74
84
def __init__ (self ):
75
85
self .instructions = {}
86
+ self .special_instructions = {}
76
87
self .integrated_instructions = {}
77
88
78
89
def parse_assign_line (self , line ):
@@ -84,20 +95,53 @@ def parse_assign_line(self, line):
84
95
assignments = huge_mess .split (', ' )
85
96
del huge_mess
86
97
assignments [- 1 ] = assignments [- 1 ].strip ('}' )
98
+ re_var = re .compile ('^_var\d+$' )
99
+ assignment_count
87
100
for a in assignments :
88
- self ._add_assignment (a )
89
- #print(f"Added/merged {len(assignments)} assignments. Total of {len(self.instructions)} assignment instructions")
90
-
91
- def _add_assignment (self , assignment ):
92
- sd_key , fickling_var = assignment .split (': ' )
101
+ if self ._add_assignment (a , re_var ):
102
+ assignment_count = assignment_count + 1
103
+ if NO_PICKLE_DEBUG :
104
+ print (f"Added/merged { assignment_count } assignments. Total of { len (self .instructions )} assignment instructions" )
105
+
106
+ def _add_assignment (self , assignment , re_var ):
107
+ # assignment can look like this:
108
+ # 'cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight': _var2009
109
+ # or assignment can look like this:
110
+ # 'embedding_manager.embedder.transformer.text_model.encoder.layers.6.mlp.fc1': {'version': 1}
111
+ sd_key , fickling_var = assignment .split (': ' , 1 )
93
112
sd_key = sd_key .strip ("'" )
94
- self .instructions [sd_key ] = fickling_var
113
+ if re_var .match (fickling_var ):
114
+ self .instructions [sd_key ] = fickling_var
115
+ return True
116
+ else :
117
+ # now convert the string "{'version': 1}" into a dictionary {'version': 1}
118
+ entries = fickling_var .split (',' )
119
+ special_dict = {}
120
+ for e in entries :
121
+ e = e .strip ("{}" )
122
+ k , v = e .split (': ' )
123
+ k = k .strip ("'" )
124
+ v = v .strip ("'" )
125
+ special_dict [k ] = v
126
+ self .special_instructions [sd_key ] = special_dict
127
+ return False
95
128
96
129
def integrate (self , load_instructions ):
130
+ unfound_keys = {}
97
131
for sd_key , fickling_var in self .instructions .items ():
98
132
if fickling_var in load_instructions :
99
133
self .integrated_instructions [sd_key ] = load_instructions [fickling_var ]
100
- #print(f"Have {len(self.integrated_instructions)} integrated load/assignment instructions")
134
+ if sd_key in self .special_instructions :
135
+ if NO_PICKLE_DEBUG :
136
+ print (f"Key found in both load and special instructions: { sd_key } " )
137
+ else :
138
+ unfound_keys [sd_key ] = True ;
139
+ #for sd_key, special in self.special_instructions.items():
140
+ # if sd_key in unfound_keys:
141
+ # #todo
142
+
143
+ if NO_PICKLE_DEBUG :
144
+ print (f"Have { len (self .integrated_instructions )} integrated load/assignment instructions" )
101
145
102
146
def parse_update_line (self , line ):
103
147
# input looks like:
@@ -109,9 +153,13 @@ def parse_update_line(self, line):
109
153
updates = huge_mess .split (', ' )
110
154
del huge_mess
111
155
updates [- 1 ] = updates [- 1 ].strip ('})' )
156
+ re_var = re .compile ('^_var\d+$' )
157
+ update_count = 0
112
158
for u in updates :
113
- self ._add_assignment (u )
114
- #print(f"Added/merged {len(updates)} updates. Total of {len(self.instructions)} assignment instructions")
159
+ if self ._add_assignment (u , re_var ):
160
+ update_count = update_count + 1
161
+ if NO_PICKLE_DEBUG :
162
+ print (f"Added/merged { update_count } updates. Total of { len (self .instructions )} assignment instructions" )
115
163
116
164
class LoadInstruction :
117
165
def __init__ (self , instruction_string ):
0 commit comments