Skip to content

Commit 5e5bb55

Browse files
committed
[LowerConv] support lowering Conv with Quant node on weights
1 parent a92093c commit 5e5bb55

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

src/qonnx/transformation/lower_convs_to_matmul.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ def apply(self, model):
5151
warnings.warn("Found Conv node with bias, skipping")
5252
continue
5353

54-
if model.get_initializer(node.input[1]) is None:
55-
warnings.warn("Found Conv node with non-initialized weight, skipping")
56-
continue
57-
5854
# extract parameters of node
5955
(
6056
cnv_input,
@@ -67,6 +63,7 @@ def apply(self, model):
6763
stride_w,
6864
group,
6965
weight_name,
66+
conv_weight_inp_name,
7067
W_conv,
7168
ifm_ch,
7269
ofm_ch,
@@ -78,6 +75,10 @@ def apply(self, model):
7875
pad,
7976
) = self.extract_conv_params(model, node)
8077

78+
if W_conv is None:
79+
warnings.warn("Found Conv node with non-initialized weight, skipping")
80+
continue
81+
8182
# if depthwise conv create sparse matrix and variable "dw"
8283
# to store as attribute in Im2Col that indicates that the created
8384
# Im2Col node belongs to a depthwise convolution
@@ -108,6 +109,8 @@ def apply(self, model):
108109
# transpose to get ONNX-compatible [k_h*k_w*IFM][OFM] matrix
109110
W_matmul = W_matmul.T
110111
model.set_initializer(weight_name, W_matmul)
112+
if weight_name != conv_weight_inp_name:
113+
model.set_tensor_shape(conv_weight_inp_name, W_matmul.shape)
111114

112115
# create new intermediate values
113116
inp_trans_out = helper.make_tensor_value_info(
@@ -158,7 +161,7 @@ def apply(self, model):
158161

159162
matmul_input = im2col_out if need_im2col else inp_trans_out
160163
# do matmul
161-
matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out])
164+
matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out])
162165
# NHWC -> NCHW
163166
out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2])
164167

@@ -182,7 +185,14 @@ def extract_conv_params(self, model, node):
182185
stride_w = get_by_name(node.attribute, "strides").ints[1]
183186
group = get_by_name(node.attribute, "group").i
184187
weight_name = node.input[1]
188+
conv_weight_inp_name = node.input[1]
185189
W_conv = model.get_initializer(weight_name)
190+
if W_conv is None:
191+
# check to see if there is an immediate quantizer node feeding the weight input
192+
w_producer = model.find_producer(weight_name)
193+
if not (w_producer is None) and w_producer.op_type == "Quant":
194+
W_conv = model.get_initializer(w_producer.input[0])
195+
weight_name = w_producer.input[0]
186196
ifm_ch = model.get_tensor_shape(cnv_input)[1] # assume NCHW
187197
ofm_ch = model.get_tensor_shape(cnv_output)[1] # assume NCHW
188198
ifm_dim_h = model.get_tensor_shape(cnv_input)[2] # assume NCHW
@@ -217,6 +227,7 @@ def extract_conv_params(self, model, node):
217227
stride_w,
218228
group,
219229
weight_name,
230+
conv_weight_inp_name,
220231
W_conv,
221232
ifm_ch,
222233
ofm_ch,

0 commit comments

Comments
 (0)