@@ -64,6 +64,7 @@ def apply(self, model):
64
64
group ,
65
65
weight_name ,
66
66
conv_weight_inp_name ,
67
+ conv_weight_q_scale_name ,
67
68
W_conv ,
68
69
ifm_ch ,
69
70
ofm_ch ,
@@ -110,7 +111,19 @@ def apply(self, model):
110
111
W_matmul = W_matmul .T
111
112
model .set_initializer (weight_name , W_matmul )
112
113
if weight_name != conv_weight_inp_name :
114
+ # required for convs with quantized weights
113
115
model .set_tensor_shape (conv_weight_inp_name , W_matmul .shape )
116
+ if conv_weight_q_scale_name is not None :
117
+ # required for convs with quantized weights
118
+ scale_weight_q = model .get_initializer (conv_weight_q_scale_name )
119
+ # scale shape is originally [OFM, IFM, k_H, k_W]
120
+ # transpose into [OFM, k_H, k_W, IFM]
121
+ scale_weight_q = scale_weight_q .transpose (0 , 2 , 3 , 1 )
122
+ # reshape into [OFM][k_h*k_w*IFM] matrix
123
+ scale_weight_q = scale_weight_q .reshape (ofm_ch , - 1 )
124
+ # transpose to be shape-compatible with weight matrix
125
+ scale_weight_q = scale_weight_q .T
126
+ model .set_initializer (conv_weight_q_scale_name , scale_weight_q )
114
127
115
128
# create new intermediate values
116
129
inp_trans_out = helper .make_tensor_value_info (
@@ -186,13 +199,15 @@ def extract_conv_params(self, model, node):
186
199
group = get_by_name (node .attribute , "group" ).i
187
200
weight_name = node .input [1 ]
188
201
conv_weight_inp_name = node .input [1 ]
202
+ conv_weight_q_scale_name = None
189
203
W_conv = model .get_initializer (weight_name )
190
204
if W_conv is None :
191
205
# check to see if there is an immediate quantizer node feeding the weight input
192
206
w_producer = model .find_producer (weight_name )
193
207
if not (w_producer is None ) and w_producer .op_type == "Quant" :
194
208
W_conv = model .get_initializer (w_producer .input [0 ])
195
209
weight_name = w_producer .input [0 ]
210
+ conv_weight_q_scale_name = w_producer .input [1 ]
196
211
ifm_ch = model .get_tensor_shape (cnv_input )[1 ] # assume NCHW
197
212
ofm_ch = model .get_tensor_shape (cnv_output )[1 ] # assume NCHW
198
213
ifm_dim_h = model .get_tensor_shape (cnv_input )[2 ] # assume NCHW
@@ -228,6 +243,7 @@ def extract_conv_params(self, model, node):
228
243
group ,
229
244
weight_name ,
230
245
conv_weight_inp_name ,
246
+ conv_weight_q_scale_name ,
231
247
W_conv ,
232
248
ifm_ch ,
233
249
ofm_ch ,
0 commit comments