Skip to content

Commit c7cb71f

Browse files
committed
fix up automatic precision inferrence
1 parent 3a55983 commit c7cb71f

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

hls4ml/model/optimizer/passes/infer_precision.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def _infer_precision(self, node, types_to_infer):
4949
if node_class in ['Conv1D', 'Conv2D', 'PointwiseConv1D', 'PointwiseConv2D', 'Conv2DBatchnorm']:
5050
return self._infer_conv_precision(node, types_to_infer)
5151

52-
if node_class in ['SeparableConv1D', 'SeparableConv2D', 'DepthwiseConv2D']:
52+
if node_class in ['DepthwiseConv1D', 'DepthwiseConv2D']:
53+
return self._infer_depthconv_precision(node, types_to_infer)
54+
55+
if node_class in ['SeparableConv1D', 'SeparableConv2D']:
5356
return self._infer_sepconv_precision(node, types_to_infer)
5457

5558
if node_class in ['Pooling1D', 'Pooling2D']:
@@ -166,6 +169,10 @@ def _infer_conv_precision(self, node, types_to_infer):
166169
n_ops = node.get_attr('n_chan') * node.get_attr('filt_height', 1) * node.get_attr('filt_width')
167170
return self._infer_common_precision(node, types_to_infer, n_ops)
168171

172+
def _infer_depthconv_precision(self, node, types_to_infer):
173+
n_ops = node.get_attr('filt_height', 1) * node.get_attr('filt_width')
174+
return self._infer_common_precision(node, types_to_infer, n_ops)
175+
169176
def _infer_sepconv_precision(self, node, types_to_infer):
170177
inferred_types = []
171178

hls4ml/model/optimizer/passes/seperable_to_dw_conv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def transform(self, model, node):
7171
model.config.parse_name_config(dw_name, dw_layer_config)
7272

7373
# creating the attributes
74-
dw_attributes = {k: node.attributes.get(k, None) for k in SeperableToDepthwiseAndConv._dw_attributes}
74+
dw_attributes = {k: node.attributes[k] for k in SeperableToDepthwiseAndConv._dw_attributes if k in node.attributes}
7575

7676
dw_attributes['use_bias'] = False
7777

@@ -101,7 +101,7 @@ def transform(self, model, node):
101101
model.config.parse_name_config(pw_name, pw_layer_config)
102102

103103
# creating the attributes
104-
pw_attributes = {k: node.attributes.get(k, None) for k in SeperableToDepthwiseAndConv._pw_attributes}
104+
pw_attributes = {k: node.attributes[k] for k in SeperableToDepthwiseAndConv._pw_attributes if k in node.attributes}
105105
pw_attributes['filt_width'] = 1
106106
pw_attributes['filt_height'] = 1
107107
pw_attributes['stride_width'] = 1
@@ -111,7 +111,7 @@ def transform(self, model, node):
111111
pw_attributes['pad_top'] = 0
112112
pw_attributes['pad_bottom'] = 0
113113
pw_attributes['in_width'] = pw_attributes['out_width']
114-
pw_attributes['in_height'] = pw_attributes['out_height']
114+
pw_attributes['in_height'] = pw_attributes.get('out_height', 1)
115115
pw_attributes['n_chan'] = node.get_attr('n_chan') * node.get_attr('depth_multiplier')
116116
pw_attributes['weight_data'] = node.get_attr('pointwise_data')
117117
pw_attributes['weight_quantizer'] = node.get_attr('pointwise_quantizer')

0 commit comments

Comments
 (0)