Skip to content

Commit 372db65

Browse files
committed
feat(models): add param enc_out_indices to all model classes. Enables selecting the specific encoder features by their indices.
1 parent bc946c7 commit 372db65

File tree

5 files changed

+80
-80
lines changed

5 files changed

+80
-80
lines changed

cellseg_models_pytorch/models/cellpose/cellpose.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
enc_name: str = "resnet50",
3535
enc_pretrain: bool = True,
3636
enc_freeze: bool = False,
37+
enc_out_indices: Tuple[int, ...] = None,
3738
upsampling: str = "fixed-unpool",
3839
long_skip: str = "unet",
3940
merge_policy: str = "sum",
@@ -95,6 +96,9 @@ def __init__(
9596
Whether to use imagenet pretrained weights in the encoder.
9697
enc_freeze : bool, default=False
9798
Freeze encoder weights for training.
99+
enc_out_indices : Tuple[int, ...], optional
100+
Indices of the output features from the encoder. If None, indices are
101+
set to `range(len(depth))`
98102
upsampling : str, default="fixed-unpool"
99103
The upsampling method. One of: "fixed-unpool", "bilinear", "nearest",
100104
"conv_transpose", "bicubic"
@@ -147,8 +151,17 @@ def __init__(
147151
self.aux_key = self._check_decoder_args(decoders, ("omnipose", "cellpose"))
148152
self.inst_key = inst_key
149153
self._check_head_args(heads, decoders)
154+
155+
if enc_out_indices is None:
156+
enc_out_indices = tuple(range(depth))
157+
150158
self._check_depth(
151-
depth, {"out_channels": out_channels, "layer_depths": layer_depths}
159+
depth,
160+
{
161+
"out_channels": out_channels,
162+
"layer_depths": layer_depths,
163+
"enc_out_indices": enc_out_indices,
164+
},
152165
)
153166

154167
self.enc_freeze = enc_freeze
@@ -177,23 +190,9 @@ def __init__(
177190
for d in decoders
178191
}
179192

180-
# set encoder
181-
# self.encoder = Encoder(
182-
# enc_name,
183-
# depth=depth,
184-
# pretrained=enc_pretrain,
185-
# checkpoint_path=kwargs.get("checkpoint_path", None),
186-
# unettr_kwargs={ # Only used for transformer encoders
187-
# "convolution": convolution,
188-
# "activation": activation,
189-
# "normalization": normalization,
190-
# "attention": attention,
191-
# },
192-
# **encoder_params if encoder_params is not None else {},
193-
# )
194193
self.encoder = Encoder(
195194
timm_encoder_name=enc_name,
196-
timm_encoder_out_indices=tuple(range(depth)),
195+
timm_encoder_out_indices=enc_out_indices,
197196
pixel_decoder_out_channels=out_channels,
198197
timm_encoder_pretrained=enc_pretrain,
199198
timm_extra_kwargs=encoder_params,

cellseg_models_pytorch/models/cellvit/cellvit.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ def __init__(
2828
heads: Dict[str, Dict[str, int]],
2929
inst_key: str = "inst",
3030
out_channels: Tuple[int, ...] = (512, 256, 128, 64),
31-
encoder_out_channels: Tuple[int, ...] = (512, 512, 256, 128),
3231
layer_depths: Tuple[int, ...] = (3, 2, 2, 2),
3332
style_channels: int = None,
3433
enc_name: str = "sam_vit_b",
3534
enc_pretrain: bool = True,
3635
enc_freeze: bool = False,
36+
enc_out_channels: Tuple[int, ...] = None,
37+
enc_out_indices: Tuple[int, ...] = None,
3738
long_skip: str = "unet",
3839
merge_policy: str = "cat",
3940
short_skip: str = "basic",
@@ -74,8 +75,6 @@ def __init__(
7475
inst_key : str, default="inst"
7576
The key for the model output that will be used in the instance
7677
segmentation post-processing pipeline as the binary segmentation result.
77-
encoder_out_channels : Tuple[int, ...], default=(512, 512, 256, 128)
78-
Out channels for each SAM-UnetTR encoder stage.
7978
out_channels : Tuple[int, ...], default=(256, 256, 64, 64)
8079
Out channels for each decoder stage.
8180
layer_depths : Tuple[int, ...], default=(4, 4, 4, 4)
@@ -88,6 +87,11 @@ def __init__(
8887
Whether to use imagenet pretrained weights in the encoder.
8988
enc_freeze : bool, default=False
9089
Freeze encoder weights for training.
90+
enc_out_channels : Tuple[int, ...], default=None
91+
Out channels for each SAM-UnetTR encoder stage.
92+
enc_out_indices : Tuple[int, ...], default=None
93+
Indices of the output features from the encoder. If None,
94+
`len(range(layer_depths))` features are used.
9195
long_skip : str, default="unet"
9296
long skip method to be used. One of: "unet", "unetpp", "unet3p",
9397
"unet3p-lite", None
@@ -133,9 +137,23 @@ def __init__(
133137
self.out_size = out_size
134138
self.aux_key = self._check_decoder_args(decoders, ("hovernet",))
135139
self.inst_key = inst_key
136-
self.depth = 4
140+
self.depth = len(layer_depths)
137141
self._check_head_args(heads, decoders)
138-
self._check_depth(self.depth, {"out_channels": out_channels})
142+
143+
if enc_out_indices is None:
144+
enc_out_indices = tuple(range(self.depth))
145+
146+
if enc_out_channels is None:
147+
enc_out_channels = out_channels
148+
149+
self._check_depth(
150+
self.depth,
151+
{
152+
"out_channels": out_channels,
153+
"enc_out_indices": enc_out_indices,
154+
"enc_out_channels": enc_out_channels,
155+
},
156+
)
139157

140158
self.add_stem_skip = add_stem_skip
141159
self.enc_freeze = enc_freeze
@@ -175,21 +193,11 @@ def __init__(
175193
f"Allowed encoder for CellVit: {allowed}"
176194
)
177195

178-
# set encoder
179-
# self.encoder = EncoderUnetTR(
180-
# backbone=build_sam_encoder(name=enc_name, pretrained=enc_pretrain),
181-
# out_channels=encoder_out_channels,
182-
# up_method="conv_transpose",
183-
# convolution=convolution,
184-
# activation=activation,
185-
# normalization=normalization,
186-
# attention=attention,
187-
# )
188-
196+
# set encoders
189197
self.encoder = Encoder(
190198
timm_encoder_name=enc_name,
191-
timm_encoder_out_indices=tuple(range(len(encoder_out_channels))),
192-
pixel_decoder_out_channels=encoder_out_channels,
199+
timm_encoder_out_indices=enc_out_indices,
200+
pixel_decoder_out_channels=enc_out_channels,
193201
timm_encoder_pretrained=enc_pretrain,
194202
timm_extra_kwargs=encoder_params,
195203
)

cellseg_models_pytorch/models/cppnet/cppnet.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
enc_name: str = "resnet50",
3636
enc_pretrain: bool = True,
3737
enc_freeze: bool = False,
38+
enc_out_indices: Tuple[int, ...] = None,
3839
upsampling: str = "conv_transpose",
3940
long_skip: str = "unet",
4041
merge_policy: str = "cat",
@@ -94,6 +95,9 @@ def __init__(
9495
Whether to use imagenet pretrained weights in the encoder.
9596
enc_freeze : bool, default=False
9697
Freeze encoder weights for training.
98+
enc_out_indices : Optional[Tuple[int]], default=None
99+
Indices of the encoder output features. If None, these are set to
100+
range(len(depth)).
97101
upsampling : str, default="fixed-unpool"
98102
The upsampling method to be used. One of: "fixed-unpool", "nearest",
99103
"bilinear", "bicubic", "conv_transpose"
@@ -150,7 +154,14 @@ def __init__(
150154
self.aux_key = "stardist_refined"
151155
self.inst_key = inst_key
152156
self._check_head_args(heads, decoders)
153-
self._check_depth(depth, {"out_channels": out_channels})
157+
158+
if enc_out_indices is None:
159+
enc_out_indices = tuple(range(depth))
160+
161+
self._check_depth(
162+
depth,
163+
{"out_channels": out_channels, "enc_out_indices": enc_out_indices},
164+
)
154165

155166
self.add_stem_skip = add_stem_skip
156167
self.enc_freeze = enc_freeze
@@ -179,22 +190,9 @@ def __init__(
179190
}
180191

181192
# set encoder
182-
# self.encoder = Encoder(
183-
# enc_name,
184-
# depth=depth,
185-
# pretrained=enc_pretrain,
186-
# checkpoint_path=kwargs.get("checkpoint_path", None),
187-
# unettr_kwargs={ # Only used for transformer encoders
188-
# "convolution": convolution,
189-
# "activation": activation,
190-
# "normalization": normalization,
191-
# "attention": attention,
192-
# },
193-
# **encoder_params if encoder_params is not None else {},
194-
# )
195193
self.encoder = Encoder(
196194
timm_encoder_name=enc_name,
197-
timm_encoder_out_indices=tuple(range(depth)),
195+
timm_encoder_out_indices=enc_out_indices,
198196
pixel_decoder_out_channels=out_channels,
199197
timm_encoder_pretrained=enc_pretrain,
200198
timm_extra_kwargs=encoder_params,

cellseg_models_pytorch/models/hovernet/hovernet.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
enc_name: str = "resnet50",
3434
enc_pretrain: bool = True,
3535
enc_freeze: bool = False,
36+
enc_out_indices: Tuple[int, ...] = None,
3637
upsampling: str = "fixed-unpool",
3738
long_skip: str = "unet",
3839
merge_policy: str = "sum",
@@ -91,6 +92,9 @@ def __init__(
9192
Whether to use imagenet pretrained weights in the encoder.
9293
enc_freeze : bool, default=False
9394
Freeze encoder weights for training.
95+
enc_out_indices : Tuple[int, ...], optional
96+
Indices of the encoder output features. If None, indices is set to
97+
`range(len(depth))`.
9498
upsampling : str, default="fixed-unpool"
9599
The upsampling method to be used. One of: "fixed-unpool", "nearest",
96100
"bilinear", "bicubic", "conv_transpose"
@@ -143,7 +147,14 @@ def __init__(
143147
self.aux_key = self._check_decoder_args(decoders, ("hovernet",))
144148
self.inst_key = inst_key
145149
self._check_head_args(heads, decoders)
146-
self._check_depth(depth, {"out_channels": out_channels})
150+
151+
if enc_out_indices is None:
152+
enc_out_indices = tuple(range(depth))
153+
154+
self._check_depth(
155+
depth,
156+
{"out_channels": out_channels, "enc_out_indices": enc_out_indices},
157+
)
147158

148159
self.add_stem_skip = add_stem_skip
149160
self.enc_freeze = enc_freeze
@@ -172,23 +183,9 @@ def __init__(
172183
}
173184

174185
# set encoder
175-
# self.encoder = Encoder(
176-
# enc_name,
177-
# depth=depth,
178-
# pretrained=enc_pretrain,
179-
# checkpoint_path=kwargs.get("checkpoint_path", None),
180-
# unettr_kwargs={ # Only used for transformer encoders
181-
# "convolution": convolution,
182-
# "activation": activation,
183-
# "normalization": normalization,
184-
# "attention": attention,
185-
# },
186-
# **encoder_params if encoder_params is not None else {},
187-
# )
188-
189186
self.encoder = Encoder(
190187
timm_encoder_name=enc_name,
191-
timm_encoder_out_indices=tuple(range(depth)),
188+
timm_encoder_out_indices=enc_out_indices,
192189
pixel_decoder_out_channels=out_channels,
193190
timm_encoder_pretrained=enc_pretrain,
194191
timm_extra_kwargs=encoder_params,

cellseg_models_pytorch/models/stardist/stardist.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
enc_name: str = "resnet50",
2929
enc_pretrain: bool = True,
3030
enc_freeze: bool = False,
31+
enc_out_indices: Tuple[int, ...] = None,
3132
upsampling: str = "fixed-unpool",
3233
long_skip: str = "unet",
3334
merge_policy: str = "cat",
@@ -91,6 +92,9 @@ def __init__(
9192
Whether to use imagenet pretrained weights in the encoder.
9293
enc_freeze : bool, default=False
9394
Freeze encoder weights for training.
95+
enc_out_indices : Tuple[int, ...], optional
96+
Indices of the encoder output features. If None, indices is set to
97+
`range(len(depth))`.
9498
upsampling : str, default="fixed-unpool"
9599
The upsampling method. One of: "fixed-unpool", "nearest", "bilinear",
96100
"bicubic", "conv_transpose"
@@ -147,7 +151,14 @@ def __init__(
147151
self.inst_key = inst_key
148152
self._check_head_args(extra_convs, decoders)
149153
self._check_head_args(heads, self._get_inner_keys(extra_convs))
150-
self._check_depth(depth, {"out_channels": out_channels})
154+
155+
if enc_out_indices is None:
156+
enc_out_indices = tuple(range(depth))
157+
158+
self._check_depth(
159+
depth,
160+
{"out_channels": out_channels, "enc_out_indices": enc_out_indices},
161+
)
151162

152163
self.enc_freeze = enc_freeze
153164
use_style = style_channels is not None
@@ -177,22 +188,9 @@ def __init__(
177188
}
178189

179190
# set encoder
180-
# self.encoder = Encoder(
181-
# enc_name,
182-
# depth=depth,
183-
# pretrained=enc_pretrain,
184-
# checkpoint_path=kwargs.get("checkpoint_path", None),
185-
# unettr_kwargs={ # Only used for transformer encoders, ignored otherwise
186-
# "convolution": convolution,
187-
# "activation": activation,
188-
# "normalization": normalization,
189-
# "attention": attention,
190-
# },
191-
# **encoder_params if encoder_params is not None else {},
192-
# )
193191
self.encoder = Encoder(
194192
timm_encoder_name=enc_name,
195-
timm_encoder_out_indices=tuple(range(depth)),
193+
timm_encoder_out_indices=enc_out_indices,
196194
pixel_decoder_out_channels=out_channels,
197195
timm_encoder_pretrained=enc_pretrain,
198196
timm_extra_kwargs=encoder_params,

0 commit comments

Comments
 (0)