@@ -77,7 +77,8 @@ def insert_identity_op(model, op, as_first_node, approx):
77
77
@pytest .mark .parametrize ("op" , ["Add" , "Sub" , "Mul" , "Div" , "Identity" ])
78
78
@pytest .mark .parametrize ("approx" , [False , True ])
79
79
@pytest .mark .parametrize ("as_first_node" , [False , True ])
80
- def test_remove_identity_ops (op , as_first_node , approx ):
80
+ @pytest .mark .parametrize ("fork_before_id" , [False , True ])
81
+ def test_remove_identity_ops (op , as_first_node , approx , fork_before_id ):
81
82
# set up onnx model
82
83
inp = helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , [1 , 4 , 1 , 1 ])
83
84
mul = helper .make_tensor_value_info ("mul" , TensorProto .FLOAT , [])
@@ -114,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx):
114
115
model = model .transform (InferShapes ())
115
116
model = model .transform (InferDataTypes ())
116
117
idict = {"inp" : inp_values }
117
- odict = oxe .execute_onnx (model , idict )
118
- out_before = odict ["outp" ]
118
+ odict_before = oxe .execute_onnx (model , idict )
119
119
num_of_nodes_before = len (model .graph .node )
120
-
120
+ if fork_before_id and not as_first_node :
121
+ divout_vi = model .get_tensor_valueinfo ("div_out" )
122
+ model .graph .output .append (divout_vi )
123
+ model .graph .value_info .remove (divout_vi )
121
124
model = model .transform (RemoveIdentityOps ())
122
125
num_of_nodes_after = len (model .graph .node )
123
126
assert num_of_nodes_before - 1 == num_of_nodes_after
124
127
125
- odict = oxe .execute_onnx (model , idict )
126
- out_after = odict [ "outp" ]
127
- assert np . isclose ( out_before , out_after , atol = 1e-3 ). all ()
128
+ odict_after = oxe .execute_onnx (model , idict )
129
+ outputs_same = [ np . isclose ( odict_before [ tname ], odict_after [ tname ], atol = 1e-3 ). all () for tname in odict_before . keys () ]
130
+ assert all (outputs_same )
0 commit comments