1
1
import pytest
2
2
import torch
3
3
4
- from cellseg_models_pytorch .models import (
5
- MultiTaskUnet ,
6
- cellpose_base ,
7
- cellpose_plus ,
8
- hovernet_base ,
9
- hovernet_plus ,
10
- hovernet_small ,
11
- hovernet_small_plus ,
12
- omnipose_base ,
13
- omnipose_plus ,
14
- stardist_base ,
15
- stardist_base_multiclass ,
16
- stardist_plus ,
17
- )
18
-
19
-
20
- @pytest .mark .parametrize (
21
- "model" , [hovernet_base , hovernet_plus , hovernet_small_plus , hovernet_small ]
22
- )
4
+ from cellseg_models_pytorch .models import MultiTaskUnet , get_model
5
+
6
+
7
+ @pytest .mark .parametrize ("model_type" , ["base" , "plus" , "small_plus" , "small" ])
8
+ @pytest .mark .parametrize ("style_channels" , [None , 32 ])
9
+ @pytest .mark .parametrize ("enc_name" , ["sam_vit_b" , "sam_vit_h" , "sam_vit_l" ])
10
+ def test_cellvit_fwdbwd (model_type , style_channels , enc_name ):
11
+ x = torch .rand ([1 , 3 , 32 , 32 ])
12
+ model = get_model (
13
+ name = "cellvit" ,
14
+ type = model_type ,
15
+ ntypes = 3 ,
16
+ ntissues = 3 ,
17
+ style_channels = style_channels ,
18
+ enc_name = enc_name ,
19
+ enc_pretrain = False ,
20
+ )
21
+ model .freeze_encoder ()
22
+
23
+ y = model (x )
24
+ y ["hovernet" ].mean ().backward ()
25
+
26
+ assert y ["type" ].shape == x .shape
27
+
28
+ if "sem" in y .keys ():
29
+ assert y ["sem" ].shape == torch .Size ([1 , 3 , 32 , 32 ])
30
+
31
+
32
+ @pytest .mark .parametrize ("model_type" , ["base" , "plus" , "small_plus" , "small" ])
23
33
@pytest .mark .parametrize ("style_channels" , [None , 32 ])
24
- def test_hovernet_fwdbwd (model , style_channels ):
34
+ @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
35
+ def test_hovernet_fwdbwd (model_type , style_channels , add_stem_skip ):
25
36
x = torch .rand ([1 , 3 , 64 , 64 ])
26
- m = model (
27
- type_classes = 3 ,
28
- sem_classes = 3 ,
37
+ model = get_model (
38
+ name = "hovernet" ,
39
+ type = model_type ,
40
+ ntypes = 3 ,
41
+ ntissues = 3 ,
29
42
style_channels = style_channels ,
43
+ add_stem_skip = add_stem_skip ,
30
44
)
31
- y = m (x )
45
+
46
+ y = model (x )
32
47
y ["hovernet" ].mean ().backward ()
33
48
34
49
assert y ["type" ].shape == x .shape
35
50
51
+ if "sem" in y .keys ():
52
+ assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
53
+
36
54
37
- @pytest .mark .parametrize (
38
- "model" , [stardist_base , stardist_plus , stardist_base_multiclass ]
39
- )
55
+ @pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
40
56
@pytest .mark .parametrize ("style_channels" , [None , 32 ])
41
- def test_stardist_fwdbwd (model , style_channels ):
57
+ @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
58
+ def test_stardist_fwdbwd (model_type , style_channels , add_stem_skip ):
42
59
n_rays = 3
43
- x = torch .rand ([1 , 3 , 64 , 64 ])
44
- m = model (
60
+ x = torch .rand ([1 , 3 , 32 , 32 ])
61
+ model = get_model (
62
+ name = "stardist" ,
63
+ type = model_type ,
45
64
n_rays = n_rays ,
46
- type_classes = 3 ,
47
- sem_classes = 3 ,
65
+ ntypes = 3 ,
66
+ ntissues = 3 ,
48
67
style_channels = style_channels ,
68
+ add_stem_skip = add_stem_skip ,
49
69
)
50
- y = m (x )
70
+
71
+ y = model (x )
51
72
y ["stardist" ].mean ().backward ()
52
73
53
- assert y ["stardist" ].shape == torch .Size ([1 , n_rays , 64 , 64 ])
74
+ assert y ["type" ].shape == x .shape
75
+ assert y ["stardist" ].shape == torch .Size ([1 , n_rays , 32 , 32 ])
54
76
77
+ if "sem" in y .keys ():
78
+ assert y ["sem" ].shape == torch .Size ([1 , 3 , 32 , 32 ])
55
79
56
- @pytest .mark .parametrize ("model" , [cellpose_base , cellpose_plus ])
57
- def test_cellpose_fwdbwd (model ):
80
+
81
+ @pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
82
+ @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
83
+ def test_cellpose_fwdbwd (model_type , add_stem_skip ):
58
84
x = torch .rand ([1 , 3 , 64 , 64 ])
59
- m = model (type_classes = 3 , sem_classes = 3 )
60
- y = m (x )
85
+ model = get_model (
86
+ name = "cellpose" ,
87
+ type = model_type ,
88
+ ntypes = 3 ,
89
+ ntissues = 3 ,
90
+ add_stem_skip = add_stem_skip ,
91
+ )
92
+
93
+ y = model (x )
61
94
y ["cellpose" ].mean ().backward ()
62
95
96
+ assert y ["type" ].shape == x .shape
97
+ assert y ["cellpose" ].shape == torch .Size ([1 , 2 , 64 , 64 ])
98
+
63
99
if "sem" in y .keys ():
64
100
assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
65
101
66
- assert y ["cellpose" ].shape == torch .Size ([1 , 2 , 64 , 64 ])
67
-
68
102
69
- @pytest .mark .parametrize ("model" , [omnipose_base , omnipose_plus ])
70
- def test_omnipose_fwdbwd (model ):
103
+ @pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
104
+ @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
105
+ def test_cellpose_fwdbwd (model_type , add_stem_skip ):
71
106
x = torch .rand ([1 , 3 , 64 , 64 ])
72
- m = model (type_classes = 3 , sem_classes = 3 )
73
- y = m (x )
107
+ model = get_model (
108
+ name = "omnipose" ,
109
+ type = model_type ,
110
+ ntypes = 3 ,
111
+ ntissues = 3 ,
112
+ add_stem_skip = add_stem_skip ,
113
+ )
114
+
115
+ y = model (x )
74
116
y ["omnipose" ].mean ().backward ()
75
117
118
+ assert y ["type" ].shape == x .shape
119
+ assert y ["omnipose" ].shape == torch .Size ([1 , 2 , 64 , 64 ])
120
+
76
121
if "sem" in y .keys ():
77
122
assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
78
123
79
- assert y ["omnipose" ].shape == torch .Size ([1 , 2 , 64 , 64 ])
80
-
81
124
82
- def test_multitaskunet_fwdbwd ():
125
+ @pytest .mark .parametrize ("add_stem_skip" , [False , True ])
126
+ def test_multitaskunet_fwdbwd (add_stem_skip ):
83
127
x = torch .rand ([1 , 3 , 64 , 64 ])
84
128
m = MultiTaskUnet (
85
129
decoders = ("sem" ,),
@@ -89,6 +133,7 @@ def test_multitaskunet_fwdbwd():
89
133
out_channels = {"sem" : (128 , 64 , 32 , 16 )},
90
134
long_skips = {"sem" : "unet" },
91
135
dec_params = {"sem" : None },
136
+ add_stem_skip = add_stem_skip ,
92
137
)
93
138
y = m (x )
94
139
y ["sem" ].mean ().backward ()
0 commit comments