Skip to content

Commit a438757

Browse files
committed
fix(models): modify seg models to use the new encoder API
1 parent 78fe346 commit a438757

File tree

8 files changed

+142
-76
lines changed

8 files changed

+142
-76
lines changed

cellseg_models_pytorch/models/base/_base_model.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ def forward_features(
1717
1818
NOTE: Returns both encoder and decoder features, not style.
1919
"""
20-
feats = self.forward_encoder(x)
20+
enc_output, feats = self.forward_encoder(x)
2121
style = self.forward_style(feats[0])
2222
dec_feats = self.forward_dec_features(feats, style)
2323

2424
# final input resolution skip connection
2525
if self.add_stem_skip:
2626
dec_feats = self.forward_stem_skip(x, dec_feats)
2727

28-
return feats, dec_feats
28+
return enc_output, feats, dec_feats
2929

3030
def forward_stem_skip(
3131
self, x: torch.Tensor, dec_feats: Dict[str, torch.Tensor]
@@ -38,12 +38,14 @@ def forward_stem_skip(
3838

3939
return dec_feats
4040

41-
def forward_encoder(self, x: torch.Tensor) -> List[torch.Tensor]:
41+
def forward_encoder(
42+
self, x: torch.Tensor
43+
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
4244
"""Forward the model encoder."""
4345
self._check_input_shape(x)
44-
feats = self.encoder(x)
46+
output, feats = self.encoder(x)
4547

46-
return feats
48+
return output, feats
4749

4850
def forward_style(self, feat: torch.Tensor) -> torch.Tensor:
4951
"""Forward the style domain adaptation layer.

cellseg_models_pytorch/models/base/_multitask_unet.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ def __init__(
3737
out_size: Optional[int] = None,
3838
stem_params: Dict[str, Any] = None,
3939
encoder_params: Optional[Dict] = None,
40-
unettr_kwargs: Optional[Dict] = None,
4140
**kwargs,
4241
) -> None:
4342
"""Create a universal multi-task (2D) unet.
@@ -52,9 +51,9 @@ def __init__(
5251
Names of the decoder branches (has to match `decoders`) mapped to dicts
5352
of output name - number of output classes. E.g.
5453
{"cellpose": {"type": 4, "cellpose": 2}, "sem": {"sem": 5}}
55-
out_channels : Tuple[int, ...]
54+
out_channels : Dict[str, Dict[str, int]]
5655
Out channels for each decoder stage.
57-
long_skips : Dict[str, str]
56+
long_skips : Dict[str, Union[str, Tuple[str, ...]]]
5857
Dictionary mapping decoder branch-names to tuples defining the long skip
5958
method to be used inside each of the decoder stages.
6059
Allowed: "cross-attn", "unet", "unetpp", "unet3p", "unet3p-lite", None
@@ -118,13 +117,20 @@ def __init__(
118117
self.add_stem_skip = add_stem_skip
119118

120119
# set encoder
120+
# self.encoder = Encoder(
121+
# enc_name,
122+
# depth=depth,
123+
# pretrained=enc_pretrain,
124+
# checkpoint_path=kwargs.get("checkpoint_path", None),
125+
# unettr_kwargs=unettr_kwargs,
126+
# **encoder_params if encoder_params is not None else {},
127+
# )
121128
self.encoder = Encoder(
122-
enc_name,
123-
depth=depth,
124-
pretrained=enc_pretrain,
125-
checkpoint_path=kwargs.get("checkpoint_path", None),
126-
unettr_kwargs=unettr_kwargs,
127-
**encoder_params if encoder_params is not None else {},
129+
timm_encoder_name=enc_name,
130+
timm_encoder_out_indices=tuple(range(depth)),
131+
pixel_decoder_out_channels=tuple(out_channels.values())[0],
132+
timm_encoder_pretrained=enc_pretrain,
133+
timm_extra_kwargs=encoder_params,
128134
)
129135

130136
# get the reduction factors for the encoder
@@ -202,7 +208,7 @@ def from_yaml(cls, yaml_path: str) -> nn.Module:
202208

203209
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
204210
"""Forward pass of Multi-task U-net."""
205-
feats = self.forward_encoder(x)
211+
_, feats = self.forward_encoder(x)
206212
style = self.forward_style(feats[0])
207213
dec_feats = self.forward_dec_features(feats, style)
208214
out = self.forward_heads(dec_feats)

cellseg_models_pytorch/models/cellpose/cellpose.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,25 @@ def __init__(
178178
}
179179

180180
# set encoder
181+
# self.encoder = Encoder(
182+
# enc_name,
183+
# depth=depth,
184+
# pretrained=enc_pretrain,
185+
# checkpoint_path=kwargs.get("checkpoint_path", None),
186+
# unettr_kwargs={ # Only used for transformer encoders
187+
# "convolution": convolution,
188+
# "activation": activation,
189+
# "normalization": normalization,
190+
# "attention": attention,
191+
# },
192+
# **encoder_params if encoder_params is not None else {},
193+
# )
181194
self.encoder = Encoder(
182-
enc_name,
183-
depth=depth,
184-
pretrained=enc_pretrain,
185-
checkpoint_path=kwargs.get("checkpoint_path", None),
186-
unettr_kwargs={ # Only used for transformer encoders
187-
"convolution": convolution,
188-
"activation": activation,
189-
"normalization": normalization,
190-
"attention": attention,
191-
},
192-
**encoder_params if encoder_params is not None else {},
195+
timm_encoder_name=enc_name,
196+
timm_encoder_out_indices=tuple(range(depth)),
197+
pixel_decoder_out_channels=out_channels,
198+
timm_encoder_pretrained=enc_pretrain,
199+
timm_extra_kwargs=encoder_params,
193200
)
194201

195202
# get the reduction factors for the encoder
@@ -286,7 +293,7 @@ def forward(
286293
returns also the encoder features in a list, decoder features as a dict
287294
mapping decoder names to outputs and the final head outputs dict.
288295
"""
289-
feats, dec_feats = self.forward_features(x)
296+
_, feats, dec_feats = self.forward_features(x)
290297
out = self.forward_heads(dec_feats)
291298

292299
if return_feats:

cellseg_models_pytorch/models/cellvit/cellvit.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55

66
from cellseg_models_pytorch.decoders import UnetDecoder
77
from cellseg_models_pytorch.decoders.long_skips import StemSkip
8-
from cellseg_models_pytorch.encoders import EncoderUnetTR
9-
from cellseg_models_pytorch.encoders.vit_det_SAM import build_sam_encoder
8+
from cellseg_models_pytorch.encoders import Encoder
109
from cellseg_models_pytorch.modules.misc_modules import StyleReshape
1110

1211
from ..base._base_model import BaseMultiTaskSegModel
@@ -47,6 +46,7 @@ def __init__(
4746
add_stem_skip: Optional[bool] = True,
4847
out_size: Optional[int] = None,
4948
skip_params: Optional[Dict] = None,
49+
encoder_params: Optional[Dict] = None,
5050
**kwargs,
5151
) -> None:
5252
"""Create a CellVit model.
@@ -163,21 +163,35 @@ def __init__(
163163
for d in decoders
164164
}
165165

166-
if enc_name not in ("sam_vit_b", "sam_vit_l", "sam_vit_h"):
166+
allowed = (
167+
"samvit_base_patch16",
168+
"samvit_base_patch16_224",
169+
"samvit_huge_patch16",
170+
"samvit_large_patch16",
171+
)
172+
if enc_name not in allowed:
167173
raise ValueError(
168174
f"Wrong encoder name. Got: {enc_name}. "
169-
"Allowed encoder for CellVit: sam_vit_b, sam_vit_l, sam_vit_h."
175+
f"Allowed encoder for CellVit: {allowed}"
170176
)
171177

172178
# set encoder
173-
self.encoder = EncoderUnetTR(
174-
backbone=build_sam_encoder(name=enc_name, pretrained=enc_pretrain),
175-
out_channels=encoder_out_channels,
176-
up_method="conv_transpose",
177-
convolution=convolution,
178-
activation=activation,
179-
normalization=normalization,
180-
attention=attention,
179+
# self.encoder = EncoderUnetTR(
180+
# backbone=build_sam_encoder(name=enc_name, pretrained=enc_pretrain),
181+
# out_channels=encoder_out_channels,
182+
# up_method="conv_transpose",
183+
# convolution=convolution,
184+
# activation=activation,
185+
# normalization=normalization,
186+
# attention=attention,
187+
# )
188+
189+
self.encoder = Encoder(
190+
timm_encoder_name=enc_name,
191+
timm_encoder_out_indices=tuple(range(len(encoder_out_channels))),
192+
pixel_decoder_out_channels=encoder_out_channels,
193+
timm_encoder_pretrained=enc_pretrain,
194+
timm_extra_kwargs=encoder_params,
181195
)
182196

183197
# get the reduction factors for the encoder
@@ -275,7 +289,7 @@ def forward(
275289
returns also the encoder features in a list, decoder features as a dict
276290
mapping decoder names to outputs and the final head outputs dict.
277291
"""
278-
feats, dec_feats = self.forward_features(x)
292+
_, feats, dec_feats = self.forward_features(x)
279293
out = self.forward_heads(dec_feats)
280294

281295
if return_feats:

cellseg_models_pytorch/models/cppnet/cppnet.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,18 +179,25 @@ def __init__(
179179
}
180180

181181
# set encoder
182+
# self.encoder = Encoder(
183+
# enc_name,
184+
# depth=depth,
185+
# pretrained=enc_pretrain,
186+
# checkpoint_path=kwargs.get("checkpoint_path", None),
187+
# unettr_kwargs={ # Only used for transformer encoders
188+
# "convolution": convolution,
189+
# "activation": activation,
190+
# "normalization": normalization,
191+
# "attention": attention,
192+
# },
193+
# **encoder_params if encoder_params is not None else {},
194+
# )
182195
self.encoder = Encoder(
183-
enc_name,
184-
depth=depth,
185-
pretrained=enc_pretrain,
186-
checkpoint_path=kwargs.get("checkpoint_path", None),
187-
unettr_kwargs={ # Only used for transformer encoders
188-
"convolution": convolution,
189-
"activation": activation,
190-
"normalization": normalization,
191-
"attention": attention,
192-
},
193-
**encoder_params if encoder_params is not None else {},
196+
timm_encoder_name=enc_name,
197+
timm_encoder_out_indices=tuple(range(depth)),
198+
pixel_decoder_out_channels=out_channels,
199+
timm_encoder_pretrained=enc_pretrain,
200+
timm_extra_kwargs=encoder_params,
194201
)
195202

196203
# get the reduction factors for the encoder
@@ -358,7 +365,7 @@ def forward(
358365
returns also the encoder features in a list, decoder features as a dict
359366
mapping decoder names to outputs and the final head outputs dict.
360367
"""
361-
feats, dec_feats = self.forward_features(x)
368+
_, feats, dec_feats = self.forward_features(x)
362369
out = self.forward_heads(dec_feats)
363370

364371
# cppnet specific

cellseg_models_pytorch/models/hovernet/hovernet.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,26 @@ def __init__(
172172
}
173173

174174
# set encoder
175+
# self.encoder = Encoder(
176+
# enc_name,
177+
# depth=depth,
178+
# pretrained=enc_pretrain,
179+
# checkpoint_path=kwargs.get("checkpoint_path", None),
180+
# unettr_kwargs={ # Only used for transformer encoders
181+
# "convolution": convolution,
182+
# "activation": activation,
183+
# "normalization": normalization,
184+
# "attention": attention,
185+
# },
186+
# **encoder_params if encoder_params is not None else {},
187+
# )
188+
175189
self.encoder = Encoder(
176-
enc_name,
177-
depth=depth,
178-
pretrained=enc_pretrain,
179-
checkpoint_path=kwargs.get("checkpoint_path", None),
180-
unettr_kwargs={ # Only used for transformer encoders
181-
"convolution": convolution,
182-
"activation": activation,
183-
"normalization": normalization,
184-
"attention": attention,
185-
},
186-
**encoder_params if encoder_params is not None else {},
190+
timm_encoder_name=enc_name,
191+
timm_encoder_out_indices=tuple(range(depth)),
192+
pixel_decoder_out_channels=out_channels,
193+
timm_encoder_pretrained=enc_pretrain,
194+
timm_extra_kwargs=encoder_params,
187195
)
188196

189197
# get the reduction factors for the encoder
@@ -281,7 +289,7 @@ def forward(
281289
returns also the encoder features in a list, decoder features as a dict
282290
mapping decoder names to outputs and the final head outputs dict.
283291
"""
284-
feats, dec_feats = self.forward_features(x)
292+
_, feats, dec_feats = self.forward_features(x)
285293
out = self.forward_heads(dec_feats)
286294

287295
if return_feats:

cellseg_models_pytorch/models/stardist/stardist.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -177,18 +177,25 @@ def __init__(
177177
}
178178

179179
# set encoder
180+
# self.encoder = Encoder(
181+
# enc_name,
182+
# depth=depth,
183+
# pretrained=enc_pretrain,
184+
# checkpoint_path=kwargs.get("checkpoint_path", None),
185+
# unettr_kwargs={ # Only used for transformer encoders, ignored otherwise
186+
# "convolution": convolution,
187+
# "activation": activation,
188+
# "normalization": normalization,
189+
# "attention": attention,
190+
# },
191+
# **encoder_params if encoder_params is not None else {},
192+
# )
180193
self.encoder = Encoder(
181-
enc_name,
182-
depth=depth,
183-
pretrained=enc_pretrain,
184-
checkpoint_path=kwargs.get("checkpoint_path", None),
185-
unettr_kwargs={ # Only used for transformer encoders, ignored otherwise
186-
"convolution": convolution,
187-
"activation": activation,
188-
"normalization": normalization,
189-
"attention": attention,
190-
},
191-
**encoder_params if encoder_params is not None else {},
194+
timm_encoder_name=enc_name,
195+
timm_encoder_out_indices=tuple(range(depth)),
196+
pixel_decoder_out_channels=out_channels,
197+
timm_encoder_pretrained=enc_pretrain,
198+
timm_extra_kwargs=encoder_params,
192199
)
193200

194201
# get the reduction factors for the encoder
@@ -315,7 +322,7 @@ def forward(
315322
returns also the encoder features in a list, decoder features as a dict
316323
mapping decoder names to outputs and the final head outputs dict.
317324
"""
318-
feats, dec_feats = self.forward_features(x)
325+
_, feats, dec_feats = self.forward_features(x)
319326

320327
if return_feats:
321328
ret_dec_feats = dec_feats.copy()
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
## Breaking changes
2+
- Lose support for python 3.9
3+
4+
## Chore
5+
- Update timm version to above 1.0.0.
6+
7+
## Features
8+
- Image encoders now only from timm models.
9+
10+
# Removed
11+
- SAM and DINOv2 original implementation image-encoders removed from this repo. These can be found from timm models these days.
12+
13+
## Examples
14+
- Updated example notebooks.
15+
- Added new example notebooks utilizing UNI and CONCH encoders from the huggingface model hub.

0 commit comments

Comments
 (0)