Skip to content

Commit 00c0440

Browse files
committed
feat(models): return featmaps at every dec stage
1 parent 01e463a commit 00c0440

File tree

9 files changed

+62
-48
lines changed

9 files changed

+62
-48
lines changed

cellseg_models_pytorch/decoders/decoder.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Tuple
1+
from typing import Dict, List, Optional, Tuple
22

33
import torch
44
import torch.nn as nn
@@ -77,13 +77,14 @@ def __init__(
7777

7878
self.out_channels = decoder_block.out_channels
7979

80-
def forward(
81-
self, *features: Tuple[torch.Tensor], style: torch.Tensor = None
82-
) -> torch.Tensor:
83-
"""Forward pass of the decoder."""
80+
def forward_features(
81+
self, features: Tuple[torch.Tensor], style: torch.Tensor = None
82+
) -> List[torch.Tensor]:
83+
"""Forward pass of the decoder. Returns all the decoder stage feats."""
8484
head = features[0]
8585
skips = features[1:]
8686
extra_skips = [head] if self.long_skip == "unet3p" else []
87+
ret_feats = []
8788

8889
x = head
8990
for _, decoder_stage in enumerate(self.values()):
@@ -96,4 +97,14 @@ def forward(
9697
elif self.long_skip == "unet3p":
9798
extra_skips.append(x)
9899

99-
return x
100+
ret_feats.append(x)
101+
102+
return ret_feats
103+
104+
def forward(
105+
self, *features: Tuple[torch.Tensor], style: torch.Tensor = None
106+
) -> torch.Tensor:
107+
"""Forward pass of the decoder."""
108+
dec_feats = self.forward_features(features, style)
109+
110+
return dec_feats

cellseg_models_pytorch/decoders/tests/test_decoders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def test_decoder_fwdbwd(long_skip, merge_policy):
3636
x = [torch.rand([1, enc_channels[i], out_dims[i], out_dims[i]]) for i in range(5)]
3737
out = decoder(*x)
3838

39-
out.mean().backward()
39+
out[-1].mean().backward()
4040

41-
assert out.shape[1] == decoder.out_channels
41+
assert out[-1].shape[1] == decoder.out_channels
4242

4343

4444
@pytest.mark.slow
@@ -163,6 +163,6 @@ def test_decoder_fwdbwd_all(
163163
x = [torch.rand([1, enc_channels[i], out_dims[i], out_dims[i]]) for i in range(5)]
164164
out = decoder(*x)
165165

166-
out.mean().backward()
166+
out[-1].mean().backward()
167167

168-
assert out.shape[1] == decoder.out_channels
168+
assert out[-1].shape[1] == decoder.out_channels

cellseg_models_pytorch/models/base/_base_model.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,38 @@
1010

1111

1212
class BaseMultiTaskSegModel(nn.ModuleDict):
13+
def forward_encoder(self, x: torch.Tensor) -> List[torch.Tensor]:
14+
"""Forward the model encoder."""
15+
self._check_input_shape(x)
16+
feats = self.encoder(x)
17+
18+
return feats
19+
20+
def forward_style(self, feat: torch.Tensor) -> torch.Tensor:
21+
"""Forward the style domain adaptation layer.
22+
23+
NOTE: returns None if style channels are not given at model init.
24+
"""
25+
style = None
26+
if self.make_style is not None:
27+
style = self.make_style(feat)
28+
29+
return style
30+
1331
def forward_dec_features(
1432
self, feats: List[torch.Tensor], style: torch.Tensor = None
15-
) -> Dict[str, torch.Tensor]:
16-
"""Forward pass of the decoders in a multi-task seg model."""
33+
) -> Dict[str, List[torch.Tensor]]:
34+
"""Forward pass of all the decoder features mappings in the model.
35+
36+
NOTE: returns all the features from diff decoder stages in a list.
37+
"""
1738
res = {}
1839
decoders = [k for k in self.keys() if "decoder" in k]
1940

2041
for dec in decoders:
21-
x = self[dec](*feats, style=style)
42+
featlist = self[dec](*feats, style=style)
2243
branch = dec.split("_")[0]
23-
res[branch] = x
44+
res[branch] = featlist
2445

2546
return res
2647

@@ -30,10 +51,9 @@ def forward_heads(
3051
"""Forward pass of the seg heads in a multi-task seg model."""
3152
res = {}
3253
heads = [k for k in self.keys() if "head" in k]
33-
3454
for head in heads:
3555
branch = head.split("_")[0]
36-
x = self[head](dec_feats[branch])
56+
x = self[head](dec_feats[branch][-1]) # the last decoder stage feat map
3757
res[branch] = x
3858

3959
return res

cellseg_models_pytorch/models/base/_multitask_unet.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,8 @@ def __init__(
125125

126126
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
127127
"""Forward pass of Multi-task U-net."""
128-
self._check_input_shape(x)
129-
130-
feats = self.encoder(x)
131-
132-
style = None
133-
if self.make_style is not None:
134-
style = self.make_style(feats[0])
135-
128+
feats = self.forward_encoder(x)
129+
style = self.forward_style(feats[0])
136130
dec_feats = self.forward_dec_features(feats, style)
137131

138132
for decoder_name in self.heads.keys():

cellseg_models_pytorch/models/cellpose/cellpose.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,8 @@ def __init__(
201201

202202
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
203203
"""Forward pass of Cellpose U-net."""
204-
self._check_input_shape(x)
205-
206-
feats = self.encoder(x)
207-
208-
style = None
209-
if self.make_style is not None:
210-
style = self.make_style(feats[0])
211-
204+
feats = self.forward_encoder(x)
205+
style = self.forward_style(feats[0])
212206
dec_feats = self.forward_dec_features(feats, style)
213207

214208
for decoder_name in self.heads.keys():

cellseg_models_pytorch/models/hovernet/hovernet.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,8 @@ def __init__(
194194

195195
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
196196
"""Forward pass of HoVer-Net."""
197-
self._check_input_shape(x)
198-
199-
feats = self.encoder(x)
200-
201-
style = None
202-
if self.make_style is not None:
203-
style = self.make_style(feats[0])
197+
feats = self.forward_encoder(x)
198+
style = self.forward_style(feats[0])
204199

205200
dec_feats = self.forward_dec_features(feats, style)
206201
out = self.forward_heads(dec_feats)

cellseg_models_pytorch/models/stardist/stardist.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,20 +209,17 @@ def __init__(
209209

210210
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
211211
"""Forward pass of Stardist."""
212-
self._check_input_shape(x)
213-
214-
feats = self.encoder(x)
215-
216-
style = None
217-
if self.make_style is not None:
218-
style = self.make_style(feats[0])
212+
feats = self.forward_encoder(x)
213+
style = self.forward_style(feats[0])
219214

220215
dec_feats = self.forward_dec_features(feats, style)
221216
# Extra convs after decoders
222217
for e in self.extra_convs.keys():
223218
for extra_conv in self.extra_convs[e].keys():
224219
k = self.aux_key if extra_conv not in dec_feats.keys() else extra_conv
225-
dec_feats[extra_conv] = self[f"{extra_conv}_features"](dec_feats[k])
220+
dec_feats[extra_conv] = [
221+
self[f"{extra_conv}_features"](dec_feats[k][-1])
222+
] # use last decoder feat
226223

227224
# seg heads
228225
for decoder_name in self.heads.keys():

cellseg_models_pytorch/training/lit/lightning_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474
Params dict for the scheduler. Refer to torch lr_scheduler docs
7575
for the possible scheduler arguments.
7676
log_freq : int, default=100
77-
Return soft masks every every n batches for callbacks and logging.
77+
Return logs every n batches in logging callbacks.
7878
7979
Raises
8080
------
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Features
2+
3+
- Support to return all of the feature maps from each decoder stage.

0 commit comments

Comments
 (0)