Skip to content

Commit d9269a9

Browse files
authored
Merge pull request #158 from fastmachinelearning/feature/remove_forked_identity
Remove identity nodes with output forking
2 parents 46721d9 + a5d5668 commit d9269a9

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

src/qonnx/transformation/remove.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,20 @@ def apply(self, model):
113113
graph_modified = False
114114
for node in graph.node:
115115
node_ind += 1
116-
if node.op_type in ["Add", "Sub"] and not model.is_fork_node(node) and not model.is_join_node(node):
116+
if node.op_type in ["Add", "Sub"]:
117117
A = model.get_initializer(node.input[1])
118118
if A is not None and np.isclose(A, np.zeros_like(A), atol=self.atol).all():
119119
remove_node_and_rewire(model, node)
120120
graph_modified = True
121121
break
122122

123-
elif node.op_type in ["Mul", "Div"] and not model.is_fork_node(node) and not model.is_join_node(node):
123+
elif node.op_type in ["Mul", "Div"]:
124124
A = model.get_initializer(node.input[1])
125125
if A is not None and np.isclose(A, np.ones_like(A), atol=self.atol).all():
126126
remove_node_and_rewire(model, node)
127127
graph_modified = True
128128
break
129-
elif node.op_type == "Pad" and not model.is_fork_node(node) and not model.is_join_node(node):
129+
elif node.op_type == "Pad":
130130
pads = get_by_name(node.attribute, "pads")
131131
if pads is not None:
132132
# older versions of Pad op specify pads as attribute

tests/transformation/test_remove_identity_ops.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@
3434
import qonnx.core.onnx_exec as oxe
3535
from qonnx.core.datatype import DataType
3636
from qonnx.core.modelwrapper import ModelWrapper
37+
from qonnx.transformation.general import SortGraph
3738
from qonnx.transformation.infer_datatypes import InferDataTypes
3839
from qonnx.transformation.infer_shapes import InferShapes
3940
from qonnx.transformation.remove import RemoveIdentityOps
4041
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model
4142

4243

43-
def insert_identity_op(model, op, as_first_node, approx):
44+
def insert_identity_op(model, op, as_first_node, approx, fork_after_id):
4445
kwargs = {}
4546
inp_ndims = 4 if as_first_node else 2
4647
if approx:
@@ -71,12 +72,24 @@ def insert_identity_op(model, op, as_first_node, approx):
7172
model.set_initializer("value", val)
7273
inplist = ["inp" if as_first_node else "div_out", "value"]
7374
identity_node = helper.make_node(op, inplist, ["ident_out"], **kwargs)
75+
old_2nd_node = graph.node[1]
76+
old_last_node = graph.node[-1]
77+
graph.node.append(identity_node)
78+
if fork_after_id:
79+
graph.node.append(helper.make_node("Mul", ["ident_out", "mul2"], ["mulbranch0_out"]))
80+
model.set_initializer("mul2", np.asarray([2.0], dtype=np.float32))
81+
graph.node.append(helper.make_node("Mul", ["ident_out", "mul3"], ["mulbranch1_out"]))
82+
model.set_initializer("mul3", np.asarray([3.0], dtype=np.float32))
83+
graph.node.append(helper.make_node("Add", ["mulbranch0_out", "mulbranch1_out"], ["idfork_out"]))
84+
subgraph_out = "idfork_out"
85+
else:
86+
subgraph_out = "ident_out"
87+
7488
if as_first_node:
75-
graph.node.insert(0, identity_node)
76-
graph.node[1].input[0] = "ident_out"
89+
old_2nd_node.input[0] = subgraph_out
7790
else:
78-
graph.node.insert(3, identity_node)
79-
graph.node[-1].input[0] = "ident_out"
91+
old_last_node.input[0] = subgraph_out
92+
model = model.transform(SortGraph())
8093

8194
return model
8295

@@ -86,7 +99,10 @@ def insert_identity_op(model, op, as_first_node, approx):
8699
@pytest.mark.parametrize("approx", [False, True])
87100
@pytest.mark.parametrize("as_first_node", [False, True])
88101
@pytest.mark.parametrize("fork_before_id", [False, True])
89-
def test_remove_identity_ops(op, as_first_node, approx, fork_before_id):
102+
@pytest.mark.parametrize("fork_after_id", [False, True])
103+
def test_remove_identity_ops(op, as_first_node, approx, fork_before_id, fork_after_id):
104+
if approx and not (op in ["Add", "Sub", "Mul", "Div"]):
105+
pytest.skip(f"approx=True not relevant for {op}")
90106
# set up onnx model
91107
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
92108
mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, [])
@@ -119,7 +135,8 @@ def test_remove_identity_ops(op, as_first_node, approx, fork_before_id):
119135
model.set_initializer("shape", shape_values)
120136
model.set_initializer("div", div_values)
121137
model.set_initializer("matmul", matmul_values)
122-
insert_identity_op(model, op, as_first_node, approx)
138+
insert_identity_op(model, op, as_first_node, approx, fork_after_id)
139+
model = model.transform(InferShapes())
123140
model = model.transform(InferShapes())
124141
model = model.transform(InferDataTypes())
125142
idict = {"inp": inp_values}

0 commit comments

Comments
 (0)