Skip to content
This repository was archived by the owner on Feb 7, 2023. It is now read-only.

Commit d5795a8

Browse files
authored
adding conv2d with dilation test (#479)
1 parent 0da14f8 commit d5795a8

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tests/pytorch_model_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,24 @@ def forward(self, x):
343343
H, W = 6, 3
344344
_test_torch_model_single_io(torch_model, (1,1,H,W), (1, H, W)) # type: ignore
345345

346+
def test_conv2d_dilation(self):
347+
class TestModule(torch.nn.Module):
348+
def __init__(self):
349+
in_channels = 1
350+
out_channels = 3
351+
bsz = 1 # batch size
352+
super(TestModule, self).__init__()
353+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels,
354+
kernel_size=(3, 4), stride=2, dilation=2)
355+
356+
def forward(self, x):
357+
return self.conv1(x)
358+
359+
torch_model = TestModule() # type: ignore
360+
torch_model.train(False)
361+
H, W = 64, 64
362+
_test_torch_model_single_io(torch_model, (1,1,H,W), (1, H, W)) # type: ignore
363+
346364

347365
def test_bachnorm_after_reshape(self): # type: () -> None
348366
class Net(nn.Module):

0 commit comments

Comments
 (0)