Skip to content

Commit 8363785

Browse files
committed
fix: use dataclasses for decoder outputs to ease the downstream processing
1 parent bc5fd06 commit 8363785

File tree

1 file changed

+194
-33
lines changed

1 file changed

+194
-33
lines changed

cellseg_models_pytorch/decoders/multitask_decoder.py

Lines changed: 194 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from dataclasses import dataclass, field
12
from itertools import chain
2-
from typing import Any, Dict, List, Tuple
3+
from typing import Any, Dict, List, Optional, Tuple
34

45
import torch
56
import torch.nn as nn
@@ -15,10 +16,29 @@
1516
from cellseg_models_pytorch.models.base._seg_head import SegHead
1617
from cellseg_models_pytorch.modules.misc_modules import StyleReshape
1718

18-
ALLOWED_HEADS = [
19-
"inst",
19+
__all__ = [
20+
"MultiTaskDecoder",
21+
"DecoderSoftOutput",
22+
"SoftInstanceOutput",
23+
"SoftSemanticOutput",
24+
]
25+
26+
27+
INST_SEG_PREFIX = [
28+
"nuc",
29+
"cyto",
30+
]
31+
32+
SEM_SEG_PREFIX = [
33+
"tissue",
34+
]
35+
36+
MODEL_SEG_OUT_TYPES = [
37+
"binary",
2038
"type",
21-
"sem",
39+
]
40+
41+
MODEL_AUX_OUT_TYPES = [
2242
"cellpose",
2343
"omnipose",
2444
"stardist",
@@ -28,7 +48,40 @@
2848
"dran",
2949
]
3050

31-
__all__ = ["MultiTaskDecoder"]
51+
AUX_COMBOS = [
52+
f"{prefix}_{aux}" for prefix in INST_SEG_PREFIX for aux in MODEL_AUX_OUT_TYPES
53+
]
54+
INST_SEG_COMBOS = [
55+
f"{prefix}_{seg}" for prefix in INST_SEG_PREFIX for seg in MODEL_SEG_OUT_TYPES
56+
]
57+
SEM_SEG_COMBOS = [
58+
f"{prefix}_{seg}" for prefix in SEM_SEG_PREFIX for seg in MODEL_SEG_OUT_TYPES
59+
]
60+
61+
62+
@dataclass
63+
class SoftInstanceOutput:
64+
type_map: torch.Tensor
65+
aux_map: torch.Tensor
66+
binary_map: Optional[torch.Tensor] = field(default=None)
67+
parents: Optional[Dict[str, List[str]]] = field(default=None)
68+
69+
70+
@dataclass
71+
class SoftSemanticOutput:
72+
type_map: torch.Tensor
73+
binary_map: Optional[torch.Tensor] = field(default=None)
74+
parents: Optional[Dict[str, List[str]]] = field(default=None)
75+
76+
77+
@dataclass
78+
class DecoderSoftOutput:
79+
nuc_map: SoftInstanceOutput
80+
tissue_map: Optional[SoftSemanticOutput] = field(default=None)
81+
cyto_map: Optional[SoftInstanceOutput] = field(default=None)
82+
83+
dec_feats: Optional[List[torch.Tensor]] = field(default=None)
84+
enc_feats: Optional[List[torch.Tensor]] = field(default=None)
3285

3386

3487
class MultiTaskDecoder(nn.ModuleDict):
@@ -82,6 +135,7 @@ def __init__(
82135
"""
83136
super().__init__()
84137
self.out_size = out_size
138+
self.out_keys = []
85139
self._check_head_args(heads, decoders)
86140
self._check_decoder_args(decoders)
87141
self._check_depth(
@@ -92,6 +146,8 @@ def __init__(
92146
"enc_feature_info": enc_feature_info,
93147
},
94148
)
149+
self.decoders = decoders
150+
self.heads = heads
95151

96152
# get the reduction factors and out channels of the encoder
97153
self.enc_feature_info = enc_feature_info[::-1] # bottleneck first
@@ -136,48 +192,45 @@ def __init__(
136192
self.add_module(f"{decoder_name}_stem_skip", stem_skip)
137193

138194
# set heads
139-
for decoder_name in heads.keys():
140-
for output_name, n_classes in heads[decoder_name].items():
195+
for decoder_name in decoders:
196+
for head_name, n_classes in heads[decoder_name].items():
141197
seg_head = SegHead(
142198
in_channels=decoder.out_channels,
143199
out_channels=n_classes,
144200
kernel_size=1,
145201
excitation_channels=head_excitation_channels,
146202
)
147-
self.add_module(f"{decoder_name}-{output_name}_head", seg_head)
203+
self.add_module(f"{decoder_name}-{head_name}_head", seg_head)
204+
self.out_keys.append(f"{decoder_name}-{head_name}")
148205

149206
def forward_features(
150207
self, feats: List[torch.Tensor], style: torch.Tensor = None
151208
) -> Dict[str, List[torch.Tensor]]:
152209
"""Forward all the decoders and return multi-res feature-lists per branch."""
153210
res = {}
154-
decoders = [k for k in self.keys() if "decoder" in k]
155211

156-
for dec in decoders:
157-
featlist = self[dec](*feats, style=style)
158-
branch = "_".join(dec.split("_")[:-1])
159-
res[branch] = featlist
212+
for decoder_name in self.decoders:
213+
featlist = self[f"{decoder_name}_decoder"](*feats, style=style)
214+
res[decoder_name] = featlist
160215

161216
return res
162217

163218
def forward_heads(
164219
self, dec_feats: Dict[str, torch.Tensor]
165-
) -> Dict[str, torch.Tensor]:
220+
) -> Dict[str, Dict[str, torch.Tensor]]:
166221
"""Forward pass all the seg heads."""
167222
res = {}
168-
heads = [k for k in self.keys() if "head" in k]
169-
for head in heads:
170-
branch_head = head.split("-")
171-
branch = branch_head[0] # branch name
172-
head_name = "_".join(branch_head[1].split("_")[:-1]) # head name
173-
x = self[head](dec_feats[branch][-1]) # the last decoder stage feat map
174-
175-
if self.out_size is not None:
176-
x = F.interpolate(
177-
x, size=self.out_size, mode="bilinear", align_corners=False
178-
)
179-
180-
res[f"{branch}-{head_name}"] = x
223+
for decoder_name in self.decoders:
224+
for head_name in self.heads[decoder_name]:
225+
x = self[f"{decoder_name}-{head_name}_head"](
226+
dec_feats[decoder_name][-1]
227+
) # the last decoder stage feat map
228+
229+
if self.out_size is not None:
230+
x = F.interpolate(
231+
x, size=self.out_size, mode="bilinear", align_corners=False
232+
)
233+
res[f"{decoder_name}-{head_name}"] = x
181234

182235
return res
183236

@@ -229,7 +282,54 @@ def forward(
229282

230283
out = self.forward_heads(dec_feats)
231284

232-
return enc_feats, dec_feats, out
285+
nuc_out = SoftInstanceOutput(
286+
type_map=out[self.nuc_type_key],
287+
aux_map=out[self.nuc_aux_key],
288+
binary_map=out.get(self.nuc_binary_key, None),
289+
parents={
290+
"aux_map": self.nuc_aux_key.split("-"),
291+
"type_map": self.nuc_type_key.split("-"),
292+
"binary_map": self.nuc_binary_key.split("-")
293+
if out.get(self.nuc_binary_key, None) is not None
294+
else None,
295+
},
296+
)
297+
298+
cyto_out = None
299+
if self.cyto_aux_key is not None:
300+
cyto_out = SoftInstanceOutput(
301+
type_map=out[self.cyto_type_key],
302+
aux_map=out[self.cyto_aux_key],
303+
binary_map=out.get(self.cyto_binary_key, None),
304+
parents={
305+
"aux_map": self.cyto_aux_key.split("-"),
306+
"type_map": self.cyto_type_key.split("-"),
307+
"binary_map": self.cyto_binary_key.split("-")
308+
if out.get(self.cyto_binary_key, None) is not None
309+
else None,
310+
},
311+
)
312+
313+
tissue_out = None
314+
if self.tissue_type_key is not None:
315+
tissue_out = SoftSemanticOutput(
316+
type_map=out[self.tissue_type_key],
317+
binary_map=out.get(self.tissue_binary_key, None),
318+
parents={
319+
"type_map": self.tissue_type_key.split("-"),
320+
"binary_map": self.tissue_binary_key.split("-")
321+
if out.get(self.tissue_binary_key, None) is not None
322+
else None,
323+
},
324+
)
325+
326+
return DecoderSoftOutput(
327+
nuc_map=nuc_out,
328+
tissue_map=tissue_out,
329+
cyto_map=cyto_out,
330+
enc_feats=enc_feats,
331+
dec_feats=dec_feats,
332+
)
233333

234334
def initialize(self) -> None:
235335
"""Initialize the decoders and segmentation heads."""
@@ -266,19 +366,80 @@ def _check_head_args(
266366
self, heads: Dict[str, int], decoders: Tuple[str, ...]
267367
) -> None:
268368
"""Check `heads` arg."""
369+
if not set(decoders) == set(heads.keys()):
370+
raise ValueError(
371+
"The decoder names need match exactly to the keys of `heads`. "
372+
f"Got decoders: {decoders} and heads: {list(heads.keys())}."
373+
)
374+
269375
for head in heads.keys():
270376
self._check_string_arg(head)
271377

378+
allowed = AUX_COMBOS + INST_SEG_COMBOS + SEM_SEG_COMBOS
272379
for head in self._get_inner_keys(heads):
273-
if head not in ALLOWED_HEADS:
380+
if head not in allowed:
274381
raise ValueError(
275-
f"Unknown head type: '{head}'. Allowed: {ALLOWED_HEADS}."
382+
f"Invalid head name '{head}'. Allowed names are: {allowed}."
276383
)
277384

278-
if not set(decoders) == set(heads.keys()):
385+
self.nuc_aux_key = None
386+
self.cyto_aux_key = None
387+
self.nuc_type_key = None
388+
self.nuc_binary_key = None
389+
self.cyto_type_key = None
390+
self.cyto_binary_key = None
391+
self.tissue_type_key = None
392+
self.tissue_binary_key = None
393+
for decoder_name in heads.keys():
394+
for head in heads[decoder_name].keys():
395+
val = f"{decoder_name}-{head}"
396+
if head in AUX_COMBOS and head.startswith("nuc_"):
397+
self.nuc_aux_key = val
398+
elif head in AUX_COMBOS and head.startswith("cyto_"):
399+
self.cyto_aux_key = val
400+
elif (
401+
head in INST_SEG_COMBOS
402+
and head.startswith("nuc_")
403+
and head.endswith("binary")
404+
):
405+
self.nuc_binary_key = val
406+
elif (
407+
head in INST_SEG_COMBOS
408+
and head.startswith("cyto_")
409+
and head.endswith("binary")
410+
):
411+
self.cyto_binary_key = val
412+
elif (
413+
head in INST_SEG_COMBOS
414+
and head.startswith("nuc_")
415+
and head.endswith("type")
416+
):
417+
self.nuc_type_key = val
418+
elif (
419+
head in INST_SEG_COMBOS
420+
and head.startswith("cyto_")
421+
and head.endswith("type")
422+
):
423+
self.cyto_type_key = val
424+
elif (
425+
head in SEM_SEG_COMBOS
426+
and head.startswith("tissue_")
427+
and head.endswith("type")
428+
):
429+
self.tissue_type_key = val
430+
elif (
431+
head in SEM_SEG_COMBOS
432+
and head.startswith("tissue_")
433+
and head.endswith("binary")
434+
):
435+
self.tissue_binary_key = val
436+
437+
if self.nuc_aux_key is None or (
438+
self.nuc_type_key is None and self.nuc_binary_key is None
439+
):
279440
raise ValueError(
280-
"The decoder names need match exactly to the keys of `heads`. "
281-
f"Got decoders: {decoders} and heads: {list(heads.keys())}."
441+
"The model must have either 'nuc_type' or 'nuc_binary' keys "
442+
f"and one of: {AUX_COMBOS}"
282443
)
283444

284445
def _check_depth(self, depth: int, arrs: Dict[str, Tuple[Any, ...]]) -> None:

0 commit comments

Comments
 (0)