Skip to content

Commit 2d09341

Browse files
committed
[Test] add fork cases to RemoveIdentityOps test
1 parent 0a4d5c5 commit 2d09341

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

tests/transformation/test_remove_identity_ops.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def insert_identity_op(model, op, as_first_node, approx):
7777
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"])
7878
@pytest.mark.parametrize("approx", [False, True])
7979
@pytest.mark.parametrize("as_first_node", [False, True])
80-
def test_remove_identity_ops(op, as_first_node, approx):
80+
@pytest.mark.parametrize("fork_before_id", [False, True])
81+
def test_remove_identity_ops(op, as_first_node, approx, fork_before_id):
8182
# set up onnx model
8283
inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [1, 4, 1, 1])
8384
mul = helper.make_tensor_value_info("mul", TensorProto.FLOAT, [])
@@ -114,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx):
114115
model = model.transform(InferShapes())
115116
model = model.transform(InferDataTypes())
116117
idict = {"inp": inp_values}
117-
odict = oxe.execute_onnx(model, idict)
118-
out_before = odict["outp"]
118+
odict_before = oxe.execute_onnx(model, idict)
119119
num_of_nodes_before = len(model.graph.node)
120-
120+
if fork_before_id and not as_first_node:
121+
divout_vi = model.get_tensor_valueinfo("div_out")
122+
model.graph.output.append(divout_vi)
123+
model.graph.value_info.remove(divout_vi)
121124
model = model.transform(RemoveIdentityOps())
122125
num_of_nodes_after = len(model.graph.node)
123126
assert num_of_nodes_before - 1 == num_of_nodes_after
124127

125-
odict = oxe.execute_onnx(model, idict)
126-
out_after = odict["outp"]
127-
assert np.isclose(out_before, out_after, atol=1e-3).all()
128+
odict_after = oxe.execute_onnx(model, idict)
129+
outputs_same = [np.isclose(odict_before[tname], odict_after[tname], atol=1e-3).all() for tname in odict_before.keys()]
130+
assert all(outputs_same)

0 commit comments

Comments
 (0)