|
41 | 41 | from qonnx.transformation.make_input_chanlast import MakeInputChannelsLast
|
42 | 42 | from qonnx.transformation.quant_constant_folding import FoldTransposeIntoQuantInit
|
43 | 43 | from qonnx.util.basic import get_by_name
|
| 44 | +from qonnx.util.onnx import is_eltwise_optype |
44 | 45 |
|
45 | 46 | # Standard ONNX nodes which require a ChannelsLast data format to function properly
|
46 | 47 | _channelsLast_node_types = list(channels_last.custom_op.keys())
|
|
53 | 54 | # And modify all values in the same way, if the second tensor is a scalar.
|
54 | 55 | _move_through_nodes_if_scalar = ["Mul", "Div", "Sub", "Add"]
|
55 | 56 |
|
56 |
| -# optypes that operate in an elementwise fashion |
57 |
| -# (with numpy-style broadcasting when shapes mismatch) |
58 |
| -_eltwise_optypes = ["Relu", "Quant", "Mul", "Div", "Sub", "Add"] |
59 |
| - |
60 | 57 |
|
61 | 58 | def get_transpose_perms(transpose_node, model):
|
62 | 59 | perm = get_by_name(transpose_node.attribute, "perm")
|
@@ -441,7 +438,7 @@ def apply(self, model):
|
441 | 438 | successor = successors[0]
|
442 | 439 | transpose_node = node
|
443 | 440 |
|
444 |
| - if successor.op_type in _eltwise_optypes: |
| 441 | + if is_eltwise_optype(successor.op_type): |
445 | 442 | model = move_transpose_past_eltwise(transpose_node, successor, model)
|
446 | 443 | graph_modified = True
|
447 | 444 | return model, graph_modified
|
|
0 commit comments