1
+ from dataclasses import dataclass , field
1
2
from itertools import chain
2
- from typing import Any , Dict , List , Tuple
3
+ from typing import Any , Dict , List , Optional , Tuple
3
4
4
5
import torch
5
6
import torch .nn as nn
15
16
from cellseg_models_pytorch .models .base ._seg_head import SegHead
16
17
from cellseg_models_pytorch .modules .misc_modules import StyleReshape
17
18
18
- ALLOWED_HEADS = [
19
- "inst" ,
19
+ __all__ = [
20
+ "MultiTaskDecoder" ,
21
+ "DecoderSoftOutput" ,
22
+ "SoftInstanceOutput" ,
23
+ "SoftSemanticOutput" ,
24
+ ]
25
+
26
+
27
+ INST_SEG_PREFIX = [
28
+ "nuc" ,
29
+ "cyto" ,
30
+ ]
31
+
32
+ SEM_SEG_PREFIX = [
33
+ "tissue" ,
34
+ ]
35
+
36
+ MODEL_SEG_OUT_TYPES = [
37
+ "binary" ,
20
38
"type" ,
21
- "sem" ,
39
+ ]
40
+
41
+ MODEL_AUX_OUT_TYPES = [
22
42
"cellpose" ,
23
43
"omnipose" ,
24
44
"stardist" ,
28
48
"dran" ,
29
49
]
30
50
31
- __all__ = ["MultiTaskDecoder" ]
51
+ AUX_COMBOS = [
52
+ f"{ prefix } _{ aux } " for prefix in INST_SEG_PREFIX for aux in MODEL_AUX_OUT_TYPES
53
+ ]
54
+ INST_SEG_COMBOS = [
55
+ f"{ prefix } _{ seg } " for prefix in INST_SEG_PREFIX for seg in MODEL_SEG_OUT_TYPES
56
+ ]
57
+ SEM_SEG_COMBOS = [
58
+ f"{ prefix } _{ seg } " for prefix in SEM_SEG_PREFIX for seg in MODEL_SEG_OUT_TYPES
59
+ ]
60
+
61
+
62
+ @dataclass
63
+ class SoftInstanceOutput :
64
+ type_map : torch .Tensor
65
+ aux_map : torch .Tensor
66
+ binary_map : Optional [torch .Tensor ] = field (default = None )
67
+ parents : Optional [Dict [str , List [str ]]] = field (default = None )
68
+
69
+
70
+ @dataclass
71
+ class SoftSemanticOutput :
72
+ type_map : torch .Tensor
73
+ binary_map : Optional [torch .Tensor ] = field (default = None )
74
+ parents : Optional [Dict [str , List [str ]]] = field (default = None )
75
+
76
+
77
+ @dataclass
78
+ class DecoderSoftOutput :
79
+ nuc_map : SoftInstanceOutput
80
+ tissue_map : Optional [SoftSemanticOutput ] = field (default = None )
81
+ cyto_map : Optional [SoftInstanceOutput ] = field (default = None )
82
+
83
+ dec_feats : Optional [List [torch .Tensor ]] = field (default = None )
84
+ enc_feats : Optional [List [torch .Tensor ]] = field (default = None )
32
85
33
86
34
87
class MultiTaskDecoder (nn .ModuleDict ):
@@ -82,6 +135,7 @@ def __init__(
82
135
"""
83
136
super ().__init__ ()
84
137
self .out_size = out_size
138
+ self .out_keys = []
85
139
self ._check_head_args (heads , decoders )
86
140
self ._check_decoder_args (decoders )
87
141
self ._check_depth (
@@ -92,6 +146,8 @@ def __init__(
92
146
"enc_feature_info" : enc_feature_info ,
93
147
},
94
148
)
149
+ self .decoders = decoders
150
+ self .heads = heads
95
151
96
152
# get the reduction factors and out channels of the encoder
97
153
self .enc_feature_info = enc_feature_info [::- 1 ] # bottleneck first
@@ -136,48 +192,45 @@ def __init__(
136
192
self .add_module (f"{ decoder_name } _stem_skip" , stem_skip )
137
193
138
194
# set heads
139
- for decoder_name in heads . keys () :
140
- for output_name , n_classes in heads [decoder_name ].items ():
195
+ for decoder_name in decoders :
196
+ for head_name , n_classes in heads [decoder_name ].items ():
141
197
seg_head = SegHead (
142
198
in_channels = decoder .out_channels ,
143
199
out_channels = n_classes ,
144
200
kernel_size = 1 ,
145
201
excitation_channels = head_excitation_channels ,
146
202
)
147
- self .add_module (f"{ decoder_name } -{ output_name } _head" , seg_head )
203
+ self .add_module (f"{ decoder_name } -{ head_name } _head" , seg_head )
204
+ self .out_keys .append (f"{ decoder_name } -{ head_name } " )
148
205
149
206
def forward_features (
150
207
self , feats : List [torch .Tensor ], style : torch .Tensor = None
151
208
) -> Dict [str , List [torch .Tensor ]]:
152
209
"""Forward all the decoders and return multi-res feature-lists per branch."""
153
210
res = {}
154
- decoders = [k for k in self .keys () if "decoder" in k ]
155
211
156
- for dec in decoders :
157
- featlist = self [dec ](* feats , style = style )
158
- branch = "_" .join (dec .split ("_" )[:- 1 ])
159
- res [branch ] = featlist
212
+ for decoder_name in self .decoders :
213
+ featlist = self [f"{ decoder_name } _decoder" ](* feats , style = style )
214
+ res [decoder_name ] = featlist
160
215
161
216
return res
162
217
163
218
def forward_heads (
164
219
self , dec_feats : Dict [str , torch .Tensor ]
165
- ) -> Dict [str , torch .Tensor ]:
220
+ ) -> Dict [str , Dict [ str , torch .Tensor ] ]:
166
221
"""Forward pass all the seg heads."""
167
222
res = {}
168
- heads = [k for k in self .keys () if "head" in k ]
169
- for head in heads :
170
- branch_head = head .split ("-" )
171
- branch = branch_head [0 ] # branch name
172
- head_name = "_" .join (branch_head [1 ].split ("_" )[:- 1 ]) # head name
173
- x = self [head ](dec_feats [branch ][- 1 ]) # the last decoder stage feat map
174
-
175
- if self .out_size is not None :
176
- x = F .interpolate (
177
- x , size = self .out_size , mode = "bilinear" , align_corners = False
178
- )
179
-
180
- res [f"{ branch } -{ head_name } " ] = x
223
+ for decoder_name in self .decoders :
224
+ for head_name in self .heads [decoder_name ]:
225
+ x = self [f"{ decoder_name } -{ head_name } _head" ](
226
+ dec_feats [decoder_name ][- 1 ]
227
+ ) # the last decoder stage feat map
228
+
229
+ if self .out_size is not None :
230
+ x = F .interpolate (
231
+ x , size = self .out_size , mode = "bilinear" , align_corners = False
232
+ )
233
+ res [f"{ decoder_name } -{ head_name } " ] = x
181
234
182
235
return res
183
236
@@ -229,7 +282,54 @@ def forward(
229
282
230
283
out = self .forward_heads (dec_feats )
231
284
232
- return enc_feats , dec_feats , out
285
+ nuc_out = SoftInstanceOutput (
286
+ type_map = out [self .nuc_type_key ],
287
+ aux_map = out [self .nuc_aux_key ],
288
+ binary_map = out .get (self .nuc_binary_key , None ),
289
+ parents = {
290
+ "aux_map" : self .nuc_aux_key .split ("-" ),
291
+ "type_map" : self .nuc_type_key .split ("-" ),
292
+ "binary_map" : self .nuc_binary_key .split ("-" )
293
+ if out .get (self .nuc_binary_key , None ) is not None
294
+ else None ,
295
+ },
296
+ )
297
+
298
+ cyto_out = None
299
+ if self .cyto_aux_key is not None :
300
+ cyto_out = SoftInstanceOutput (
301
+ type_map = out [self .cyto_type_key ],
302
+ aux_map = out [self .cyto_aux_key ],
303
+ binary_map = out .get (self .cyto_binary_key , None ),
304
+ parents = {
305
+ "aux_map" : self .cyto_aux_key .split ("-" ),
306
+ "type_map" : self .cyto_type_key .split ("-" ),
307
+ "binary_map" : self .cyto_binary_key .split ("-" )
308
+ if out .get (self .cyto_binary_key , None ) is not None
309
+ else None ,
310
+ },
311
+ )
312
+
313
+ tissue_out = None
314
+ if self .tissue_type_key is not None :
315
+ tissue_out = SoftSemanticOutput (
316
+ type_map = out [self .tissue_type_key ],
317
+ binary_map = out .get (self .tissue_binary_key , None ),
318
+ parents = {
319
+ "type_map" : self .tissue_type_key .split ("-" ),
320
+ "binary_map" : self .tissue_binary_key .split ("-" )
321
+ if out .get (self .tissue_binary_key , None ) is not None
322
+ else None ,
323
+ },
324
+ )
325
+
326
+ return DecoderSoftOutput (
327
+ nuc_map = nuc_out ,
328
+ tissue_map = tissue_out ,
329
+ cyto_map = cyto_out ,
330
+ enc_feats = enc_feats ,
331
+ dec_feats = dec_feats ,
332
+ )
233
333
234
334
def initialize (self ) -> None :
235
335
"""Initialize the decoders and segmentation heads."""
@@ -266,19 +366,80 @@ def _check_head_args(
266
366
self , heads : Dict [str , int ], decoders : Tuple [str , ...]
267
367
) -> None :
268
368
"""Check `heads` arg."""
369
+ if not set (decoders ) == set (heads .keys ()):
370
+ raise ValueError (
371
+ "The decoder names need match exactly to the keys of `heads`. "
372
+ f"Got decoders: { decoders } and heads: { list (heads .keys ())} ."
373
+ )
374
+
269
375
for head in heads .keys ():
270
376
self ._check_string_arg (head )
271
377
378
+ allowed = AUX_COMBOS + INST_SEG_COMBOS + SEM_SEG_COMBOS
272
379
for head in self ._get_inner_keys (heads ):
273
- if head not in ALLOWED_HEADS :
380
+ if head not in allowed :
274
381
raise ValueError (
275
- f"Unknown head type: '{ head } '. Allowed: { ALLOWED_HEADS } ."
382
+ f"Invalid head name '{ head } '. Allowed names are : { allowed } ."
276
383
)
277
384
278
- if not set (decoders ) == set (heads .keys ()):
385
+ self .nuc_aux_key = None
386
+ self .cyto_aux_key = None
387
+ self .nuc_type_key = None
388
+ self .nuc_binary_key = None
389
+ self .cyto_type_key = None
390
+ self .cyto_binary_key = None
391
+ self .tissue_type_key = None
392
+ self .tissue_binary_key = None
393
+ for decoder_name in heads .keys ():
394
+ for head in heads [decoder_name ].keys ():
395
+ val = f"{ decoder_name } -{ head } "
396
+ if head in AUX_COMBOS and head .startswith ("nuc_" ):
397
+ self .nuc_aux_key = val
398
+ elif head in AUX_COMBOS and head .startswith ("cyto_" ):
399
+ self .cyto_aux_key = val
400
+ elif (
401
+ head in INST_SEG_COMBOS
402
+ and head .startswith ("nuc_" )
403
+ and head .endswith ("binary" )
404
+ ):
405
+ self .nuc_binary_key = val
406
+ elif (
407
+ head in INST_SEG_COMBOS
408
+ and head .startswith ("cyto_" )
409
+ and head .endswith ("binary" )
410
+ ):
411
+ self .cyto_binary_key = val
412
+ elif (
413
+ head in INST_SEG_COMBOS
414
+ and head .startswith ("nuc_" )
415
+ and head .endswith ("type" )
416
+ ):
417
+ self .nuc_type_key = val
418
+ elif (
419
+ head in INST_SEG_COMBOS
420
+ and head .startswith ("cyto_" )
421
+ and head .endswith ("type" )
422
+ ):
423
+ self .cyto_type_key = val
424
+ elif (
425
+ head in SEM_SEG_COMBOS
426
+ and head .startswith ("tissue_" )
427
+ and head .endswith ("type" )
428
+ ):
429
+ self .tissue_type_key = val
430
+ elif (
431
+ head in SEM_SEG_COMBOS
432
+ and head .startswith ("tissue_" )
433
+ and head .endswith ("binary" )
434
+ ):
435
+ self .tissue_binary_key = val
436
+
437
+ if self .nuc_aux_key is None or (
438
+ self .nuc_type_key is None and self .nuc_binary_key is None
439
+ ):
279
440
raise ValueError (
280
- "The decoder names need match exactly to the keys of `heads`. "
281
- f"Got decoders: { decoders } and heads : { list ( heads . keys ()) } . "
441
+ "The model must have either 'nuc_type' or 'nuc_binary' keys "
442
+ f"and one of : { AUX_COMBOS } "
282
443
)
283
444
284
445
def _check_depth (self , depth : int , arrs : Dict [str , Tuple [Any , ...]]) -> None :
0 commit comments