Skip to content

Commit bb712b9

Browse files
authored
Improvements for CIFAR-100 models (#121)
* Improved Cifar-100 models added. quantize.py changed wrt to the changes in the model parameters defined in the training repo
1 parent d5cbf03 commit bb712b9

9 files changed

+22002
-31239
lines changed

izer/quantize.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ def mean_n_stds_max_abs(t, n_stds=1):
8585
def get_const(_):
8686
return arguments.scale
8787

88-
def get_max_bit_shift(t, return_bit_shift=False):
89-
float_scale = 1.0 / max_max(t)
90-
bit_shift = torch.ceil(torch.log2(float_scale))
88+
def get_max_bit_shift(t, shift_quantile, return_bit_shift=False):
89+
float_scale = 1.0 / torch.quantile(t.abs(), shift_quantile)
90+
bit_shift = torch.floor(torch.log2(float_scale))
9191
if return_bit_shift:
9292
return bit_shift
93-
# else:
93+
9494
return torch.pow(2., bit_shift)
9595

9696
# If not using quantization-aware training (QAT),
@@ -128,6 +128,7 @@ def get_max_bit_shift(t, return_bit_shift=False):
128128
layer, operation, parameter = param_levels[0], None, param_levels[1]
129129
else:
130130
continue
131+
131132
if parameter in ['w_zero_point', 'b_zero_point']:
132133
if checkpoint_state[k].nonzero().numel() != 0:
133134
raise RuntimeError(f"\nParameter {k} is not zero.")
@@ -158,7 +159,23 @@ def get_max_bit_shift(t, return_bit_shift=False):
158159
else:
159160
clamp_bits = tc.dev.DEFAULT_WEIGHT_BITS # Default to 8 bits
160161

161-
factor = 2**(clamp_bits-1) * sat_fn(checkpoint_state[k])
162+
bias_name = '.'.join([layer, operation, 'bias'])
163+
if sat_fn == get_max_bit_shift:
164+
if bias_name in checkpoint_state:
165+
weight_r = torch.flatten(checkpoint_state[k])
166+
bias_r = torch.flatten(checkpoint_state[bias_name])
167+
params_r = torch.cat((weight_r, bias_r))
168+
else:
169+
params_r = torch.flatten(checkpoint_state[k])
170+
171+
shift_quantile_name = '.'.join([layer, 'shift_quantile'])
172+
shift_quantile = 1.0
173+
if shift_quantile_name in checkpoint_state:
174+
shift_quantile = checkpoint_state[shift_quantile_name]
175+
176+
factor = 2**(clamp_bits-1) * get_max_bit_shift(params_r, shift_quantile)
177+
else:
178+
factor = 2**(clamp_bits-1) * sat_fn(checkpoint_state[k])
162179

163180
if arguments.verbose:
164181
print(k, 'avg_max:', unwrap(avg_max(checkpoint_state[k])),
@@ -187,7 +204,6 @@ def get_max_bit_shift(t, return_bit_shift=False):
187204
torch.Tensor([CONV_DEFAULT_WEIGHT_BITS])
188205

189206
# Is there a bias for this layer? Use the same factor as for weights.
190-
bias_name = '.'.join([layer, operation, 'bias'])
191207
if bias_name in checkpoint_state:
192208
bias_bits_name = '.'.join([layer, 'bias_bits'])
193209
if arguments.verbose:
@@ -220,7 +236,8 @@ def get_max_bit_shift(t, return_bit_shift=False):
220236
# Set output shift
221237
if arguments.clip_mode is None:
222238
out_shift_name = '.'.join([layer, 'output_shift'])
223-
out_shift = torch.Tensor([-1 * get_max_bit_shift(checkpoint_state[k], True)])
239+
out_shift = torch.Tensor([-1 * get_max_bit_shift(checkpoint_state[k],
240+
shift_quantile, True)])
224241
new_checkpoint_state[out_shift_name] = out_shift
225242
if new_masks_dict is not None:
226243
new_masks_dict[out_shift_name] = out_shift
4.75 KB
Binary file not shown.

trained/ai85-cifar100-qat-mixed.log

Lines changed: 10528 additions & 31232 deletions
Large diffs are not rendered by default.
4.94 KB
Binary file not shown.

trained/ai85-cifar100-qat8-q.pth.tar

3.47 KB
Binary file not shown.

trained/ai85-cifar100-qat8.pth.tar

3.6 KB
Binary file not shown.
Binary file not shown.

trained/ai85-cifar100-simplenetwide2x-qat-mixed.log

Lines changed: 11450 additions & 0 deletions
Large diffs are not rendered by default.
Binary file not shown.

0 commit comments

Comments
 (0)