Skip to content

Commit c418bdf

Browse files
committed
test: add decoder and transformer tests
1 parent 01dc5dc commit c418bdf

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

cellseg_models_pytorch/decoders/tests/test_decoders.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66

77
@pytest.mark.parametrize("long_skip", ["unet", "unetpp", "unet3p", "unet3p-lite"])
88
@pytest.mark.parametrize("merge_policy", ["cat", "sum"])
9-
def test_decoder_fwdbwd(long_skip, merge_policy):
10-
enc_channels = (64, 32, 16, 8, 8)
11-
out_dims = [256 // 2**i for i in range(6)][::-1]
9+
@pytest.mark.parametrize("use_conv", [True, False])
10+
@pytest.mark.parametrize("use_tr", [True, False])
11+
def test_decoder_fwdbwd(long_skip, merge_policy, use_conv, use_tr):
12+
enc_channels = (64, 32, 64, 32)
13+
out_dims = [32 // 2**i for i in range(4)][::-1]
1214

1315
decoder1_kwargs = {"merge_policy": merge_policy}
1416
decoder2_kwargs = {"merge_policy": merge_policy}
@@ -23,17 +25,30 @@ def test_decoder_fwdbwd(long_skip, merge_policy):
2325
decoder5_kwargs,
2426
)
2527

28+
n_layers = None
29+
n_blocks = None
30+
if use_conv:
31+
n_layers = (1, 1, 1, 1)
32+
n_blocks = ((2,), (2,), (2,), (2,))
33+
34+
n_tr_layers = None
35+
n_tr_blocks = None
36+
if use_tr:
37+
n_tr_layers = (1, 1, 1, 1)
38+
n_tr_blocks = ((1,), (1,), (1,), (1,))
39+
2640
decoder = Decoder(
2741
enc_channels=enc_channels,
28-
model_input_size=256,
29-
out_channels=(64, 32, 16, 8, 8),
30-
n_layers=(1, 1, 1, 1, 1),
31-
n_blocks=((2,), (2,), (2,), (2,), (2,)),
42+
out_channels=(32, 32, 32, 32),
43+
n_layers=n_layers,
44+
n_blocks=n_blocks,
45+
n_transformers=n_tr_layers,
46+
n_transformer_blocks=n_tr_blocks,
3247
long_skip=long_skip,
3348
stage_params=stage_params,
3449
)
3550

36-
x = [torch.rand([1, enc_channels[i], out_dims[i], out_dims[i]]) for i in range(5)]
51+
x = [torch.rand([1, enc_channels[i], out_dims[i], out_dims[i]]) for i in range(4)]
3752
out = decoder(*x)
3853

3954
out[-1].mean().backward()
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
import torch
3+
4+
from cellseg_models_pytorch.modules import Transformer2D
5+
6+
7+
@pytest.mark.parametrize("block_type", ["basic", "slice"])
8+
def test_transformer(block_type):
9+
in_channels = 64
10+
B = 4
11+
H = W = 32
12+
13+
x = torch.rand([B, in_channels, H, W])
14+
tr = Transformer2D(
15+
in_channels=in_channels,
16+
num_heads=4,
17+
head_dim=32,
18+
n_blocks=1,
19+
block_types=(block_type,),
20+
biases=(False,),
21+
dropouts=(0.0,),
22+
slice_size=4,
23+
)
24+
25+
out = tr(x)
26+
27+
assert out.shape == x.shape

0 commit comments

Comments
 (0)