Skip to content

Commit bdf9405

Browse files
authored
Merge pull request #132 from fastmachinelearning/feature/convlower_qnt
Preserve weight quantizer while lowering convolutions
2 parents 84ad7ae + 032681c commit bdf9405

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

src/qonnx/transformation/lower_convs_to_matmul.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def apply(self, model):
6363
stride_w,
6464
group,
6565
weight_name,
66+
conv_weight_inp_name,
67+
conv_weight_q_scale_name,
6668
W_conv,
6769
ifm_ch,
6870
ofm_ch,
@@ -74,12 +76,18 @@ def apply(self, model):
7476
pad,
7577
) = self.extract_conv_params(model, node)
7678

79+
if W_conv is None:
80+
warnings.warn("Found Conv node with non-initialized weight, skipping")
81+
continue
82+
7783
# if depthwise conv create sparse matrix and variable "dw"
7884
# to store as attribute in Im2Col that indicates that the created
7985
# Im2Col node belongs to a depthwise convolution
8086
dw = False
8187
if group == ifm_ch and ofm_ch == ifm_ch:
8288
W_sparse = np.zeros((ofm_ch, ifm_ch, k_h, k_w)) # (OFM, IFM, k_H, k_W)
89+
# TODO: if the convolution is quantized with a non-zero zeropoint we
90+
# should be using the zeropoint value here instead of np.zeros
8391
for ch in range(ifm_ch):
8492
W_sparse[ch][ch] = W_conv[ch][0] # W_conv = [OFM, IFM, k_H, k_W]
8593
W_conv = W_sparse.astype(np.float32)
@@ -104,6 +112,21 @@ def apply(self, model):
104112
# transpose to get ONNX-compatible [k_h*k_w*IFM][OFM] matrix
105113
W_matmul = W_matmul.T
106114
model.set_initializer(weight_name, W_matmul)
115+
if weight_name != conv_weight_inp_name:
116+
# required for convs with quantized weights
117+
model.set_tensor_shape(conv_weight_inp_name, W_matmul.shape)
118+
if conv_weight_q_scale_name is not None:
119+
# required for convs with quantized weights
120+
scale_weight_q = model.get_initializer(conv_weight_q_scale_name)
121+
if scale_weight_q.ndim > 0:
122+
# scale shape is originally [OFM, IFM, k_H, k_W]
123+
# transpose into [OFM, k_H, k_W, IFM]
124+
scale_weight_q = scale_weight_q.transpose(0, 2, 3, 1)
125+
# reshape into [OFM][k_h*k_w*IFM] matrix
126+
scale_weight_q = scale_weight_q.reshape(ofm_ch, -1)
127+
# transpose to be shape-compatible with weight matrix
128+
scale_weight_q = scale_weight_q.T
129+
model.set_initializer(conv_weight_q_scale_name, scale_weight_q)
107130

108131
# create new intermediate values
109132
inp_trans_out = helper.make_tensor_value_info(
@@ -154,7 +177,7 @@ def apply(self, model):
154177

155178
matmul_input = im2col_out if need_im2col else inp_trans_out
156179
# do matmul
157-
matmul_node = helper.make_node("MatMul", [matmul_input, weight_name], [matmul_out])
180+
matmul_node = helper.make_node("MatMul", [matmul_input, conv_weight_inp_name], [matmul_out])
158181
# NHWC -> NCHW
159182
out_trans_node = helper.make_node("Transpose", [matmul_out], [cnv_output], perm=[0, 3, 1, 2])
160183

@@ -178,7 +201,16 @@ def extract_conv_params(self, model, node):
178201
stride_w = get_by_name(node.attribute, "strides").ints[1]
179202
group = get_by_name(node.attribute, "group").i
180203
weight_name = node.input[1]
204+
conv_weight_inp_name = node.input[1]
205+
conv_weight_q_scale_name = None
181206
W_conv = model.get_initializer(weight_name)
207+
if W_conv is None:
208+
# check to see if there is an immediate quantizer node feeding the weight input
209+
w_producer = model.find_producer(weight_name)
210+
if not (w_producer is None) and w_producer.op_type == "Quant":
211+
W_conv = model.get_initializer(w_producer.input[0])
212+
weight_name = w_producer.input[0]
213+
conv_weight_q_scale_name = w_producer.input[1]
182214
ifm_ch = model.get_tensor_shape(cnv_input)[1] # assume NCHW
183215
ofm_ch = model.get_tensor_shape(cnv_output)[1] # assume NCHW
184216
ifm_dim_h = model.get_tensor_shape(cnv_input)[2] # assume NCHW
@@ -213,6 +245,8 @@ def extract_conv_params(self, model, node):
213245
stride_w,
214246
group,
215247
weight_name,
248+
conv_weight_inp_name,
249+
conv_weight_q_scale_name,
216250
W_conv,
217251
ifm_ch,
218252
ofm_ch,

src/qonnx/util/test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,20 @@ def qonnx_download_model():
145145
clize.run(download_model)
146146

147147

148-
def get_golden_in_and_output(test_model):
149-
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
150-
rng = np.random.RandomState(42)
148+
def get_random_input(test_model, seed=42):
149+
rng = np.random.RandomState(seed)
151150
input_shape = test_model_details[test_model]["input_shape"]
152151
(low, high) = test_model_details[test_model]["input_range"]
153152
size = np.prod(np.asarray(input_shape))
154153
input_tensor = rng.uniform(low=low, high=high, size=size)
155154
input_tensor = input_tensor.astype(np.float32)
156155
input_tensor = input_tensor.reshape(input_shape)
156+
return input_tensor
157+
158+
159+
def get_golden_in_and_output(test_model, seed=42):
160+
model = download_model(test_model, do_cleanup=True, return_modelwrapper=True)
161+
input_tensor = get_random_input(test_model, seed=seed)
157162
input_dict = {model.graph.input[0].name: input_tensor}
158163
golden_output_dict = oxe.execute_onnx(model, input_dict)
159164
golden_result = golden_output_dict[model.graph.output[0].name]

tests/transformation/test_conv_lowering.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,19 @@
4343
from qonnx.transformation.infer_shapes import InferShapes
4444
from qonnx.transformation.lower_convs_to_matmul import LowerConvsToMatMul
4545
from qonnx.util.basic import gen_finn_dt_tensor, qonnx_make_model
46+
from qonnx.util.test import download_model, get_golden_in_and_output
47+
48+
49+
@pytest.mark.parametrize("model_name", ["FINN-CNV_W2A2", "MobileNetv1-w4a4"])
50+
def test_conv_lowering_quant_weights(model_name):
51+
model = download_model(model_name, return_modelwrapper=True, do_cleanup=True)
52+
input_t, golden_t = get_golden_in_and_output(model_name, seed=0)
53+
input_dict = {model.graph.input[0].name: input_t}
54+
model = model.transform(LowerConvsToMatMul())
55+
assert model.get_nodes_by_op_type("Conv") == []
56+
prod_dict = oxe.execute_onnx(model, input_dict)
57+
prod_t = prod_dict[model.graph.output[0].name]
58+
assert np.isclose(golden_t, prod_t, atol=1e-04).all()
4659

4760

4861
def test_conv_lowering_convmnist():

0 commit comments

Comments
 (0)