@@ -84,10 +84,15 @@ def load(
84
84
while seq < len (operator ) and (operator [seq ] == opn .NONE or bypass [seq ]):
85
85
seq += 1
86
86
87
- operation , parameter = k .rsplit (sep = '.' , maxsplit = 1 )
87
+ param_levels = k .rsplit (sep = '.' , maxsplit = 2 )
88
+ if len (param_levels ) == 3 :
89
+ layer , op , parameter = param_levels [0 ], param_levels [1 ], param_levels [2 ]
90
+ elif len (param_levels ) == 2 :
91
+ layer , op , parameter = param_levels [0 ], None , param_levels [1 ]
92
+ else :
93
+ continue
94
+
88
95
if parameter in ['weight' ]:
89
- _ , op = k .split (sep = '.' , maxsplit = 1 )
90
- op = op .rsplit (sep = '.' , maxsplit = 1 )[0 ]
91
96
if layers >= num_conv_layers or seq >= num_conv_layers :
92
97
continue
93
98
@@ -171,11 +176,12 @@ def load(
171
176
weight_keys .append (k )
172
177
173
178
# Is there a bias for this layer?
174
- bias_name = operation + '.bias'
179
+ bias_name = '.' .join ([layer , op , 'bias' ])
180
+ wb_name = '.' .join ([layer , 'weight_bits' ])
175
181
176
182
if bias_name in checkpoint_state and seq not in no_bias :
177
183
w = checkpoint_state [bias_name ].numpy (). \
178
- astype (np .int64 ) // tc . dev . BIAS_DIV
184
+ astype (np .int64 ) // 2 ** ( checkpoint_state [ wb_name ]. numpy (). astype ( np . int64 ) - 1 )
179
185
180
186
if np .all (w == 0 ):
181
187
wprint (f'All bias values for `{ bias_name } ` are zero.' )
@@ -207,7 +213,7 @@ def load(
207
213
208
214
# Not overriding output_shift?
209
215
if output_shift [seq ] is None :
210
- output_shift_name = operation . rsplit ( sep = '.' , maxsplit = 1 )[ 0 ] + '. output_shift'
216
+ output_shift_name = '.' . join ([ layer , ' output_shift'])
211
217
# Is there an output_shift for this layer?
212
218
if output_shift_name in checkpoint_state :
213
219
w = checkpoint_state [output_shift_name ].numpy ().astype (np .int64 )
0 commit comments