@@ -42,167 +42,143 @@ class LowerConvsToMatMul(Transformation):
42
42
def apply (self , model ):
43
43
model = model .transform (ExtractBiasFromConv ())
44
44
graph = model .graph
45
- node_ind = 0
46
45
graph_modified = False
47
- for n in graph .node :
48
- node_ind += 1
49
- if n .op_type == "Conv" :
50
- if len (n .input ) == 3 :
51
- warnings .warn ("Found Conv node with bias, skipping" )
52
- continue
53
- cnv_input = n .input [0 ]
54
- cnv_output = n .output [0 ]
55
- idt = model .get_tensor_datatype (cnv_input )
56
- odt = model .get_tensor_datatype (cnv_output )
57
- # extract conv parameters
58
- k = get_by_name (n .attribute , "kernel_shape" ).ints
59
- k_h = k [0 ]
60
- k_w = k [1 ]
61
- stride_h = get_by_name (n .attribute , "strides" ).ints [0 ]
62
- stride_w = get_by_name (n .attribute , "strides" ).ints [1 ]
63
- group = get_by_name (n .attribute , "group" ).i
64
- weight_name = n .input [1 ]
65
- W_conv = model .get_initializer (weight_name )
66
- ifm_ch = model .get_tensor_shape (n .input [0 ])[1 ] # assume NCHW
67
- ofm_ch = model .get_tensor_shape (n .output [0 ])[1 ] # assume NCHW
68
- ifm_dim_h = model .get_tensor_shape (n .input [0 ])[2 ] # assume NCHW
69
- ifm_dim_w = model .get_tensor_shape (n .input [0 ])[3 ]
70
- ofm_dim_h = model .get_tensor_shape (n .output [0 ])[2 ] # assume NCHW
71
- ofm_dim_w = model .get_tensor_shape (n .output [0 ])[3 ]
72
- dilation_attr = get_by_name (n .attribute , "dilations" )
73
- if dilation_attr is not None :
74
- dilation = dilation_attr .ints
75
- else :
76
- dilation = [1 , 1 ] # default value
77
- # handle both auto_pad and explicit padding
78
- auto_pad = get_by_name (n .attribute , "auto_pad" )
79
- if auto_pad is not None :
80
- # find equivalent specified padding
81
- auto_pad = auto_pad .s .decode ("utf-8" )
82
- if auto_pad == "NOTSET" :
83
- # use specified padding
84
- pad = get_by_name (n .attribute , "pads" ).ints
85
- else :
86
- pad = auto_pad_to_explicit_padding (
87
- auto_pad ,
88
- ifm_dim_h ,
89
- ifm_dim_w ,
90
- k_h ,
91
- k_w ,
92
- stride_h ,
93
- stride_w ,
94
- len (model .get_tensor_shape (n .input [0 ])) - 2 ,
95
- )
96
- else :
97
- # use specified padding
98
- pad = get_by_name (n .attribute , "pads" ).ints
99
-
100
- # If len(pad) == 2, assume no padding for other dimension
101
- if len (pad ) == 2 : # only one dimension should be padded
102
- assert ifm_dim_h == 1 or ifm_dim_w == 1 , "Padding is assumed to be 1D, image is 2D"
103
-
104
- # if depthwise conv create sparse matrix and variable "dw"
105
- # to store as attribute in Im2Col that indicates that the created
46
+ for node_ind , node in enumerate (graph .node , start = 1 ):
47
+ if node .op_type != "Conv" :
48
+ continue
49
+
50
+ if len (node .input ) == 3 :
51
+ warnings .warn ("Found Conv node with bias, skipping" )
52
+ continue
53
+
54
+ # extract parameters of node
55
+ (cnv_input , cnv_output , cnv_input_datatype , cnv_output_datatype ,
56
+ k_h , k_w , stride_h , stride_w , group , weight_name , W_conv , ifm_ch ,
57
+ ofm_ch , ifm_dim_h , ifm_dim_w , ofm_dim_h , ofm_dim_w , dilation , pad ) = \
58
+ self .extract_conv_params (model , node )
59
+
60
+ # if depthwise conv create sparse matrix and variable "dw"
61
+ # to store as attribute in Im2Col that indicates that the created
62
+ # Im2Col node belongs to a depthwise convolution
63
+ dw = False
64
+ if group == ifm_ch and ofm_ch == ifm_ch :
65
+ W_sparse = np .zeros ((ofm_ch , ifm_ch , k_h , k_w )) # (OFM, IFM, k_H, k_W)
66
+ for ch in range (ifm_ch ):
67
+ W_sparse [ch ][ch ] = W_conv [ch ][0 ] # W_conv = [OFM, IFM, k_H, k_W]
68
+ W_conv = W_sparse .astype (np .float32 )
69
+ # we need to store information of the
70
+ # sparsity of the weight matrix. For this
71
+ # we use the sparsity annotation of the
72
+ # weight tensor
73
+ sparsity = {"dw" : {"kernel_shape" : [k_h , k_w ]}}
74
+ model .set_tensor_sparsity (weight_name , sparsity )
75
+ # additionally create variable "dw" to store
76
+ # as attribute in Im2Col that indicates that the created
106
77
# Im2Col node belongs to a depthwise convolution
107
- dw = False
108
- if group == ifm_ch and ofm_ch == ifm_ch :
109
- W_sparse = np .zeros ((ofm_ch , ifm_ch , k_h , k_w )) # (OFM, IFM, k_H, k_W)
110
- for ch in range (ifm_ch ):
111
- W_sparse [ch ][ch ] = W_conv [ch ][0 ] # W_conv = [OFM, IFM, k_H, k_W]
112
- W_conv = W_sparse .astype (np .float32 )
113
- # we need to store information of the
114
- # sparsity of the weight matrix. For this
115
- # we use the sparsity annotation of the
116
- # weight tensor
117
- sparsity = {"dw" : {"kernel_shape" : [k_h , k_w ]}}
118
- model .set_tensor_sparsity (weight_name , sparsity )
119
- # additionally create variable "dw" to store
120
- # as attribute in Im2Col that indicates that the created
121
- # Im2Col node belongs to a depthwise convolution
122
- dw = True
123
-
124
- # reuse conv weights for new matmul weights
125
- # conv weights are [OFM][IFM][k][k]
126
- # first convert to [OFM][k][k][IFM] (to remain compatible with
127
- # finn-hlslib and how it does im2col/sliding window)
128
- W_matmul = W_conv .transpose (0 , 2 , 3 , 1 ) # W_conv = [OFM, IFM, k_H, k_W]
129
- # reshape into [OFM][k*k*IFM] matrix
130
- W_matmul = W_matmul .reshape (ofm_ch , ifm_ch * k_h * k_w )
131
- # transpose to get ONNX-compatible [k*k*IFM][OFM] matrix
132
- W_matmul = W_matmul .T
133
- model .set_initializer (weight_name , W_matmul )
134
-
135
- # create new intermediate values
136
- inp_trans_out = helper .make_tensor_value_info (
137
- model .make_new_valueinfo_name (),
138
- TensorProto .FLOAT ,
139
- (1 , ifm_dim_h , ifm_dim_w , ifm_ch ), # NHWC
78
+ dw = True
79
+
80
+ # reuse conv weights for new matmul weights
81
+ # conv weights are [OFM][IFM][k][k]
82
+ # first convert to [OFM][k_h][k_w][IFM] (to remain compatible with
83
+ # finn-hlslib and how it does im2col/sliding window)
84
+ W_matmul = W_conv .transpose (0 , 2 , 3 , 1 ) # W_conv = [OFM, IFM, k_H, k_W]
85
+ # reshape into [OFM][k_h*k_w*IFM] matrix
86
+ W_matmul = W_matmul .reshape (ofm_ch , ifm_ch * k_h * k_w )
87
+ # transpose to get ONNX-compatible [k_h*k_w*IFM][OFM] matrix
88
+ W_matmul = W_matmul .T
89
+ model .set_initializer (weight_name , W_matmul )
90
+
91
+ # create new intermediate values
92
+ inp_trans_out = helper .make_tensor_value_info (
93
+ model .make_new_valueinfo_name (),
94
+ TensorProto .FLOAT ,
95
+ (1 , ifm_dim_h , ifm_dim_w , ifm_ch ), # NHWC
96
+ )
97
+ graph .value_info .append (inp_trans_out )
98
+ inp_trans_out = inp_trans_out .name
99
+ model .set_tensor_datatype (inp_trans_out , cnv_input_datatype )
100
+
101
+ # k_h=k_w==1: pointwise convolution, thus no im2col needed
102
+ need_im2col = any (p != 0 for p in pad ) or k_h != 1 or k_w != 1 or stride_h != 1 or stride_w != 1
103
+
104
+ # create new intermediate values
105
+ matmul_out = helper .make_tensor_value_info (
106
+ model .make_new_valueinfo_name (), TensorProto .FLOAT , (1 , ofm_dim_h , ofm_dim_w , ofm_ch )
107
+ )
108
+ graph .value_info .append (matmul_out )
109
+ matmul_out = matmul_out .name
110
+ model .set_tensor_datatype (matmul_out , cnv_output_datatype )
111
+
112
+ # create new nodes
113
+ # NCHW -> NHWC
114
+ inp_trans_node = helper .make_node ("Transpose" , [cnv_input ], [inp_trans_out ], perm = [0 , 2 , 3 , 1 ])
115
+ nodes_to_insert = [inp_trans_node ]
116
+
117
+ if need_im2col :
118
+ im2col_out = helper .make_tensor_value_info (
119
+ model .make_new_valueinfo_name (), TensorProto .FLOAT , (1 , ofm_dim_h , ofm_dim_w , ifm_ch * k_h * k_w )
140
120
)
141
- graph .value_info .append (inp_trans_out )
142
- inp_trans_out = inp_trans_out .name
143
- model .set_tensor_datatype (inp_trans_out , idt )
144
-
145
- need_im2col = True
146
- if all (p == 0 for p in pad ):
147
- padding = 0
148
-
149
- # k_h=k_w==1: pointwise convolution, thus no im2col needed
150
- if k_h == 1 and k_w == 1 and padding == 0 and stride_h == 1 and stride_w == 1 :
151
- need_im2col = False
152
-
153
- if need_im2col :
154
- im2col_out = helper .make_tensor_value_info (
155
- model .make_new_valueinfo_name (),
156
- TensorProto .FLOAT ,
157
- (1 , ofm_dim_h , ofm_dim_w , ifm_ch * k_h * k_w ),
158
- )
159
- graph .value_info .append (im2col_out )
160
- im2col_out = im2col_out .name
161
- model .set_tensor_datatype (im2col_out , idt )
162
-
163
- matmul_out = helper .make_tensor_value_info (
164
- model .make_new_valueinfo_name (),
165
- TensorProto .FLOAT ,
166
- (1 , ofm_dim_h , ofm_dim_w , ofm_ch ),
121
+ graph .value_info .append (im2col_out )
122
+ im2col_out = im2col_out .name
123
+ model .set_tensor_datatype (im2col_out , cnv_input_datatype )
124
+ im2col_node = helper .make_node (
125
+ "Im2Col" , [inp_trans_out ], [im2col_out ], domain = "qonnx.custom_op.general" ,
126
+ stride = [stride_h , stride_w ], kernel_size = [k_h , k_w ], pad_amount = pad ,
127
+ input_shape = "(1,{},{},{})" .format (ifm_dim_h , ifm_dim_w , ifm_ch ), depthwise = dw , dilations = dilation
167
128
)
168
- graph .value_info .append (matmul_out )
169
- matmul_out = matmul_out .name
170
- model .set_tensor_datatype (matmul_out , odt )
171
-
172
- # create new nodes
173
- # NCHW -> NHWC
174
- inp_trans_node = helper .make_node ("Transpose" , [cnv_input ], [inp_trans_out ], perm = [0 , 2 , 3 , 1 ])
175
- # lower input tensor
176
- matmul_input = inp_trans_out
177
- if need_im2col :
178
- matmul_input = im2col_out
179
- im2col_node = helper .make_node (
180
- "Im2Col" ,
181
- [inp_trans_out ],
182
- [im2col_out ],
183
- domain = "qonnx.custom_op.general" ,
184
- stride = [stride_h , stride_w ],
185
- kernel_size = [k_h , k_w ],
186
- pad_amount = pad ,
187
- input_shape = "(1,{},{},{})" .format (ifm_dim_h , ifm_dim_w , ifm_ch ),
188
- depthwise = dw ,
189
- dilations = dilation ,
190
- )
191
-
192
- # do matmul
193
- matmul_node = helper .make_node ("MatMul" , [matmul_input , weight_name ], [matmul_out ])
194
- # NHWC -> NCHW
195
- out_trans_node = helper .make_node ("Transpose" , [matmul_out ], [cnv_output ], perm = [0 , 3 , 1 , 2 ])
196
- # insert nodes where the conv is to preserve topological ordering
197
- graph .node .insert (node_ind , inp_trans_node )
198
- if need_im2col :
199
- graph .node .insert (node_ind + 1 , im2col_node )
200
- graph .node .insert (node_ind + 2 , matmul_node )
201
- graph .node .insert (node_ind + 3 , out_trans_node )
202
- else :
203
- graph .node .insert (node_ind + 1 , matmul_node )
204
- graph .node .insert (node_ind + 2 , out_trans_node )
205
- # remove old nodes
206
- graph .node .remove (n )
129
+ nodes_to_insert .append (im2col_node )
130
+
131
+ matmul_input = im2col_out if need_im2col else inp_trans_out
132
+ # do matmul
133
+ matmul_node = helper .make_node ("MatMul" , [matmul_input , weight_name ], [matmul_out ])
134
+ # NHWC -> NCHW
135
+ out_trans_node = helper .make_node ("Transpose" , [matmul_out ], [cnv_output ], perm = [0 , 3 , 1 , 2 ])
136
+
137
+ nodes_to_insert .extend ([matmul_node , out_trans_node ])
138
+
139
+ # insert nodes where the conv is to preserve topological ordering
140
+ for i , insert_node in enumerate (nodes_to_insert ):
141
+ graph .node .insert (node_ind + i , insert_node )
142
+ graph .node .remove (node )
207
143
208
144
return (model , graph_modified )
145
+
146
+ def extract_conv_params (self , model , node ):
147
+
148
+ cnv_input = node .input [0 ]
149
+ cnv_output = node .output [0 ]
150
+ cnv_input_datatype = model .get_tensor_datatype (cnv_input )
151
+ cnv_output_datatype = model .get_tensor_datatype (cnv_output )
152
+ k_h = get_by_name (node .attribute , "kernel_shape" ).ints [0 ]
153
+ k_w = get_by_name (node .attribute , "kernel_shape" ).ints [1 ]
154
+ stride_h = get_by_name (node .attribute , "strides" ).ints [0 ]
155
+ stride_w = get_by_name (node .attribute , "strides" ).ints [1 ]
156
+ group = get_by_name (node .attribute , "group" ).i
157
+ weight_name = node .input [1 ]
158
+ W_conv = model .get_initializer (weight_name )
159
+ ifm_ch = model .get_tensor_shape (cnv_input )[1 ] # assume NCHW
160
+ ofm_ch = model .get_tensor_shape (cnv_output )[1 ] # assume NCHW
161
+ ifm_dim_h = model .get_tensor_shape (cnv_input )[2 ] # assume NCHW
162
+ ifm_dim_w = model .get_tensor_shape (cnv_input )[3 ] # assume NCHW
163
+ ofm_dim_h = model .get_tensor_shape (cnv_output )[2 ] # assume NCHW
164
+ ofm_dim_w = model .get_tensor_shape (cnv_output )[3 ] # assume NCHW
165
+ dilation_attr = get_by_name (node .attribute , "dilations" )
166
+ dilation = dilation_attr .ints if dilation_attr is not None else [1 , 1 ] # default value
167
+ auto_pad = get_by_name (node .attribute , "auto_pad" )
168
+ if auto_pad is not None :
169
+ auto_pad = auto_pad .s .decode ("utf-8" )
170
+ if auto_pad == "NOTSET" :
171
+ pad = get_by_name (node .attribute , "pads" ).ints
172
+ else :
173
+ pad = auto_pad_to_explicit_padding (
174
+ auto_pad , ifm_dim_h , ifm_dim_w , k_h , k_w , stride_h , stride_w , len (model .get_tensor_shape (cnv_input )) - 2
175
+ )
176
+ else :
177
+ pad = get_by_name (node .attribute , "pads" ).ints
178
+
179
+ if len (pad ) == 2 : # only one dimension should be padded
180
+ assert ifm_dim_h == 1 or ifm_dim_w == 1 , "Padding is assumed to be 1D, image is 2D"
181
+
182
+ return (cnv_input , cnv_output , cnv_input_datatype , cnv_output_datatype , k_h , k_w , stride_h ,
183
+ stride_w , group , weight_name , W_conv , ifm_ch , ofm_ch , ifm_dim_h , ifm_dim_w , ofm_dim_h ,
184
+ ofm_dim_w , dilation , pad )
0 commit comments