Skip to content

Commit 50dd782

Browse files
authored
Fix bias rounding (#137)
1 parent babf823 commit 50dd782

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

izer/checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ def load(
183183
if bias_name in checkpoint_state and seq not in no_bias:
184184
wb = checkpoint_state[wb_name].numpy().astype(np.int64) \
185185
if wb_name in checkpoint_state else 8
186-
w = checkpoint_state[bias_name].numpy(). \
187-
astype(np.int64) // 2**(wb - 1)
186+
w = (checkpoint_state[bias_name] // 2**(wb - 1)).numpy(). \
187+
astype(np.int64)
188188

189189
if np.all(w == 0):
190190
wprint(f'All bias values for `{bias_name}` are zero.')

izer/quantize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def get_max_bit_shift(t, shift_quantile, return_bit_shift=False):
221221
bias = bias.add(.5).floor().clamp(min=-(2**(clamp_bits+tc.dev.ACTIVATION_BITS-2)),
222222
max=2**(clamp_bits+tc.dev.ACTIVATION_BITS-2)-1)
223223

224+
bias = (bias // 128) * 128
224225
# Store modified bias back into model
225226
new_checkpoint_state[bias_name] = bias
226227

0 commit comments

Comments
 (0)