Skip to content

Commit 51b46db

Browse files
committed
add pooling precision setting in override
1 parent 4586d4c commit 51b46db

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

hls4ml/model/optimizer/passes/hgq_proxy_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from hls4ml.backends.fpga.fpga_types import NamedType
55
from hls4ml.model.layers import Layer, register_layer
66
from hls4ml.model.optimizer import OptimizerPass, register_pass
7-
from hls4ml.model.types import FixedPrecisionType, WeightVariable
7+
from hls4ml.model.types import FixedPrecisionType, UnspecifiedPrecisionType, WeightVariable
88

99
re_purge_prefix = re.compile(r'(?<!\w)(?:ap_|ac_)', re.IGNORECASE)
1010
re_parse_fixed = re.compile(r'\s*(u?)fixed<([^>]+)>\s*', re.IGNORECASE)
@@ -95,6 +95,12 @@ def transform(self, model, node: FixedPointQuantizer):
9595
# Some layer may be removed by other passes. (e.g. Final flatten layer)
9696
continue
9797
target_node: Layer = model.graph[name]
98+
99+
# Invoke automatic precision derivation for pooling layers accum_t, if undefined.
100+
if 'pool' in target_node.__class__.__name__.lower():
101+
if not userconf_ifdef('accum_t', name, model):
102+
target_node.attributes['accum_t'].precision = UnspecifiedPrecisionType()
103+
98104
for k, v in conf.items():
99105
if userconf_ifdef(k, name, model):
100106
warn(

0 commit comments

Comments
 (0)