Skip to content

Commit 8d1ee1d

Browse files
committed
[LowerConv] support reshaping quant conv weight scales
1 parent c54f142 commit 8d1ee1d

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

src/qonnx/transformation/lower_convs_to_matmul.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def apply(self, model):
6464
group,
6565
weight_name,
6666
conv_weight_inp_name,
67+
conv_weight_q_scale_name,
6768
W_conv,
6869
ifm_ch,
6970
ofm_ch,
@@ -110,7 +111,19 @@ def apply(self, model):
110111
W_matmul = W_matmul.T
111112
model.set_initializer(weight_name, W_matmul)
112113
if weight_name != conv_weight_inp_name:
114+
# required for convs with quantized weights
113115
model.set_tensor_shape(conv_weight_inp_name, W_matmul.shape)
116+
if conv_weight_q_scale_name is not None:
117+
# required for convs with quantized weights
118+
scale_weight_q = model.get_initializer(conv_weight_q_scale_name)
119+
# scale shape is originally [OFM, IFM, k_H, k_W]
120+
# transpose into [OFM, k_H, k_W, IFM]
121+
scale_weight_q = scale_weight_q.transpose(0, 2, 3, 1)
122+
# reshape into [OFM][k_h*k_w*IFM] matrix
123+
scale_weight_q = scale_weight_q.reshape(ofm_ch, -1)
124+
# transpose to be shape-compatible with weight matrix
125+
scale_weight_q = scale_weight_q.T
126+
model.set_initializer(conv_weight_q_scale_name, scale_weight_q)
114127

115128
# create new intermediate values
116129
inp_trans_out = helper.make_tensor_value_info(
@@ -186,13 +199,15 @@ def extract_conv_params(self, model, node):
186199
group = get_by_name(node.attribute, "group").i
187200
weight_name = node.input[1]
188201
conv_weight_inp_name = node.input[1]
202+
conv_weight_q_scale_name = None
189203
W_conv = model.get_initializer(weight_name)
190204
if W_conv is None:
191205
# check to see if there is an immediate quantizer node feeding the weight input
192206
w_producer = model.find_producer(weight_name)
193207
if not (w_producer is None) and w_producer.op_type == "Quant":
194208
W_conv = model.get_initializer(w_producer.input[0])
195209
weight_name = w_producer.input[0]
210+
conv_weight_q_scale_name = w_producer.input[1]
196211
ifm_ch = model.get_tensor_shape(cnv_input)[1] # assume NCHW
197212
ofm_ch = model.get_tensor_shape(cnv_output)[1] # assume NCHW
198213
ifm_dim_h = model.get_tensor_shape(cnv_input)[2] # assume NCHW
@@ -228,6 +243,7 @@ def extract_conv_params(self, model, node):
228243
group,
229244
weight_name,
230245
conv_weight_inp_name,
246+
conv_weight_q_scale_name,
231247
W_conv,
232248
ifm_ch,
233249
ofm_ch,

0 commit comments

Comments
 (0)