Skip to content

Commit a1086c1

Browse files
committed
feat: add enc & dec feat return option to forward
1 parent 74814ab commit a1086c1

File tree

6 files changed

+136
-19
lines changed

6 files changed

+136
-19
lines changed

cellseg_models_pytorch/models/base/_base_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@
1010

1111

1212
class BaseMultiTaskSegModel(nn.ModuleDict):
13+
def forward_features(
14+
self, x: torch.Tensor
15+
) -> Tuple[List[torch.Tensor], Dict[str, torch.Tensor]]:
16+
"""Forward pass for encoder, style and decoders.
17+
18+
NOTE: Returns both encoder and decoder features, not style.
19+
"""
20+
feats = self.forward_encoder(x)
21+
style = self.forward_style(feats[0])
22+
dec_feats = self.forward_dec_features(feats, style)
23+
24+
return feats, dec_feats
25+
1326
def forward_encoder(self, x: torch.Tensor) -> List[torch.Tensor]:
1427
"""Forward the model encoder."""
1528
self._check_input_shape(x)

cellseg_models_pytorch/models/cellpose/cellpose.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Tuple
1+
from typing import Dict, List, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -199,11 +199,42 @@ def __init__(
199199
if enc_freeze:
200200
self.freeze_encoder()
201201

202-
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
203-
"""Forward pass of Cellpose U-net."""
204-
feats = self.forward_encoder(x)
205-
style = self.forward_style(feats[0])
206-
dec_feats = self.forward_dec_features(feats, style)
202+
def forward(
203+
self,
204+
x: torch.Tensor,
205+
return_feats: bool = False,
206+
) -> Union[
207+
Dict[str, torch.Tensor],
208+
Tuple[
209+
List[torch.Tensor],
210+
Dict[str, torch.Tensor],
211+
Dict[str, torch.Tensor],
212+
],
213+
]:
214+
"""Forward pass of Cellpose U-net.
215+
216+
Parameters
217+
----------
218+
x : torch.Tensor
219+
Input image batch. Shape: (B, C, H, W).
220+
return_feats : bool, default=False
221+
If True, encoder, decoder, and head outputs will all be returned
222+
223+
Returns
224+
-------
225+
Union[
226+
Dict[str, torch.Tensor],
227+
Tuple[
228+
List[torch.Tensor],
229+
Dict[str, torch.Tensor],
230+
Dict[str, torch.Tensor],
231+
],
232+
]:
233+
Dictionary mapping of output names to outputs or if `return_feats == True`
234+
returns also the encoder features in a list, decoder features as a dict
235+
mapping decoder names to outputs and the final head outputs dict.
236+
"""
237+
feats, dec_feats = self.forward_features(x)
207238

208239
for decoder_name in self.heads.keys():
209240
for head_name in self.heads[decoder_name].keys():
@@ -212,6 +243,9 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
212243

213244
out = self.forward_heads(dec_feats)
214245

246+
if return_feats:
247+
return feats, dec_feats, out
248+
215249
return out
216250

217251

cellseg_models_pytorch/models/hovernet/hovernet.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Tuple
1+
from typing import Dict, List, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -192,14 +192,47 @@ def __init__(
192192
if enc_freeze:
193193
self.freeze_encoder()
194194

195-
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
196-
"""Forward pass of HoVer-Net."""
197-
feats = self.forward_encoder(x)
198-
style = self.forward_style(feats[0])
195+
def forward(
196+
self,
197+
x: torch.Tensor,
198+
return_feats: bool = False,
199+
) -> Union[
200+
Dict[str, torch.Tensor],
201+
Tuple[
202+
List[torch.Tensor],
203+
Dict[str, torch.Tensor],
204+
Dict[str, torch.Tensor],
205+
],
206+
]:
207+
"""Forward pass of HoVer-Net.
199208
200-
dec_feats = self.forward_dec_features(feats, style)
209+
Parameters
210+
----------
211+
x : torch.Tensor
212+
Input image batch. Shape: (B, C, H, W).
213+
return_feats : bool, default=False
214+
If True, encoder, decoder, and head outputs will all be returned
215+
216+
Returns
217+
-------
218+
Union[
219+
Dict[str, torch.Tensor],
220+
Tuple[
221+
List[torch.Tensor],
222+
Dict[str, torch.Tensor],
223+
Dict[str, torch.Tensor],
224+
],
225+
]:
226+
Dictionary mapping of output names to outputs or if `return_feats == True`
227+
returns also the encoder features in a list, decoder features as a dict
228+
mapping decoder names to outputs and the final head outputs dict.
229+
"""
230+
feats, dec_feats = self.forward_features(x)
201231
out = self.forward_heads(dec_feats)
202232

233+
if return_feats:
234+
return feats, dec_feats, out
235+
203236
return out
204237

205238

cellseg_models_pytorch/models/stardist/stardist.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Tuple
1+
from typing import Dict, List, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -207,12 +207,43 @@ def __init__(
207207
if enc_freeze:
208208
self.freeze_encoder()
209209

210-
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
211-
"""Forward pass of Stardist."""
212-
feats = self.forward_encoder(x)
213-
style = self.forward_style(feats[0])
210+
def forward(
211+
self,
212+
x: torch.Tensor,
213+
return_feats: bool = False,
214+
) -> Union[
215+
Dict[str, torch.Tensor],
216+
Tuple[
217+
List[torch.Tensor],
218+
Dict[str, torch.Tensor],
219+
Dict[str, torch.Tensor],
220+
],
221+
]:
222+
"""Forward pass of Stardist.
223+
224+
Parameters
225+
----------
226+
x : torch.Tensor
227+
Input image batch. Shape: (B, C, H, W).
228+
return_feats : bool, default=False
229+
If True, encoder, decoder, and head outputs will all be returned
230+
231+
Returns
232+
-------
233+
Union[
234+
Dict[str, torch.Tensor],
235+
Tuple[
236+
List[torch.Tensor],
237+
Dict[str, torch.Tensor],
238+
Dict[str, torch.Tensor],
239+
],
240+
]:
241+
Dictionary mapping of output names to outputs or if `return_feats == True`
242+
returns also the encoder features in a list, decoder features as a dict
243+
mapping decoder names to outputs and the final head outputs dict.
244+
"""
245+
feats, dec_feats = self.forward_features(x)
214246

215-
dec_feats = self.forward_dec_features(feats, style)
216247
# Extra convs after decoders
217248
for e in self.extra_convs.keys():
218249
for extra_conv in self.extra_convs[e].keys():
@@ -230,6 +261,9 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
230261

231262
out = self.forward_heads(dec_feats)
232263

264+
if return_feats:
265+
return feats, dec_feats, out
266+
233267
return out
234268

235269

cellseg_models_pytorch/postproc/functional/stardist/stardist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def post_proc_stardist(
204204
nms are accelerated with `numba` and `scipy.spatial.KDtree`.
205205
206206
NOTE:
207-
This implementaiton of the stardist post-processing is actually nearly twice
207+
This implementaiton of the stardist post-processing is actually nearly 2x
208208
faster than the original version if `trim_bboxes` is set to True. The resulting
209209
segmentation is not an exact match but the differences are mostly neglible.
210210
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Features
2+
3+
- Add option to return encoder features, and decoder features along the outputs in the forward pass of any model.

0 commit comments

Comments
 (0)