Skip to content

Commit 4b9b2de

Browse files
committed
test: update model tests
1 parent 071fdd3 commit 4b9b2de

File tree

1 file changed

+92
-47
lines changed

1 file changed

+92
-47
lines changed

cellseg_models_pytorch/models/tests/test_models.py

Lines changed: 92 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,129 @@
11
import pytest
22
import torch
33

4-
from cellseg_models_pytorch.models import (
5-
MultiTaskUnet,
6-
cellpose_base,
7-
cellpose_plus,
8-
hovernet_base,
9-
hovernet_plus,
10-
hovernet_small,
11-
hovernet_small_plus,
12-
omnipose_base,
13-
omnipose_plus,
14-
stardist_base,
15-
stardist_base_multiclass,
16-
stardist_plus,
17-
)
18-
19-
20-
@pytest.mark.parametrize(
21-
"model", [hovernet_base, hovernet_plus, hovernet_small_plus, hovernet_small]
22-
)
4+
from cellseg_models_pytorch.models import MultiTaskUnet, get_model
5+
6+
7+
@pytest.mark.parametrize("model_type", ["base", "plus", "small_plus", "small"])
8+
@pytest.mark.parametrize("style_channels", [None, 32])
9+
@pytest.mark.parametrize("enc_name", ["sam_vit_b", "sam_vit_h", "sam_vit_l"])
10+
def test_cellvit_fwdbwd(model_type, style_channels, enc_name):
11+
x = torch.rand([1, 3, 32, 32])
12+
model = get_model(
13+
name="cellvit",
14+
type=model_type,
15+
ntypes=3,
16+
ntissues=3,
17+
style_channels=style_channels,
18+
enc_name=enc_name,
19+
enc_pretrain=False,
20+
)
21+
model.freeze_encoder()
22+
23+
y = model(x)
24+
y["hovernet"].mean().backward()
25+
26+
assert y["type"].shape == x.shape
27+
28+
if "sem" in y.keys():
29+
assert y["sem"].shape == torch.Size([1, 3, 32, 32])
30+
31+
32+
@pytest.mark.parametrize("model_type", ["base", "plus", "small_plus", "small"])
2333
@pytest.mark.parametrize("style_channels", [None, 32])
24-
def test_hovernet_fwdbwd(model, style_channels):
34+
@pytest.mark.parametrize("add_stem_skip", [False, True])
35+
def test_hovernet_fwdbwd(model_type, style_channels, add_stem_skip):
2536
x = torch.rand([1, 3, 64, 64])
26-
m = model(
27-
type_classes=3,
28-
sem_classes=3,
37+
model = get_model(
38+
name="hovernet",
39+
type=model_type,
40+
ntypes=3,
41+
ntissues=3,
2942
style_channels=style_channels,
43+
add_stem_skip=add_stem_skip,
3044
)
31-
y = m(x)
45+
46+
y = model(x)
3247
y["hovernet"].mean().backward()
3348

3449
assert y["type"].shape == x.shape
3550

51+
if "sem" in y.keys():
52+
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
53+
3654

37-
@pytest.mark.parametrize(
38-
"model", [stardist_base, stardist_plus, stardist_base_multiclass]
39-
)
55+
@pytest.mark.parametrize("model_type", ["base", "plus"])
4056
@pytest.mark.parametrize("style_channels", [None, 32])
41-
def test_stardist_fwdbwd(model, style_channels):
57+
@pytest.mark.parametrize("add_stem_skip", [False, True])
58+
def test_stardist_fwdbwd(model_type, style_channels, add_stem_skip):
4259
n_rays = 3
43-
x = torch.rand([1, 3, 64, 64])
44-
m = model(
60+
x = torch.rand([1, 3, 32, 32])
61+
model = get_model(
62+
name="stardist",
63+
type=model_type,
4564
n_rays=n_rays,
46-
type_classes=3,
47-
sem_classes=3,
65+
ntypes=3,
66+
ntissues=3,
4867
style_channels=style_channels,
68+
add_stem_skip=add_stem_skip,
4969
)
50-
y = m(x)
70+
71+
y = model(x)
5172
y["stardist"].mean().backward()
5273

53-
assert y["stardist"].shape == torch.Size([1, n_rays, 64, 64])
74+
assert y["type"].shape == x.shape
75+
assert y["stardist"].shape == torch.Size([1, n_rays, 32, 32])
5476

77+
if "sem" in y.keys():
78+
assert y["sem"].shape == torch.Size([1, 3, 32, 32])
5579

56-
@pytest.mark.parametrize("model", [cellpose_base, cellpose_plus])
57-
def test_cellpose_fwdbwd(model):
80+
81+
@pytest.mark.parametrize("model_type", ["base", "plus"])
82+
@pytest.mark.parametrize("add_stem_skip", [False, True])
83+
def test_cellpose_fwdbwd(model_type, add_stem_skip):
5884
x = torch.rand([1, 3, 64, 64])
59-
m = model(type_classes=3, sem_classes=3)
60-
y = m(x)
85+
model = get_model(
86+
name="cellpose",
87+
type=model_type,
88+
ntypes=3,
89+
ntissues=3,
90+
add_stem_skip=add_stem_skip,
91+
)
92+
93+
y = model(x)
6194
y["cellpose"].mean().backward()
6295

96+
assert y["type"].shape == x.shape
97+
assert y["cellpose"].shape == torch.Size([1, 2, 64, 64])
98+
6399
if "sem" in y.keys():
64100
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
65101

66-
assert y["cellpose"].shape == torch.Size([1, 2, 64, 64])
67-
68102

69-
@pytest.mark.parametrize("model", [omnipose_base, omnipose_plus])
70-
def test_omnipose_fwdbwd(model):
103+
@pytest.mark.parametrize("model_type", ["base", "plus"])
104+
@pytest.mark.parametrize("add_stem_skip", [False, True])
105+
def test_cellpose_fwdbwd(model_type, add_stem_skip):
71106
x = torch.rand([1, 3, 64, 64])
72-
m = model(type_classes=3, sem_classes=3)
73-
y = m(x)
107+
model = get_model(
108+
name="omnipose",
109+
type=model_type,
110+
ntypes=3,
111+
ntissues=3,
112+
add_stem_skip=add_stem_skip,
113+
)
114+
115+
y = model(x)
74116
y["omnipose"].mean().backward()
75117

118+
assert y["type"].shape == x.shape
119+
assert y["omnipose"].shape == torch.Size([1, 2, 64, 64])
120+
76121
if "sem" in y.keys():
77122
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
78123

79-
assert y["omnipose"].shape == torch.Size([1, 2, 64, 64])
80-
81124

82-
def test_multitaskunet_fwdbwd():
125+
@pytest.mark.parametrize("add_stem_skip", [False, True])
126+
def test_multitaskunet_fwdbwd(add_stem_skip):
83127
x = torch.rand([1, 3, 64, 64])
84128
m = MultiTaskUnet(
85129
decoders=("sem",),
@@ -89,6 +133,7 @@ def test_multitaskunet_fwdbwd():
89133
out_channels={"sem": (128, 64, 32, 16)},
90134
long_skips={"sem": "unet"},
91135
dec_params={"sem": None},
136+
add_stem_skip=add_stem_skip,
92137
)
93138
y = m(x)
94139
y["sem"].mean().backward()

0 commit comments

Comments
 (0)