@@ -51,25 +51,30 @@ def insert_identity_op(model, op, as_first_node, approx):
51
51
val = np .asarray ([zero_val ], dtype = np .float32 )
52
52
elif op in ["Mul" , "Div" ]:
53
53
val = np .asarray ([one_val ], dtype = np .float32 )
54
+ elif op in ["Identity" ]:
55
+ val = None
54
56
else :
55
57
return
56
58
57
59
graph = model .graph
60
+ if val is None :
61
+ inplist = ["inp" if as_first_node else "div_out" ]
62
+ else :
63
+ model .set_initializer ("value" , val )
64
+ inplist = ["inp" if as_first_node else "div_out" , "value" ]
65
+ identity_node = helper .make_node (op , inplist , ["ident_out" ])
58
66
if as_first_node :
59
- identity_node = helper .make_node (op , ["inp" , "value" ], ["ident_out" ])
60
67
graph .node .insert (0 , identity_node )
61
68
graph .node [1 ].input [0 ] = "ident_out"
62
69
else :
63
- identity_node = helper .make_node (op , ["div_out" , "value" ], ["ident_out" ])
64
70
graph .node .insert (3 , identity_node )
65
71
graph .node [- 1 ].input [0 ] = "ident_out"
66
- model .set_initializer ("value" , val )
67
72
68
73
return model
69
74
70
75
71
76
# identity operations to be inserted
72
- @pytest .mark .parametrize ("op" , ["Add" , "Sub" , "Mul" , "Div" ])
77
+ @pytest .mark .parametrize ("op" , ["Add" , "Sub" , "Mul" , "Div" , "Identity" ])
73
78
@pytest .mark .parametrize ("approx" , [False , True ])
74
79
@pytest .mark .parametrize ("as_first_node" , [False , True ])
75
80
def test_remove_identity_ops (op , as_first_node , approx ):
0 commit comments