@@ -85,12 +85,12 @@ def mean_n_stds_max_abs(t, n_stds=1):
85
85
def get_const (_ ):
86
86
return arguments .scale
87
87
88
- def get_max_bit_shift (t , return_bit_shift = False ):
89
- float_scale = 1.0 / max_max ( t )
90
- bit_shift = torch .ceil (torch .log2 (float_scale ))
88
+ def get_max_bit_shift (t , shift_quantile , return_bit_shift = False ):
89
+ float_scale = 1.0 / torch . quantile ( t . abs (), shift_quantile )
90
+ bit_shift = torch .floor (torch .log2 (float_scale ))
91
91
if return_bit_shift :
92
92
return bit_shift
93
- # else:
93
+
94
94
return torch .pow (2. , bit_shift )
95
95
96
96
# If not using quantization-aware training (QAT),
@@ -128,6 +128,7 @@ def get_max_bit_shift(t, return_bit_shift=False):
128
128
layer , operation , parameter = param_levels [0 ], None , param_levels [1 ]
129
129
else :
130
130
continue
131
+
131
132
if parameter in ['w_zero_point' , 'b_zero_point' ]:
132
133
if checkpoint_state [k ].nonzero ().numel () != 0 :
133
134
raise RuntimeError (f"\n Parameter { k } is not zero." )
@@ -158,7 +159,23 @@ def get_max_bit_shift(t, return_bit_shift=False):
158
159
else :
159
160
clamp_bits = tc .dev .DEFAULT_WEIGHT_BITS # Default to 8 bits
160
161
161
- factor = 2 ** (clamp_bits - 1 ) * sat_fn (checkpoint_state [k ])
162
+ bias_name = '.' .join ([layer , operation , 'bias' ])
163
+ if sat_fn == get_max_bit_shift :
164
+ if bias_name in checkpoint_state :
165
+ weight_r = torch .flatten (checkpoint_state [k ])
166
+ bias_r = torch .flatten (checkpoint_state [bias_name ])
167
+ params_r = torch .cat ((weight_r , bias_r ))
168
+ else :
169
+ params_r = torch .flatten (checkpoint_state [k ])
170
+
171
+ shift_quantile_name = '.' .join ([layer , 'shift_quantile' ])
172
+ shift_quantile = 1.0
173
+ if shift_quantile_name in checkpoint_state :
174
+ shift_quantile = checkpoint_state [shift_quantile_name ]
175
+
176
+ factor = 2 ** (clamp_bits - 1 ) * get_max_bit_shift (params_r , shift_quantile )
177
+ else :
178
+ factor = 2 ** (clamp_bits - 1 ) * sat_fn (checkpoint_state [k ])
162
179
163
180
if arguments .verbose :
164
181
print (k , 'avg_max:' , unwrap (avg_max (checkpoint_state [k ])),
@@ -187,7 +204,6 @@ def get_max_bit_shift(t, return_bit_shift=False):
187
204
torch .Tensor ([CONV_DEFAULT_WEIGHT_BITS ])
188
205
189
206
# Is there a bias for this layer? Use the same factor as for weights.
190
- bias_name = '.' .join ([layer , operation , 'bias' ])
191
207
if bias_name in checkpoint_state :
192
208
bias_bits_name = '.' .join ([layer , 'bias_bits' ])
193
209
if arguments .verbose :
@@ -220,7 +236,8 @@ def get_max_bit_shift(t, return_bit_shift=False):
220
236
# Set output shift
221
237
if arguments .clip_mode is None :
222
238
out_shift_name = '.' .join ([layer , 'output_shift' ])
223
- out_shift = torch .Tensor ([- 1 * get_max_bit_shift (checkpoint_state [k ], True )])
239
+ out_shift = torch .Tensor ([- 1 * get_max_bit_shift (checkpoint_state [k ],
240
+ shift_quantile , True )])
224
241
new_checkpoint_state [out_shift_name ] = out_shift
225
242
if new_masks_dict is not None :
226
243
new_masks_dict [out_shift_name ] = out_shift
0 commit comments