Skip to content

Commit 14da6f5

Browse files
committed
update broadcasting for moving scales for conv
1 parent b0efdd6 commit 14da6f5

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

hls4ml/model/optimizer/passes/move_scales.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,21 @@ def transform(self, model, node):
237237
# zero bias, propagate through, if possible
238238
# (always possible if scale is scalar)
239239
try:
240-
newscale = np.broadcast_to(scale, output.shape) # make sure broadcastable
240+
if scale.ndim > 1:
241+
# undo any broadcast_to
242+
reduced_scale = _remove_redundant_dims(scale)
243+
if reduced_scale.shape[-1] == 1:
244+
reduced_scale = reduced_scale[..., 0]
245+
if node.attributes['n_dim'] == 1:
246+
scale_trans = np.transpose(reduced_scale, (1, 0))
247+
else:
248+
scale_trans = np.transpose(reduced_scale, (1, 2, 0))
249+
newscale = np.broadcast_to(scale_trans, output.shape) # make sure broadcastable
250+
can_propagate = True
251+
else:
252+
newscale = np.broadcast_to(scale, output.shape) # make sure broadcastable
253+
can_propagate = True
241254
newbias = np.zeros(output.shape)
242-
can_propagate = True
243255
except ValueError:
244256
can_propagate = False
245257

@@ -309,3 +321,14 @@ def transform(self, model, node):
309321
new_node = model.make_node('ApplyAlpha', apply_alpha.name, new_attrs, [x for x in node.outputs])
310322
model.insert_node(new_node)
311323
return True
324+
325+
326+
def _remove_redundant_dims(X):
327+
"""This is somewhat of the inverse of broadcast-to. It sets the dimension size to 1 if all values are identical"""
328+
329+
shape = X.shape
330+
for i in range(len(shape)):
331+
reduced = np.expand_dims(np.take(X, 0, axis=i), axis=i)
332+
if np.all(reduced == X):
333+
X = reduced
334+
return X

0 commit comments

Comments
 (0)