Skip to content

Commit 71ee780

Browse files
committed
[Test] add Identity op case to test_remove_identity_ops
1 parent 8bad7e7 commit 71ee780

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

tests/transformation/test_remove_identity_ops.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,30 @@ def insert_identity_op(model, op, as_first_node, approx):
5151
val = np.asarray([zero_val], dtype=np.float32)
5252
elif op in ["Mul", "Div"]:
5353
val = np.asarray([one_val], dtype=np.float32)
54+
elif op in ["Identity"]:
55+
val = None
5456
else:
5557
return
5658

5759
graph = model.graph
60+
if val is None:
61+
inplist = ["inp" if as_first_node else "div_out"]
62+
else:
63+
model.set_initializer("value", val)
64+
inplist = ["inp" if as_first_node else "div_out", "value"]
65+
identity_node = helper.make_node(op, inplist, ["ident_out"])
5866
if as_first_node:
59-
identity_node = helper.make_node(op, ["inp", "value"], ["ident_out"])
6067
graph.node.insert(0, identity_node)
6168
graph.node[1].input[0] = "ident_out"
6269
else:
63-
identity_node = helper.make_node(op, ["div_out", "value"], ["ident_out"])
6470
graph.node.insert(3, identity_node)
6571
graph.node[-1].input[0] = "ident_out"
66-
model.set_initializer("value", val)
6772

6873
return model
6974

7075

7176
# identity operations to be inserted
72-
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div"])
77+
@pytest.mark.parametrize("op", ["Add", "Sub", "Mul", "Div", "Identity"])
7378
@pytest.mark.parametrize("approx", [False, True])
7479
@pytest.mark.parametrize("as_first_node", [False, True])
7580
def test_remove_identity_ops(op, as_first_node, approx):

0 commit comments

Comments
 (0)