Skip to content

Commit 01f8b4f

Browse files
committed
[ChanLast] use util fxn for eltwise-ness check
1 parent 1dfc28a commit 01f8b4f

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

src/qonnx/transformation/channels_last.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast
4242
from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
4343
from qonnx.util.basic import get_by_name
44+
from qonnx.util.onnx import is_eltwise_optype
4445

4546
# Standard ONNX nodes which require a ChannelsLast data format to function properly
4647
_channelsLast_node_types = list(channels_last.custom_op.keys())
@@ -53,10 +54,6 @@
5354
# And modify all values in the same way, if the second tensor is a scalar.
5455
_move_through_nodes_if_scalar = ["Mul", "Div", "Sub", "Add"]
5556

56-
# optypes that operate in an elementwise fashion
57-
# (with numpy-style broadcasting when shapes mismatch)
58-
_eltwise_optypes = ["Relu", "Quant", "Mul", "Div", "Sub", "Add"]
59-
6057

6158
def get_transpose_perms(transpose_node, model):
6259
perm = get_by_name(transpose_node.attribute, "perm")
@@ -441,7 +438,7 @@ def apply(self, model):
441438
successor = successors[0]
442439
transpose_node = node
443440

444-
if successor.op_type in _eltwise_optypes:
441+
if is_eltwise_optype(successor.op_type):
445442
model = move_transpose_past_eltwise(transpose_node, successor, model)
446443
graph_modified = True
447444
return model, graph_modified

0 commit comments

Comments
 (0)