4
4
from cellseg_models_pytorch .models import MultiTaskUnet , get_model
5
5
6
6
7
+ @pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
7
8
@pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
8
9
@pytest .mark .parametrize ("style_channels" , [None , 32 ])
9
10
@pytest .mark .parametrize ("add_stem_skip" , [False , True ])
10
- def test_cppnet_fwdbwd (model_type , style_channels , add_stem_skip ):
11
+ def test_cppnet_fwdbwd (enc_name , model_type , style_channels , add_stem_skip ):
11
12
n_rays = 3
12
- x = torch .rand ([1 , 3 , 32 , 32 ])
13
+ x = torch .rand ([1 , 3 , 64 , 64 ])
13
14
model = get_model (
14
15
name = "cppnet" ,
15
16
type = model_type ,
17
+ enc_name = enc_name ,
16
18
n_rays = n_rays ,
17
19
ntypes = 3 ,
18
20
ntissues = 3 ,
19
21
style_channels = style_channels ,
20
22
add_stem_skip = add_stem_skip ,
23
+ enc_pretrain = False ,
21
24
)
22
25
23
26
y = model (x )
24
27
y ["stardist_refined" ].mean ().backward ()
25
28
26
29
assert y ["type" ].shape == x .shape
27
- assert y ["stardist_refined" ].shape == torch .Size ([1 , n_rays , 32 , 32 ])
30
+ assert y ["stardist_refined" ].shape == torch .Size ([1 , n_rays , 64 , 64 ])
28
31
29
32
if "sem" in y .keys ():
30
- assert y ["sem" ].shape == torch .Size ([1 , 3 , 32 , 32 ])
33
+ assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
31
34
32
35
36
+ @pytest .mark .parametrize (
37
+ "enc_name" ,
38
+ [
39
+ "samvit_base_patch16" ,
40
+ "samvit_base_patch16_224" ,
41
+ "samvit_huge_patch16" ,
42
+ "samvit_large_patch16" ,
43
+ ],
44
+ )
33
45
@pytest .mark .parametrize ("model_type" , ["base" , "plus" , "small_plus" , "small" ])
34
46
@pytest .mark .parametrize ("style_channels" , [None , 32 ])
35
- @pytest .mark .parametrize ("enc_name" , ["sam_vit_b" , "sam_vit_h" , "sam_vit_l" ])
36
- def test_cellvit_fwdbwd (model_type , style_channels , enc_name ):
47
+ def test_cellvit_fwdbwd (enc_name , model_type , style_channels ):
37
48
x = torch .rand ([1 , 3 , 32 , 32 ])
38
49
model = get_model (
39
50
name = "cellvit" ,
40
51
type = model_type ,
52
+ enc_name = enc_name ,
41
53
ntypes = 3 ,
42
54
ntissues = 3 ,
43
55
style_channels = style_channels ,
44
- enc_name = enc_name ,
45
56
enc_pretrain = False ,
46
57
)
47
58
model .freeze_encoder ()
@@ -55,18 +66,21 @@ def test_cellvit_fwdbwd(model_type, style_channels, enc_name):
55
66
assert y ["sem" ].shape == torch .Size ([1 , 3 , 32 , 32 ])
56
67
57
68
69
+ @pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
58
70
@pytest .mark .parametrize ("model_type" , ["base" , "plus" , "small_plus" , "small" ])
59
71
@pytest .mark .parametrize ("style_channels" , [None , 32 ])
60
72
@pytest .mark .parametrize ("add_stem_skip" , [False , True ])
61
- def test_hovernet_fwdbwd (model_type , style_channels , add_stem_skip ):
73
+ def test_hovernet_fwdbwd (enc_name , model_type , style_channels , add_stem_skip ):
62
74
x = torch .rand ([1 , 3 , 64 , 64 ])
63
75
model = get_model (
64
76
name = "hovernet" ,
65
77
type = model_type ,
78
+ enc_name = enc_name ,
66
79
ntypes = 3 ,
67
80
ntissues = 3 ,
68
81
style_channels = style_channels ,
69
82
add_stem_skip = add_stem_skip ,
83
+ enc_pretrain = False ,
70
84
)
71
85
72
86
y = model (x )
@@ -78,42 +92,48 @@ def test_hovernet_fwdbwd(model_type, style_channels, add_stem_skip):
78
92
assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
79
93
80
94
95
+ @pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
81
96
@pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
82
97
@pytest .mark .parametrize ("style_channels" , [None , 32 ])
83
98
@pytest .mark .parametrize ("add_stem_skip" , [False , True ])
84
- def test_stardist_fwdbwd (model_type , style_channels , add_stem_skip ):
99
+ def test_stardist_fwdbwd (enc_name , model_type , style_channels , add_stem_skip ):
85
100
n_rays = 3
86
- x = torch .rand ([1 , 3 , 32 , 32 ])
101
+ x = torch .rand ([1 , 3 , 64 , 64 ])
87
102
model = get_model (
88
103
name = "stardist" ,
89
104
type = model_type ,
90
105
n_rays = n_rays ,
106
+ enc_name = enc_name ,
91
107
ntypes = 3 ,
92
108
ntissues = 3 ,
93
109
style_channels = style_channels ,
94
110
add_stem_skip = add_stem_skip ,
111
+ enc_pretrain = False ,
95
112
)
96
113
97
114
y = model (x )
98
115
y ["stardist" ].mean ().backward ()
99
116
100
117
assert y ["type" ].shape == x .shape
101
- assert y ["stardist" ].shape == torch .Size ([1 , n_rays , 32 , 32 ])
118
+ assert y ["stardist" ].shape == torch .Size ([1 , n_rays , 64 , 64 ])
102
119
103
120
if "sem" in y .keys ():
104
- assert y ["sem" ].shape == torch .Size ([1 , 3 , 32 , 32 ])
121
+ assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
105
122
106
123
124
+ @pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
107
125
@pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
108
126
@pytest .mark .parametrize ("add_stem_skip" , [False , True ])
109
- def test_cellpose_fwdbwd (model_type , add_stem_skip ):
127
+ def test_cellpose_fwdbwd (enc_name , model_type , add_stem_skip ):
110
128
x = torch .rand ([1 , 3 , 64 , 64 ])
111
129
model = get_model (
112
130
name = "cellpose" ,
113
131
type = model_type ,
132
+ enc_name = enc_name ,
114
133
ntypes = 3 ,
115
134
ntissues = 3 ,
116
135
add_stem_skip = add_stem_skip ,
136
+ enc_pretrain = False ,
117
137
)
118
138
119
139
y = model (x )
@@ -126,16 +146,19 @@ def test_cellpose_fwdbwd(model_type, add_stem_skip):
126
146
assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
127
147
128
148
149
+ @pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
129
150
@pytest .mark .parametrize ("model_type" , ["base" , "plus" ])
130
151
@pytest .mark .parametrize ("add_stem_skip" , [False , True ])
131
- def test_cellpose_fwdbwd (model_type , add_stem_skip ):
152
+ def test_cellpose_fwdbwd (enc_name , model_type , add_stem_skip ):
132
153
x = torch .rand ([1 , 3 , 64 , 64 ])
133
154
model = get_model (
134
155
name = "omnipose" ,
135
156
type = model_type ,
157
+ enc_name = enc_name ,
136
158
ntypes = 3 ,
137
159
ntissues = 3 ,
138
160
add_stem_skip = add_stem_skip ,
161
+ enc_pretrain = False ,
139
162
)
140
163
141
164
y = model (x )
@@ -148,8 +171,9 @@ def test_cellpose_fwdbwd(model_type, add_stem_skip):
148
171
assert y ["sem" ].shape == torch .Size ([1 , 3 , 64 , 64 ])
149
172
150
173
174
+ @pytest .mark .parametrize ("enc_name" , ["resnet18" , "samvit_base_patch16" ])
151
175
@pytest .mark .parametrize ("add_stem_skip" , [False , True ])
152
- def test_multitaskunet_fwdbwd (add_stem_skip ):
176
+ def test_multitaskunet_fwdbwd (enc_name , add_stem_skip ):
153
177
x = torch .rand ([1 , 3 , 64 , 64 ])
154
178
m = MultiTaskUnet (
155
179
decoders = ("sem" ,),
@@ -160,6 +184,8 @@ def test_multitaskunet_fwdbwd(add_stem_skip):
160
184
long_skips = {"sem" : "unet" },
161
185
dec_params = {"sem" : None },
162
186
add_stem_skip = add_stem_skip ,
187
+ enc_name = enc_name ,
188
+ enc_pretrain = False ,
163
189
)
164
190
y = m (x )
165
191
y ["sem" ].mean ().backward ()
0 commit comments