Skip to content

Commit bf1840d

Browse files
committed
Fix: adding check and warning for dimension mismatch
1 parent 8d66e89 commit bf1840d

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/qonnx/transformation/resize_conv_to_deconv.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,15 @@ def apply(self, model):
124124
continue
125125

126126
kshape = get_by_name(conv.attribute, "kernel_shape").ints
127-
ifm_ch = model.get_tensor_shape(conv.input[0])[1] # assume NCHW
128-
ofm_ch = model.get_tensor_shape(conv.output[0])[1] # assume NCHW
129-
ifm_dim_h = model.get_tensor_shape(conv.input[0])[2] # assume NCHW
130-
ifm_dim_w = model.get_tensor_shape(conv.input[0])[3] # assume NCHW
131-
ofm_dim_h = model.get_tensor_shape(conv.output[0])[2] # assume NCHW
132-
ofm_dim_w = model.get_tensor_shape(conv.output[0])[3]
127+
idim = model.get_tensor_shape(conv.input[0]) # require NCHW
128+
odim = model.get_tensor_shape(conv.output[0]) # require NCHW
129+
if not (len(odim) == len(idim) == 4):
130+
warnings.warn("Skipping resize conv, only 2D convolutions supported.")
131+
continue
132+
133+
[_, ifm_ch, ifm_dim_h, ifm_dim_w] = idim
134+
[_, ofm_ch, ofm_dim_h, ofm_dim_w] = odim
135+
133136
if (ifm_dim_h != ofm_dim_h) or (ifm_dim_w != ofm_dim_w):
134137
warnings.warn("Skipping resize conv, only same-padded convs supported.")
135138
continue

src/qonnx/transformation/subpixel_to_deconv.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,15 @@ def apply(self, model):
120120
continue
121121

122122
kshape = get_by_name(n.attribute, "kernel_shape").ints
123-
ifm_ch = model.get_tensor_shape(n.input[0])[1] # assume NCHW
124-
ofm_ch = model.get_tensor_shape(n.output[0])[1] # assume NCHW
125-
ifm_dim_h = model.get_tensor_shape(n.input[0])[2] # assume NCHW
126-
ifm_dim_w = model.get_tensor_shape(n.input[0])[3] # assume NCHW
127-
ofm_dim_h = model.get_tensor_shape(n.output[0])[2] # assume NCHW
128-
ofm_dim_w = model.get_tensor_shape(n.output[0])[3]
123+
idim = model.get_tensor_shape(n.input[0]) # require NCHW
124+
odim = model.get_tensor_shape(n.output[0]) # require NCHW
125+
if not (len(odim) == len(idim) == 4):
126+
warnings.warn("Skipping sub-pixel conv, only 2D convolutions supported.")
127+
continue
128+
129+
[_, ifm_ch, ifm_dim_h, ifm_dim_w] = idim
130+
[_, ofm_ch, ofm_dim_h, ofm_dim_w] = odim
131+
129132
if (ifm_dim_h != ofm_dim_h) or (ifm_dim_w != ofm_dim_w):
130133
warnings.warn("Skipping sub-pixel conv, only same-padded convs supported.")
131134
continue

0 commit comments

Comments
 (0)