Skip to content

Commit c5bd87f

Browse files
authored
Merge pull request #102 from i-colbert/icolbert/nn-resize-to-deconv
Feat: Adding ResizeConvolutionToDeconvolution transformation
2 parents fd61cfe + b7eebaa commit c5bd87f

File tree

10 files changed

+635
-46
lines changed

10 files changed

+635
-46
lines changed
Binary file not shown.
Binary file not shown.

src/qonnx/transformation/lower_convs_to_matmul.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,7 @@
3232

3333
from qonnx.transformation.base import Transformation
3434
from qonnx.transformation.extract_conv_bias import ExtractBiasFromConv
35-
from qonnx.util.basic import get_by_name
36-
37-
38-
def _auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h, stride_w, n_dims):
39-
pad_total_h = (stride_h - 1) * idim_h - stride_h + k_h
40-
pad_total_w = (stride_w - 1) * idim_w - stride_w + k_w
41-
pad_half_small_h = int((pad_total_h / 2))
42-
pad_half_small_w = int((pad_total_w / 2))
43-
pad_half_large_h = pad_total_h - pad_half_small_h
44-
pad_half_large_w = pad_total_w - pad_half_small_w
45-
if autopad_str == "VALID":
46-
return [0 for i in range(2 * n_dims)]
47-
elif autopad_str == "SAME_UPPER":
48-
return [pad_half_small_h, pad_half_small_w, pad_half_large_h, pad_half_large_w]
49-
elif autopad_str == "SAME_LOWER":
50-
return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w]
51-
else:
52-
raise Exception("Unsupported auto_pad: " + autopad_str)
35+
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
5336

5437

5538
class LowerConvsToMatMul(Transformation):
@@ -100,7 +83,7 @@ def apply(self, model):
10083
# use specified padding
10184
pad = get_by_name(n.attribute, "pads").ints
10285
else:
103-
pad = _auto_pad_to_explicit_padding(
86+
pad = auto_pad_to_explicit_padding(
10487
auto_pad,
10588
ifm_dim_h,
10689
ifm_dim_w,
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# Copyright (c) 2024, Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of QONNX nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import numpy as np
30+
import warnings
31+
from onnx import helper
32+
33+
from qonnx.core.datatype import DataType
34+
from qonnx.custom_op.general.quant import quant, resolve_rounding_mode
35+
from qonnx.transformation.base import Transformation
36+
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
37+
38+
39+
def _weight_convolution(cnv_weights: np.ndarray, scale: int) -> np.ndarray:
40+
"""Adaptation of the weight convolution algorithm as proposed in Colbert et al. (2021) - `An
41+
Energy-Efficient Edge Computing Paradigm for Convolution-Based Image Upsampling`"""
42+
ofm_ch = cnv_weights.shape[0]
43+
ifm_ch = cnv_weights.shape[1]
44+
kh_size = cnv_weights.shape[2]
45+
kw_size = cnv_weights.shape[3]
46+
assert kh_size == kw_size, "Only square channels supported currently."
47+
# NOTE - this is different than the convolution kernels, which are OC x IC x KH x KW
48+
# rather than IC x OC x KH x KW
49+
dcnv_weights = np.zeros((ifm_ch, ofm_ch, kh_size + scale - 1, kw_size + scale - 1))
50+
for oc in range(ofm_ch):
51+
for ic in range(ifm_ch):
52+
for i in range(scale):
53+
for j in range(scale):
54+
dcnv_weights[ic, oc, i : i + kh_size, j : j + kw_size] += np.rot90(cnv_weights[oc, ic], 2, [0, 1])
55+
return dcnv_weights
56+
57+
58+
class ResizeConvolutionToDeconvolution(Transformation):
59+
"""Replaces resize convolution layers (e.g., nearest neighbor upsample + same-padded convolution)
60+
with deconvolution layers using the weight convolution algorithm. Currently does not support
61+
resize convolutions that use bilinear or bicubic upsampling"""
62+
63+
def __init__(self, maintain_bit_width: bool = False):
64+
super().__init__()
65+
self.maintain_bit_width = maintain_bit_width
66+
67+
def apply(self, model):
68+
graph = model.graph
69+
node_ind = 0
70+
graph_modified = False
71+
for n in graph.node:
72+
node_ind += 1
73+
if n.op_type == "Resize":
74+
resize_input = n.input[0]
75+
resize_output = n.output[0]
76+
consumers = model.find_consumers(resize_output)
77+
78+
if len(consumers) == 0:
79+
continue
80+
81+
if len(consumers) > 1 and any([c.op_type == "Conv" for c in consumers]):
82+
warnings.warn("Skipping resize conv that has resize with multiple consumers. Not yet supported.")
83+
continue
84+
85+
conv = consumers[0]
86+
if conv is not None and conv.op_type == "Conv":
87+
# TODO: extend support to other resize convolutions
88+
resize_mode = get_by_name(n.attribute, "mode").s.decode()
89+
if resize_mode != "nearest":
90+
warnings.warn(f"Skipping resize conv with resize_mode={resize_mode}. Not yet supported.")
91+
continue
92+
93+
group = get_by_name(conv.attribute, "group").i
94+
if group != 1:
95+
warnings.warn("Skipping resize conv with group > 1. Not yet supported.")
96+
continue
97+
98+
# The weights of the convolution can be generated by another input op if the model is
99+
# quantized. Preliminary support for quantization focuses on QONNX ops (i.e., Quant)
100+
weight_name = conv.input[1]
101+
weight_prod = model.find_producer(weight_name)
102+
103+
# If the producer is None, then it is initialized by the Conv node
104+
if weight_prod is None:
105+
W_conv = model.get_initializer(weight_name) # (OC, IC, KH, KW)
106+
107+
# If the convolution weights are not initialized by the convolution, then we need to
108+
# find the node is producing the weights
109+
else:
110+
if weight_prod.op_type == "Quant":
111+
[q_w_name, q_s_name, q_zp_name, q_bw_name] = weight_prod.input
112+
W_conv = model.get_initializer(q_w_name)
113+
W_scale = model.get_initializer(q_s_name)
114+
if isinstance(W_scale, np.ndarray) and W_scale.ndim > 1:
115+
W_scale = np.moveaxis(W_scale, 0, 1)
116+
W_zeropt = model.get_initializer(q_zp_name)
117+
if isinstance(W_zeropt, np.ndarray) and W_zeropt.ndim > 1:
118+
W_zeropt = np.moveaxis(W_zeropt, 0, 1)
119+
W_bitwidth = model.get_initializer(q_bw_name)
120+
W_signed = get_by_name(weight_prod.attribute, "signed").i
121+
W_narrow = get_by_name(weight_prod.attribute, "narrow").i
122+
W_rounding_mode = get_by_name(weight_prod.attribute, "rounding_mode").s.decode()
123+
else:
124+
warnings.warn(
125+
f"Weight producer is {weight_prod.op_type}, not a QONNX Quant node. Not yet supported."
126+
)
127+
continue
128+
129+
kshape = get_by_name(conv.attribute, "kernel_shape").ints
130+
idim = model.get_tensor_shape(conv.input[0]) # require NCHW
131+
odim = model.get_tensor_shape(conv.output[0]) # require NCHW
132+
if not (len(odim) == len(idim) == 4):
133+
warnings.warn("Skipping resize conv, only 2D convolutions supported.")
134+
continue
135+
136+
[_, ifm_ch, ifm_dim_h, ifm_dim_w] = idim
137+
[_, ofm_ch, ofm_dim_h, ofm_dim_w] = odim
138+
139+
if (ifm_dim_h != ofm_dim_h) or (ifm_dim_w != ofm_dim_w):
140+
warnings.warn("Skipping resize conv, only same-padded convs supported.")
141+
continue
142+
dilation_attr = get_by_name(conv.attribute, "dilations")
143+
if dilation_attr is not None:
144+
dilation = dilation_attr.ints
145+
else:
146+
dilation = [1, 1] # default value
147+
if dilation != [1, 1]:
148+
warnings.warn("Skipping resize conv, only supporting dilation=[1,1].")
149+
continue
150+
# get resize scaling attribute
151+
resize_scales = model.get_initializer(n.input[2]) # assume NCHW
152+
if not (resize_scales[0] == resize_scales[1] == 1):
153+
warnings.warn("Skipping resize conv, scaling along batch or channel dimension not supported.")
154+
continue
155+
if resize_scales[2] != resize_scales[3]:
156+
warnings.warn("Skipping resize conv, non-square scaling not yet supported.")
157+
continue
158+
resize_scale = int(resize_scales[2]) # TODO: extend to vector once non-square scaling supported
159+
160+
W_deconv = _weight_convolution(W_conv, resize_scale).astype(np.float32)
161+
kh_size_deconv = kshape[0] + resize_scale - 1
162+
kw_size_deconv = kshape[1] + resize_scale - 1
163+
assert W_deconv.shape == (
164+
ifm_ch,
165+
ofm_ch,
166+
kh_size_deconv,
167+
kw_size_deconv,
168+
), "The resulting deconvolution weight shape is incorrect."
169+
170+
stride_h = get_by_name(conv.attribute, "strides").ints[0]
171+
stride_w = get_by_name(conv.attribute, "strides").ints[1]
172+
# handle both auto_pad and explicit padding
173+
auto_pad = get_by_name(conv.attribute, "auto_pad")
174+
if auto_pad is not None:
175+
# find equivalent specified padding
176+
auto_pad = auto_pad.s.decode("utf-8")
177+
if auto_pad == "NOTSET":
178+
# use specified padding
179+
pad = get_by_name(conv.attribute, "pads").ints
180+
else:
181+
pad = auto_pad_to_explicit_padding(
182+
auto_pad,
183+
ifm_dim_h,
184+
ifm_dim_w,
185+
kshape[0],
186+
kshape[1],
187+
stride_h,
188+
stride_w,
189+
len(model.get_tensor_shape(n.input[0])) - 2,
190+
)
191+
else:
192+
# use specified padding
193+
pad = get_by_name(conv.attribute, "pads").ints
194+
195+
# if `maintain_bit_width`, then we use the quant parameters to
196+
# re-quantize the weights after the weight convolution
197+
if self.maintain_bit_width and (weight_prod is not None):
198+
W_deconv_quant = quant(W_deconv, W_scale, W_zeropt, W_bitwidth, W_signed, W_narrow, W_rounding_mode)
199+
if not np.allclose(W_deconv, W_deconv_quant):
200+
warnings.warn("Clipping error introduced, consider `maintain_bit_width=False`.")
201+
202+
# if not `maintain_bit_width`, then we adjust the bit width to
203+
# account for the clipping errors.
204+
elif weight_prod is not None:
205+
round_fnc = resolve_rounding_mode(W_rounding_mode)
206+
W_int = (W_deconv / W_scale) + W_zeropt
207+
W_int = round_fnc(W_int) # handling rounding errors
208+
W_min = W_int.min()
209+
W_max = W_int.max()
210+
if W_min < 0:
211+
if abs(W_min) > W_max:
212+
wdt = DataType.get_smallest_possible(W_min)
213+
else:
214+
wdt = DataType.get_smallest_possible(-W_max - 1)
215+
else:
216+
wdt = DataType.get_smallest_possible(W_max)
217+
assert np.vectorize(wdt.allowed)(W_int).all(), "Error: issue finding data type to support."
218+
if W_bitwidth != wdt.bitwidth():
219+
W_bitwidth = np.array(wdt.bitwidth(), dtype=np.float32)
220+
assert wdt.signed() == W_signed, "Error: should maintain sign of the weights."
221+
222+
deconv_inps = [resize_input, weight_name]
223+
# Make sure to keep the biases from the convolution
224+
if len(conv.input) == 3:
225+
bias_name = conv.input[2]
226+
bias_prod = model.find_producer(bias_name)
227+
# If the producer is None, then it is initialized by the Conv node
228+
# and we need to ensure it isn't removed with the Conv node
229+
if bias_prod is None:
230+
B_conv = model.get_initializer(bias_name) # (OC,)
231+
model.set_initializer(bias_name, B_conv)
232+
deconv_inps.append(bias_name) # add to the inputs
233+
deconv_outs = conv.output
234+
deconv_pad = pad
235+
deconv_node = helper.make_node(
236+
"ConvTranspose",
237+
deconv_inps,
238+
deconv_outs,
239+
kernel_shape=[kh_size_deconv, kw_size_deconv],
240+
strides=[resize_scale, resize_scale],
241+
pads=deconv_pad,
242+
group=group,
243+
dilations=dilation,
244+
)
245+
W_deconv_init = weight_name
246+
if weight_prod is not None:
247+
W_deconv_init = q_w_name
248+
model.set_initializer(q_zp_name, W_zeropt)
249+
model.set_initializer(q_s_name, W_scale)
250+
model.set_initializer(q_bw_name, W_bitwidth)
251+
model.set_initializer(W_deconv_init, W_deconv)
252+
model.set_tensor_shape(weight_name, list(W_deconv.shape))
253+
graph.node.insert(node_ind, deconv_node)
254+
# remove old nodes
255+
graph.node.remove(n)
256+
graph.node.remove(conv)
257+
graph_modified = True
258+
259+
return (model, graph_modified)

src/qonnx/transformation/subpixel_to_deconv.py

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from onnx import helper
3232

3333
from qonnx.transformation.base import Transformation
34-
from qonnx.util.basic import get_by_name
34+
from qonnx.util.basic import auto_pad_to_explicit_padding, get_by_name
3535

3636

3737
def _weight_shuffle(cnv_weights: np.ndarray, block_size: int) -> np.ndarray:
@@ -62,23 +62,6 @@ def _weight_shuffle(cnv_weights: np.ndarray, block_size: int) -> np.ndarray:
6262
return dcnv_weights
6363

6464

65-
def _auto_pad_to_explicit_padding(autopad_str, idim_h, idim_w, k_h, k_w, stride_h, stride_w, n_dims):
66-
pad_total_h = (stride_h - 1) * idim_h - stride_h + k_h
67-
pad_total_w = (stride_w - 1) * idim_w - stride_w + k_w
68-
pad_half_small_h = int((pad_total_h / 2))
69-
pad_half_small_w = int((pad_total_w / 2))
70-
pad_half_large_h = pad_total_h - pad_half_small_h
71-
pad_half_large_w = pad_total_w - pad_half_small_w
72-
if autopad_str == "VALID":
73-
return [0 for i in range(2 * n_dims)]
74-
elif autopad_str == "SAME_UPPER":
75-
return [pad_half_small_h, pad_half_small_w, pad_half_large_h, pad_half_large_w]
76-
elif autopad_str == "SAME_LOWER":
77-
return [pad_half_large_h, pad_half_large_w, pad_half_small_h, pad_half_small_w]
78-
else:
79-
raise Exception("Unsupported auto_pad: " + autopad_str)
80-
81-
8265
class SubPixelToDeconvolution(Transformation):
8366
"""Replaces sub-pixel convolution layers (i.e., same-padded convolution + depth2space)
8467
with deconvolution layers using the weight shuffle algorithm. Currently does not support
@@ -111,6 +94,7 @@ def apply(self, model):
11194
group = get_by_name(n.attribute, "group").i
11295
if group != 1:
11396
warnings.warn("Skipping sub-pixel conv with group > 1. Not yet supported.")
97+
continue
11498

11599
# The weights of the convolution can be generated by another input op if the model is
116100
# quantized. Preliminary support for quantization focuses on QONNX ops (i.e., Quant)
@@ -136,14 +120,18 @@ def apply(self, model):
136120
continue
137121

138122
kshape = get_by_name(n.attribute, "kernel_shape").ints
139-
ifm_ch = model.get_tensor_shape(n.input[0])[1] # assume NCHW
140-
ofm_ch = model.get_tensor_shape(n.output[0])[1] # assume NCHW
141-
ifm_dim_h = model.get_tensor_shape(n.input[0])[2] # assume NCHW
142-
ifm_dim_w = model.get_tensor_shape(n.input[0])[3] # assume NCHW
143-
ofm_dim_h = model.get_tensor_shape(n.output[0])[2] # assume NCHW
144-
ofm_dim_w = model.get_tensor_shape(n.output[0])[3]
123+
idim = model.get_tensor_shape(n.input[0]) # require NCHW
124+
odim = model.get_tensor_shape(n.output[0]) # require NCHW
125+
if not (len(odim) == len(idim) == 4):
126+
warnings.warn("Skipping sub-pixel conv, only 2D convolutions supported.")
127+
continue
128+
129+
[_, ifm_ch, ifm_dim_h, ifm_dim_w] = idim
130+
[_, ofm_ch, ofm_dim_h, ofm_dim_w] = odim
131+
145132
if (ifm_dim_h != ofm_dim_h) or (ifm_dim_w != ofm_dim_w):
146133
warnings.warn("Skipping sub-pixel conv, only same-padded convs supported.")
134+
continue
147135
dilation_attr = get_by_name(n.attribute, "dilations")
148136
if dilation_attr is not None:
149137
dilation = dilation_attr.ints
@@ -157,6 +145,7 @@ def apply(self, model):
157145
warnings.warn(
158146
"Skipping sub-pixel conv, the output channels and block size need to be evenly divisible."
159147
)
148+
continue
160149
W_deconv = _weight_shuffle(W_conv, block_size).astype(np.float32)
161150
kh_size_deconv = kshape[0] * block_size
162151
kw_size_deconv = kshape[1] * block_size
@@ -178,7 +167,7 @@ def apply(self, model):
178167
# use specified padding
179168
pad = get_by_name(n.attribute, "pads").ints
180169
else:
181-
pad = _auto_pad_to_explicit_padding(
170+
pad = auto_pad_to_explicit_padding(
182171
auto_pad,
183172
ifm_dim_h,
184173
ifm_dim_w,

0 commit comments

Comments
 (0)