Skip to content
This repository was archived by the owner on Feb 7, 2023. It is now read-only.

Commit 2a3865e

Browse files
authored
Don't add crop layer if output_padding is not used for deconvolution (#496)
1 parent 96773d1 commit 2a3865e

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

onnx_coreml/_operators.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,14 +302,6 @@ def _get_conv_params(builder, node, graph, err, params_dict, axis=None):
302302
pads = [0, pads[0], 0, pads[1]]
303303
params_dict['pads'] = pads
304304

305-
if params_dict['is_deconv']:
306-
params_dict['crops'] = copy.copy(params_dict['pads'])
307-
params_dict['pads'] = [0, 0, 0, 0]
308-
if sum(params_dict['crops']) == 0:
309-
params_dict['is_post_crop'] = False
310-
else:
311-
params_dict['is_post_crop'] = True
312-
313305
if "kernel_shape" in node.attrs:
314306
params_dict['kernel_shape'] = node.attrs["kernel_shape"]
315307
else:
@@ -332,7 +324,6 @@ def _get_conv_params(builder, node, graph, err, params_dict, axis=None):
332324
params_dict['strides'].insert(0,1)
333325
params_dict['kernel_shape'].insert(0,1)
334326

335-
336327
params_dict['out_shape'] = None
337328
params_dict['padding_type'] = 'valid'
338329
params_dict['same_padding_asymmetry_mode'] = 'BOTTOM_RIGHT_HEAVY'
@@ -358,6 +349,8 @@ def _get_conv_params(builder, node, graph, err, params_dict, axis=None):
358349
else:
359350
params_dict['out_shape'] = (node.attrs['output_shape'][-2], node.attrs['output_shape'][-1]) # (Hout, wout)
360351
elif 'output_padding' in node.attrs:
352+
params_dict['crops'] = copy.copy(params_dict['pads'])
353+
params_dict['pads'] = [0, 0, 0, 0]
361354
post_pads = node.attrs['output_padding']
362355
if sum(post_pads) != 0:
363356
t = l = b = r = 0
@@ -432,7 +425,6 @@ def _add_conv(input_names, output_names, **kwargs):
432425
output_name=input_names[0],
433426
value=0
434427
)
435-
436428
builder.add_convolution(
437429
name=node.name,
438430
kernel_channels=kc,

tests/pytorch_model_test.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def _test_torch_model_single_io(torch_model, torch_input_shape, coreml_input_sha
4646

4747
# convert to coreml and run
4848
coreml_model = convert(onnx_model, target_ios=target_ios)
49-
5049
output_name = [o.name for o in onnx_model.graph.output][0]
5150
initializer_names = {t.name for t in onnx_model.graph.initializer}
5251
input_name = [i.name for i in onnx_model.graph.input if i.name not in initializer_names][0]
@@ -157,15 +156,29 @@ def test_conv2D_transpose(self): # type: () -> None
157156
class Net(nn.Module):
158157
def __init__(self):
159158
super(Net, self).__init__()
160-
self.convT = torch.nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, output_padding=1, padding=1, groups=1)
159+
self.convT = torch.nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, output_padding=0, padding=3, groups=1)
160+
161+
def forward(self, x):
162+
y = self.convT(x)
163+
return y
164+
165+
torch_model = Net() # type: ignore
166+
torch_model.train(False)
167+
_test_torch_model_single_io(torch_model, (1, 1, 64, 64), (1, 64, 64)) # type: ignore
168+
169+
def test_conv2D_transpose_output_padding(self): # type: () -> None
170+
class Net(nn.Module):
171+
def __init__(self):
172+
super(Net, self).__init__()
173+
self.convT = torch.nn.ConvTranspose2d(1, 1, kernel_size=3, stride=2, output_padding=1, padding=3, groups=1)
161174

162175
def forward(self, x):
163176
y = self.convT(x)
164177
return y
165178

166179
torch_model = Net() # type: ignore
167180
torch_model.train(False)
168-
_test_torch_model_single_io(torch_model, (1, 1, 2, 2), (1, 2, 2)) # type: ignore
181+
_test_torch_model_single_io(torch_model, (1, 1, 64, 64), (1, 64, 64)) # type: ignore
169182

170183
def test_conv2D_transpose_groups(self): # type: () -> None
171184
class Net(nn.Module):

0 commit comments

Comments
 (0)