Skip to content

Commit 1c71c39

Browse files
committed
test(models): update the test suite for the models module
1 parent a438757 commit 1c71c39

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

cellseg_models_pytorch/models/tests/test_models.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,55 @@
44
from cellseg_models_pytorch.models import MultiTaskUnet, get_model
55

66

7+
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
78
@pytest.mark.parametrize("model_type", ["base", "plus"])
89
@pytest.mark.parametrize("style_channels", [None, 32])
910
@pytest.mark.parametrize("add_stem_skip", [False, True])
10-
def test_cppnet_fwdbwd(model_type, style_channels, add_stem_skip):
11+
def test_cppnet_fwdbwd(enc_name, model_type, style_channels, add_stem_skip):
1112
n_rays = 3
12-
x = torch.rand([1, 3, 32, 32])
13+
x = torch.rand([1, 3, 64, 64])
1314
model = get_model(
1415
name="cppnet",
1516
type=model_type,
17+
enc_name=enc_name,
1618
n_rays=n_rays,
1719
ntypes=3,
1820
ntissues=3,
1921
style_channels=style_channels,
2022
add_stem_skip=add_stem_skip,
23+
enc_pretrain=False,
2124
)
2225

2326
y = model(x)
2427
y["stardist_refined"].mean().backward()
2528

2629
assert y["type"].shape == x.shape
27-
assert y["stardist_refined"].shape == torch.Size([1, n_rays, 32, 32])
30+
assert y["stardist_refined"].shape == torch.Size([1, n_rays, 64, 64])
2831

2932
if "sem" in y.keys():
30-
assert y["sem"].shape == torch.Size([1, 3, 32, 32])
33+
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
3134

3235

36+
@pytest.mark.parametrize(
37+
"enc_name",
38+
[
39+
"samvit_base_patch16",
40+
"samvit_base_patch16_224",
41+
"samvit_huge_patch16",
42+
"samvit_large_patch16",
43+
],
44+
)
3345
@pytest.mark.parametrize("model_type", ["base", "plus", "small_plus", "small"])
3446
@pytest.mark.parametrize("style_channels", [None, 32])
35-
@pytest.mark.parametrize("enc_name", ["sam_vit_b", "sam_vit_h", "sam_vit_l"])
36-
def test_cellvit_fwdbwd(model_type, style_channels, enc_name):
47+
def test_cellvit_fwdbwd(enc_name, model_type, style_channels):
3748
x = torch.rand([1, 3, 32, 32])
3849
model = get_model(
3950
name="cellvit",
4051
type=model_type,
52+
enc_name=enc_name,
4153
ntypes=3,
4254
ntissues=3,
4355
style_channels=style_channels,
44-
enc_name=enc_name,
4556
enc_pretrain=False,
4657
)
4758
model.freeze_encoder()
@@ -55,18 +66,21 @@ def test_cellvit_fwdbwd(model_type, style_channels, enc_name):
5566
assert y["sem"].shape == torch.Size([1, 3, 32, 32])
5667

5768

69+
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
5870
@pytest.mark.parametrize("model_type", ["base", "plus", "small_plus", "small"])
5971
@pytest.mark.parametrize("style_channels", [None, 32])
6072
@pytest.mark.parametrize("add_stem_skip", [False, True])
61-
def test_hovernet_fwdbwd(model_type, style_channels, add_stem_skip):
73+
def test_hovernet_fwdbwd(enc_name, model_type, style_channels, add_stem_skip):
6274
x = torch.rand([1, 3, 64, 64])
6375
model = get_model(
6476
name="hovernet",
6577
type=model_type,
78+
enc_name=enc_name,
6679
ntypes=3,
6780
ntissues=3,
6881
style_channels=style_channels,
6982
add_stem_skip=add_stem_skip,
83+
enc_pretrain=False,
7084
)
7185

7286
y = model(x)
@@ -78,42 +92,48 @@ def test_hovernet_fwdbwd(model_type, style_channels, add_stem_skip):
7892
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
7993

8094

95+
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
8196
@pytest.mark.parametrize("model_type", ["base", "plus"])
8297
@pytest.mark.parametrize("style_channels", [None, 32])
8398
@pytest.mark.parametrize("add_stem_skip", [False, True])
84-
def test_stardist_fwdbwd(model_type, style_channels, add_stem_skip):
99+
def test_stardist_fwdbwd(enc_name, model_type, style_channels, add_stem_skip):
85100
n_rays = 3
86-
x = torch.rand([1, 3, 32, 32])
101+
x = torch.rand([1, 3, 64, 64])
87102
model = get_model(
88103
name="stardist",
89104
type=model_type,
90105
n_rays=n_rays,
106+
enc_name=enc_name,
91107
ntypes=3,
92108
ntissues=3,
93109
style_channels=style_channels,
94110
add_stem_skip=add_stem_skip,
111+
enc_pretrain=False,
95112
)
96113

97114
y = model(x)
98115
y["stardist"].mean().backward()
99116

100117
assert y["type"].shape == x.shape
101-
assert y["stardist"].shape == torch.Size([1, n_rays, 32, 32])
118+
assert y["stardist"].shape == torch.Size([1, n_rays, 64, 64])
102119

103120
if "sem" in y.keys():
104-
assert y["sem"].shape == torch.Size([1, 3, 32, 32])
121+
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
105122

106123

124+
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
107125
@pytest.mark.parametrize("model_type", ["base", "plus"])
108126
@pytest.mark.parametrize("add_stem_skip", [False, True])
109-
def test_cellpose_fwdbwd(model_type, add_stem_skip):
127+
def test_cellpose_fwdbwd(enc_name, model_type, add_stem_skip):
110128
x = torch.rand([1, 3, 64, 64])
111129
model = get_model(
112130
name="cellpose",
113131
type=model_type,
132+
enc_name=enc_name,
114133
ntypes=3,
115134
ntissues=3,
116135
add_stem_skip=add_stem_skip,
136+
enc_pretrain=False,
117137
)
118138

119139
y = model(x)
@@ -126,16 +146,19 @@ def test_cellpose_fwdbwd(model_type, add_stem_skip):
126146
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
127147

128148

149+
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
129150
@pytest.mark.parametrize("model_type", ["base", "plus"])
130151
@pytest.mark.parametrize("add_stem_skip", [False, True])
131-
def test_cellpose_fwdbwd(model_type, add_stem_skip):
152+
def test_cellpose_fwdbwd(enc_name, model_type, add_stem_skip):
132153
x = torch.rand([1, 3, 64, 64])
133154
model = get_model(
134155
name="omnipose",
135156
type=model_type,
157+
enc_name=enc_name,
136158
ntypes=3,
137159
ntissues=3,
138160
add_stem_skip=add_stem_skip,
161+
enc_pretrain=False,
139162
)
140163

141164
y = model(x)
@@ -148,8 +171,9 @@ def test_cellpose_fwdbwd(model_type, add_stem_skip):
148171
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
149172

150173

174+
@pytest.mark.parametrize("enc_name", ["resnet18", "samvit_base_patch16"])
151175
@pytest.mark.parametrize("add_stem_skip", [False, True])
152-
def test_multitaskunet_fwdbwd(add_stem_skip):
176+
def test_multitaskunet_fwdbwd(enc_name, add_stem_skip):
153177
x = torch.rand([1, 3, 64, 64])
154178
m = MultiTaskUnet(
155179
decoders=("sem",),
@@ -160,6 +184,8 @@ def test_multitaskunet_fwdbwd(add_stem_skip):
160184
long_skips={"sem": "unet"},
161185
dec_params={"sem": None},
162186
add_stem_skip=add_stem_skip,
187+
enc_name=enc_name,
188+
enc_pretrain=False,
163189
)
164190
y = m(x)
165191
y["sem"].mean().backward()

0 commit comments

Comments
 (0)