@@ -51,10 +51,6 @@ def apply(self, model):
51
51
warnings .warn ("Found Conv node with bias, skipping" )
52
52
continue
53
53
54
- if model .get_initializer (node .input [1 ]) is None :
55
- warnings .warn ("Found Conv node with non-initialized weight, skipping" )
56
- continue
57
-
58
54
# extract parameters of node
59
55
(
60
56
cnv_input ,
@@ -67,6 +63,7 @@ def apply(self, model):
67
63
stride_w ,
68
64
group ,
69
65
weight_name ,
66
+ conv_weight_inp_name ,
70
67
W_conv ,
71
68
ifm_ch ,
72
69
ofm_ch ,
@@ -78,6 +75,10 @@ def apply(self, model):
78
75
pad ,
79
76
) = self .extract_conv_params (model , node )
80
77
78
+ if W_conv is None :
79
+ warnings .warn ("Found Conv node with non-initialized weight, skipping" )
80
+ continue
81
+
81
82
# if depthwise conv create sparse matrix and variable "dw"
82
83
# to store as attribute in Im2Col that indicates that the created
83
84
# Im2Col node belongs to a depthwise convolution
@@ -108,6 +109,8 @@ def apply(self, model):
108
109
# transpose to get ONNX-compatible [k_h*k_w*IFM][OFM] matrix
109
110
W_matmul = W_matmul .T
110
111
model .set_initializer (weight_name , W_matmul )
112
+ if weight_name != conv_weight_inp_name :
113
+ model .set_tensor_shape (conv_weight_inp_name , W_matmul .shape )
111
114
112
115
# create new intermediate values
113
116
inp_trans_out = helper .make_tensor_value_info (
@@ -158,7 +161,7 @@ def apply(self, model):
158
161
159
162
matmul_input = im2col_out if need_im2col else inp_trans_out
160
163
# do matmul
161
- matmul_node = helper .make_node ("MatMul" , [matmul_input , weight_name ], [matmul_out ])
164
+ matmul_node = helper .make_node ("MatMul" , [matmul_input , conv_weight_inp_name ], [matmul_out ])
162
165
# NHWC -> NCHW
163
166
out_trans_node = helper .make_node ("Transpose" , [matmul_out ], [cnv_output ], perm = [0 , 3 , 1 , 2 ])
164
167
@@ -182,7 +185,14 @@ def extract_conv_params(self, model, node):
182
185
stride_w = get_by_name (node .attribute , "strides" ).ints [1 ]
183
186
group = get_by_name (node .attribute , "group" ).i
184
187
weight_name = node .input [1 ]
188
+ conv_weight_inp_name = node .input [1 ]
185
189
W_conv = model .get_initializer (weight_name )
190
+ if W_conv is None :
191
+ # check to see if there is an immediate quantizer node feeding the weight input
192
+ w_producer = model .find_producer (weight_name )
193
+ if not (w_producer is None ) and w_producer .op_type == "Quant" :
194
+ W_conv = model .get_initializer (w_producer .input [0 ])
195
+ weight_name = w_producer .input [0 ]
186
196
ifm_ch = model .get_tensor_shape (cnv_input )[1 ] # assume NCHW
187
197
ofm_ch = model .get_tensor_shape (cnv_output )[1 ] # assume NCHW
188
198
ifm_dim_h = model .get_tensor_shape (cnv_input )[2 ] # assume NCHW
@@ -217,6 +227,7 @@ def extract_conv_params(self, model, node):
217
227
stride_w ,
218
228
group ,
219
229
weight_name ,
230
+ conv_weight_inp_name ,
220
231
W_conv ,
221
232
ifm_ch ,
222
233
ofm_ch ,
0 commit comments