Skip to content

Commit 62f379b

Browse files
committed
Fix to not break on Papercut model, but still not importing it correctly. Debugging statements enabled.
1 parent 8b73e58 commit 62f379b

File tree

1 file changed

+63
-15
lines changed

1 file changed

+63
-15
lines changed

backends/model_converter/no_pickle_fake_torch.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
else:
1212
from astunparse import unparse
1313

14+
NO_PICKLE_DEBUG = True
1415

1516
def extract_weights_from_checkpoint(fb0):
1617
torch_weights = {}
@@ -21,11 +22,13 @@ def extract_weights_from_checkpoint(fb0):
2122
raise ValueError("Looks like the checkpoints file is in the wrong format")
2223
folder_name = folder_name[0].replace("/data.pkl" , "").replace("\\data.pkl" , "")
2324
with myzip.open(folder_name+'/data.pkl') as myfile:
24-
instructions = examine_pickle(myfile)
25-
for sd_key,load_instruction in instructions.items():
25+
load_instructions, special_instructions = examine_pickle(myfile)
26+
for sd_key,load_instruction in load_instructions.items():
2627
with myzip.open(folder_name + f'/data/{load_instruction.obj_key}') as myfile:
2728
if (load_instruction.load_from_file_buffer(myfile)):
2829
torch_weights['state_dict'][sd_key] = load_instruction.get_data()
30+
for sd_key,special in special_instructions.items():
31+
torch_weights['state_dict'][sd_key] = special
2932
return torch_weights
3033

3134
def examine_pickle(fb0):
@@ -47,6 +50,7 @@ def examine_pickle(fb0):
4750
re_rebuild = re.compile('^_var\d+ = _rebuild_tensor_v2\(UNPICKLER\.persistent_load\(\(.*\)$')
4851
re_assign = re.compile('^_var\d+ = \{.*\}$')
4952
re_update = re.compile('^_var\d+\.update\(\{.*\}\)$')
53+
re_ordered_dict = re.compile('^_var\d+ = OrderedDict\(\)$')
5054

5155
load_instructions = {}
5256
assign_instructions = AssignInstructions()
@@ -61,18 +65,25 @@ def examine_pickle(fb0):
6165
assign_instructions.parse_assign_line(line)
6266
elif re_update.match(line):
6367
assign_instructions.parse_update_line(line)
64-
#else:
65-
# print('kicking rocks')
68+
elif re_ordered_dict.match(line):
69+
#do nothing
70+
continue
71+
else:
72+
print(f'unmatched line: {line}')
73+
6674

67-
#print(f"Found {len(load_instructions)} load instructions")
75+
if NO_PICKLE_DEBUG:
76+
print(f"Found {len(load_instructions)} load instructions")
6877

6978
assign_instructions.integrate(load_instructions)
7079

71-
return assign_instructions.integrated_instructions
80+
#return assign_instructions.integrated_instructions, assign_instructions.special_instructions
81+
return assign_instructions.integrated_instructions, {}
7282

7383
class AssignInstructions:
7484
def __init__(self):
7585
self.instructions = {}
86+
self.special_instructions = {}
7687
self.integrated_instructions = {}
7788

7889
def parse_assign_line(self, line):
@@ -84,20 +95,53 @@ def parse_assign_line(self, line):
8495
assignments = huge_mess.split(', ')
8596
del huge_mess
8697
assignments[-1] = assignments[-1].strip('}')
98+
re_var = re.compile('^_var\d+$')
99+
assignment_count
87100
for a in assignments:
88-
self._add_assignment(a)
89-
#print(f"Added/merged {len(assignments)} assignments. Total of {len(self.instructions)} assignment instructions")
90-
91-
def _add_assignment(self, assignment):
92-
sd_key, fickling_var = assignment.split(': ')
101+
if self._add_assignment(a, re_var):
102+
assignment_count = assignment_count + 1
103+
if NO_PICKLE_DEBUG:
104+
print(f"Added/merged {assignment_count} assignments. Total of {len(self.instructions)} assignment instructions")
105+
106+
def _add_assignment(self, assignment, re_var):
107+
# assignment can look like this:
108+
# 'cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight': _var2009
109+
# or assignment can look like this:
110+
# 'embedding_manager.embedder.transformer.text_model.encoder.layers.6.mlp.fc1': {'version': 1}
111+
sd_key, fickling_var = assignment.split(': ', 1)
93112
sd_key = sd_key.strip("'")
94-
self.instructions[sd_key] = fickling_var
113+
if re_var.match(fickling_var):
114+
self.instructions[sd_key] = fickling_var
115+
return True
116+
else:
117+
# now convert the string "{'version': 1}" into a dictionary {'version': 1}
118+
entries = fickling_var.split(',')
119+
special_dict = {}
120+
for e in entries:
121+
e = e.strip("{}")
122+
k, v = e.split(': ')
123+
k = k.strip("'")
124+
v = v.strip("'")
125+
special_dict[k] = v
126+
self.special_instructions[sd_key] = special_dict
127+
return False
95128

96129
def integrate(self, load_instructions):
130+
unfound_keys = {}
97131
for sd_key, fickling_var in self.instructions.items():
98132
if fickling_var in load_instructions:
99133
self.integrated_instructions[sd_key] = load_instructions[fickling_var]
100-
#print(f"Have {len(self.integrated_instructions)} integrated load/assignment instructions")
134+
if sd_key in self.special_instructions:
135+
if NO_PICKLE_DEBUG:
136+
print(f"Key found in both load and special instructions: {sd_key}")
137+
else:
138+
unfound_keys[sd_key] = True;
139+
#for sd_key, special in self.special_instructions.items():
140+
# if sd_key in unfound_keys:
141+
# #todo
142+
143+
if NO_PICKLE_DEBUG:
144+
print(f"Have {len(self.integrated_instructions)} integrated load/assignment instructions")
101145

102146
def parse_update_line(self, line):
103147
# input looks like:
@@ -109,9 +153,13 @@ def parse_update_line(self, line):
109153
updates = huge_mess.split(', ')
110154
del huge_mess
111155
updates[-1] = updates[-1].strip('})')
156+
re_var = re.compile('^_var\d+$')
157+
update_count = 0
112158
for u in updates:
113-
self._add_assignment(u)
114-
#print(f"Added/merged {len(updates)} updates. Total of {len(self.instructions)} assignment instructions")
159+
if self._add_assignment(u, re_var):
160+
update_count = update_count + 1
161+
if NO_PICKLE_DEBUG:
162+
print(f"Added/merged {update_count} updates. Total of {len(self.instructions)} assignment instructions")
115163

116164
class LoadInstruction:
117165
def __init__(self, instruction_string):

0 commit comments

Comments
 (0)