Skip to content

Commit eed35ee

Browse files
committed
quantizer shrink corner case fix (sign bit)
1 parent f51f5e6 commit eed35ee

File tree

1 file changed

+25
-7
lines changed

1 file changed

+25
-7
lines changed

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,14 @@ def _(layer: FixedPointQuantizer):
228228
else:
229229
_i += ((lf > f) & (i > li) & k).astype(np.int8)
230230

231-
# Perserve repr boundaries unless overflow never happens
232-
mask = (2.0**i - 2.0**-f >= 2.0**li - 2.0**-lf) & (k >= lk)
233-
i = np.where(mask, _i, i)
234-
f = np.where(mask, _f, f)
235-
k = np.where(mask, _k, k)
236-
231+
if layer.SAT in ('SAT', 'SAT_SM'):
232+
k, i, f = _k, _i, _f
233+
else:
234+
# Perserve repr boundaries unless overflow never happens
235+
mask = (2.0**i - 2.0**-f >= 2.0**li - 2.0**-lf) & (k >= lk)
236+
i = np.where(mask, _i, i)
237+
f = np.where(mask, _f, f)
238+
k = np.where(mask, _k, k)
237239
# Set zeros to zero
238240
idx_zeros = np.where(k + i + f <= 0)
239241
k[idx_zeros] = 0
@@ -583,10 +585,26 @@ def request_kif(layer: Layer) -> tuple[KIF_t, ...]:
583585
return kif
584586

585587

588+
def requested_by_quantizer(layer: Layer) -> bool:
589+
"""Check if the current requested kif is from a quantizer"""
590+
for n in get_output_layers(layer):
591+
if isinstance(n, FixedPointQuantizer):
592+
return True
593+
if isinstance(n, Reshape):
594+
return requested_by_quantizer(n)
595+
return False
596+
597+
586598
def default_register_precision(layer: Layer):
587599
_pk, _pi, _pf = produce_kif(layer) # Maximum possible k,i,f output from this layer
588600
_rk, _ri, _rf = requested_kif(layer) # Maximum possible k,i,f may be utilized by the next layer
589-
_ok, _oi, _of = _rk, np.minimum(_pi, _ri), np.minimum(_pf, _rf)
601+
_oi, _of = np.minimum(_pi, _ri), np.minimum(_pf, _rf)
602+
603+
if requested_by_quantizer(layer):
604+
_ok = _rk
605+
else:
606+
_ok = np.minimum(_pk, _rk)
607+
590608
ok, oi, of = kif_arrs_to_ints((_ok, _oi, _of))
591609

592610
result_t = to_hls4ml_fixed(ok, oi, of, f'{layer.name}_t')

0 commit comments

Comments
 (0)