34
34
import qonnx .core .onnx_exec as oxe
35
35
from qonnx .core .datatype import DataType
36
36
from qonnx .core .modelwrapper import ModelWrapper
37
+ from qonnx .transformation .general import SortGraph
37
38
from qonnx .transformation .infer_datatypes import InferDataTypes
38
39
from qonnx .transformation .infer_shapes import InferShapes
39
40
from qonnx .transformation .remove import RemoveIdentityOps
40
41
from qonnx .util .basic import gen_finn_dt_tensor , qonnx_make_model
41
42
42
43
43
- def insert_identity_op (model , op , as_first_node , approx ):
44
+ def insert_identity_op (model , op , as_first_node , approx , fork_after_id ):
44
45
kwargs = {}
45
46
inp_ndims = 4 if as_first_node else 2
46
47
if approx :
@@ -71,12 +72,24 @@ def insert_identity_op(model, op, as_first_node, approx):
71
72
model .set_initializer ("value" , val )
72
73
inplist = ["inp" if as_first_node else "div_out" , "value" ]
73
74
identity_node = helper .make_node (op , inplist , ["ident_out" ], ** kwargs )
75
+ old_2nd_node = graph .node [1 ]
76
+ old_last_node = graph .node [- 1 ]
77
+ graph .node .append (identity_node )
78
+ if fork_after_id :
79
+ graph .node .append (helper .make_node ("Mul" , ["ident_out" , "mul2" ], ["mulbranch0_out" ]))
80
+ model .set_initializer ("mul2" , np .asarray ([2.0 ], dtype = np .float32 ))
81
+ graph .node .append (helper .make_node ("Mul" , ["ident_out" , "mul3" ], ["mulbranch1_out" ]))
82
+ model .set_initializer ("mul3" , np .asarray ([3.0 ], dtype = np .float32 ))
83
+ graph .node .append (helper .make_node ("Add" , ["mulbranch0_out" , "mulbranch1_out" ], ["idfork_out" ]))
84
+ subgraph_out = "idfork_out"
85
+ else :
86
+ subgraph_out = "ident_out"
87
+
74
88
if as_first_node :
75
- graph .node .insert (0 , identity_node )
76
- graph .node [1 ].input [0 ] = "ident_out"
89
+ old_2nd_node .input [0 ] = subgraph_out
77
90
else :
78
- graph . node . insert ( 3 , identity_node )
79
- graph . node [ - 1 ]. input [ 0 ] = "ident_out"
91
+ old_last_node . input [ 0 ] = subgraph_out
92
+ model = model . transform ( SortGraph ())
80
93
81
94
return model
82
95
@@ -86,7 +99,10 @@ def insert_identity_op(model, op, as_first_node, approx):
86
99
@pytest .mark .parametrize ("approx" , [False , True ])
87
100
@pytest .mark .parametrize ("as_first_node" , [False , True ])
88
101
@pytest .mark .parametrize ("fork_before_id" , [False , True ])
89
- def test_remove_identity_ops (op , as_first_node , approx , fork_before_id ):
102
+ @pytest .mark .parametrize ("fork_after_id" , [False , True ])
103
+ def test_remove_identity_ops (op , as_first_node , approx , fork_before_id , fork_after_id ):
104
+ if approx and not (op in ["Add" , "Sub" , "Mul" , "Div" ]):
105
+ pytest .skip (f"approx=True not relevant for { op } " )
90
106
# set up onnx model
91
107
inp = helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , [1 , 4 , 1 , 1 ])
92
108
mul = helper .make_tensor_value_info ("mul" , TensorProto .FLOAT , [])
@@ -119,7 +135,8 @@ def test_remove_identity_ops(op, as_first_node, approx, fork_before_id):
119
135
model .set_initializer ("shape" , shape_values )
120
136
model .set_initializer ("div" , div_values )
121
137
model .set_initializer ("matmul" , matmul_values )
122
- insert_identity_op (model , op , as_first_node , approx )
138
+ insert_identity_op (model , op , as_first_node , approx , fork_after_id )
139
+ model = model .transform (InferShapes ())
123
140
model = model .transform (InferShapes ())
124
141
model = model .transform (InferDataTypes ())
125
142
idict = {"inp" : inp_values }
0 commit comments