Skip to content

Commit 18ddb1e

Browse files
committed
feat(models): Add multi-task U-net for experiments
1 parent 07f31e8 commit 18ddb1e

File tree

10 files changed

+174
-8
lines changed

10 files changed

+174
-8
lines changed

cellseg_models_pytorch/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .base._multitask_unet import MultiTaskUnet
12
from .cellpose.cellpose import (
23
CellPoseUnet,
34
cellpose_base,
@@ -20,6 +21,7 @@
2021
)
2122

2223
__all__ = [
24+
"MultiTaskUnet",
2325
"HoverNet",
2426
"hovernet_base",
2527
"hovernet_plus",
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from typing import Any, Dict, Tuple
2+
3+
import torch
4+
5+
from ...decoders import Decoder
6+
from ...modules.misc_modules import StyleReshape
7+
from ._base_model import BaseMultiTaskSegModel
8+
from ._seg_head import SegHead
9+
from ._timm_encoder import TimmEncoder
10+
11+
__all__ = ["MultiTaskUnet"]
12+
13+
14+
class MultiTaskUnet(BaseMultiTaskSegModel):
15+
def __init__(
16+
self,
17+
decoders: Tuple[str, ...],
18+
heads: Dict[str, Dict[str, int]],
19+
n_layers: Dict[str, Tuple[int, ...]],
20+
n_blocks: Dict[str, Tuple[Tuple[int, ...], ...]],
21+
out_channels: Dict[str, Tuple[int, ...]],
22+
long_skips: Dict[str, str],
23+
dec_params: Dict[str, Tuple[Dict[str, Any], ...]],
24+
depth: int = 4,
25+
style_channels: int = 256,
26+
enc_name: str = "resnet50",
27+
enc_pretrain: bool = True,
28+
enc_freeze: bool = False,
29+
) -> None:
30+
"""Create a universal multi-task (2D) unet.
31+
32+
NOTE: For experimental purposes.
33+
34+
Parameters
35+
----------
36+
decoders : Tuple[str, ...]
37+
Names of the decoder branches of this network. E.g. ("cellpose", "sem")
38+
heads : Dict[str, Dict[str, int]]
39+
Names of the decoder branches (has to match `decoders`) mapped to dicts
40+
of output name - number of output classes. E.g.
41+
{"cellpose": {"type": 4, "cellpose": 2}, "sem": {"sem": 5}}
42+
n_layers : Dict[str, Tuple[int, ...]]
43+
The number of conv layers inside each of the decoder stages.
44+
n_blocks : Dict[str, Tuple[Tuple[int, ...], ...]]
45+
The number of blocks inside each conv-layer in each decoder stage.
46+
out_channels : Tuple[int, ...]
47+
Out channels for each decoder stage.
48+
long_skips : Dict[str, str]
49+
long skip method to be used. One of: "unet", "unetpp", "unet3p",
50+
"unet3p-lite", None
51+
dec_params : Dict[str, Tuple[Dict[str, Any], ...]])
52+
The keyword args for each of the distinct decoder stages. Incudes the
53+
parameters for the long skip connections and convolutional layers of the
54+
decoder itself. See the `DecoderStage` documentation for more info.
55+
depth : int, default=4
56+
The depth of the encoder. I.e. Number of returned feature maps from
57+
the encoder. Maximum depth = 5.
58+
style_channels : int, default=256
59+
Number of style vector channels. If None, style vectors are ignored.
60+
enc_name : str, default="resnet50"
61+
Name of the encoder. See timm docs for more info.
62+
enc_pretrain : bool, default=True
63+
Whether to use imagenet pretrained weights in the encoder.
64+
enc_freeze : bool, default=False
65+
Freeze encoder weights for training.
66+
"""
67+
super().__init__()
68+
self.enc_freeze = enc_freeze
69+
use_style = style_channels is not None
70+
self.heads = heads
71+
72+
# set timm encoder
73+
self.encoder = TimmEncoder(
74+
enc_name,
75+
depth=depth,
76+
pretrained=enc_pretrain,
77+
)
78+
79+
# style
80+
self.make_style = None
81+
if use_style:
82+
self.make_style = StyleReshape(self.encoder.out_channels[0], style_channels)
83+
84+
# set decoders
85+
for decoder_name in decoders:
86+
decoder = Decoder(
87+
enc_channels=list(self.encoder.out_channels),
88+
style_channels=style_channels,
89+
out_channels=out_channels[decoder_name],
90+
long_skip=long_skips[decoder_name],
91+
n_layers=n_layers[decoder_name],
92+
n_blocks=n_blocks[decoder_name],
93+
stage_params=dec_params[decoder_name],
94+
)
95+
self.add_module(f"{decoder_name}_decoder", decoder)
96+
97+
# set heads
98+
for decoder_name in heads.keys():
99+
for output_name, n_classes in heads[decoder_name].items():
100+
seg_head = SegHead(
101+
in_channels=decoder.out_channels,
102+
out_channels=n_classes,
103+
kernel_size=1,
104+
)
105+
self.add_module(f"{output_name}_seg_head", seg_head)
106+
107+
self.name = f"MultiTaskUnet-{enc_name}"
108+
109+
# init decoder weights
110+
self.initialize()
111+
112+
# freeze encoder if specified
113+
if enc_freeze:
114+
self.freeze_encoder()
115+
116+
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
117+
"""Forward pass of Multi-task U-net."""
118+
self._check_input_shape(x)
119+
120+
feats = self.encoder(x)
121+
122+
style = None
123+
if self.make_style is not None:
124+
style = self.make_style(feats[0])
125+
126+
dec_feats = self.forward_dec_features(feats, style)
127+
128+
for decoder_name in self.heads.keys():
129+
for head_name in self.heads[decoder_name].keys():
130+
k = self.aux_key if head_name not in dec_feats.keys() else head_name
131+
dec_feats[head_name] = dec_feats[k]
132+
133+
out = self.forward_heads(dec_feats)
134+
135+
return out

cellseg_models_pytorch/models/cellpose/_conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _create_cellpose_args(
1515
use_style: bool,
1616
merge_policy: str,
1717
skip_params: Dict[str, Any],
18-
) -> Dict[str, Any]:
18+
) -> Tuple[Dict[str, Any], ...]:
1919
"""Create the args to build CellPose-Unet architecture."""
2020
skip_params = skip_params if skip_params is not None else {"k": None}
2121

cellseg_models_pytorch/models/cellpose/cellpose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
class CellPoseUnet(BaseMultiTaskSegModel):
2323
def __init__(
2424
self,
25-
decoders: Tuple[str],
25+
decoders: Tuple[str, ...],
2626
heads: Dict[str, Dict[str, int]],
2727
inst_key: str = "type",
2828
depth: int = 4,
@@ -65,7 +65,7 @@ def __init__(
6565
6666
Parameters
6767
----------
68-
decoders : Tuple[str]
68+
decoders : Tuple[str, ...]
6969
Names of the decoder branches of this network. E.g. ("cellpose", "sem")
7070
heads : Dict[str, Dict[str, int]]
7171
Names of the decoder branches (has to match `decoders`) mapped to dicts

cellseg_models_pytorch/models/hovernet/_conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _create_hovernet_args(
1515
use_style: bool,
1616
merge_policy: str,
1717
skip_params: Dict[str, Any],
18-
) -> Dict[str, Any]:
18+
) -> Tuple[Dict[str, Any], ...]:
1919
"""Create the correct args to build HoVerNet architecture."""
2020
skip_params = skip_params if skip_params is not None else {"k": None}
2121

cellseg_models_pytorch/models/hovernet/hovernet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
6363
Parameters
6464
----------
65-
decoders : Tuple[str]
65+
decoders : Tuple[str, ...]
6666
Names of the decoder branches of this network. E.g. ("hovernet", "sem")
6767
heads : Dict[str, Dict[str, int]]
6868
The segmentation heads of the architecture. I.e. Names of the decoder

cellseg_models_pytorch/models/stardist/_conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict
1+
from typing import Any, Dict, Tuple
22

33
__all__ = ["_create_stardist_args"]
44

@@ -16,7 +16,7 @@ def _create_stardist_args(
1616
block_type: str,
1717
merge_policy: str,
1818
skip_params: Dict[str, Any],
19-
) -> Dict[str, Any]:
19+
) -> Tuple[Dict[str, Any], ...]:
2020
"""Create the args to build CellPose-Unet architecture."""
2121
skip_params = skip_params if skip_params is not None else {"k": None}
2222

cellseg_models_pytorch/models/stardist/stardist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
5858
Parameters
5959
----------
60-
decoders : Tuple[str]
60+
decoders : Tuple[str, ...]
6161
Names of the decoder branches of this network. E.g. ("stardist", "sem")
6262
extra_convs : Dict[str, Dict[str, int]]
6363
The extra conv blocks before segmentation heads of the architecture.

cellseg_models_pytorch/models/tests/test_models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33

44
from cellseg_models_pytorch.models import (
5+
MultiTaskUnet,
56
cellpose_base,
67
cellpose_plus,
78
hovernet_base,
@@ -76,3 +77,20 @@ def test_omnipose_fwdbwd(model):
7677
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
7778

7879
assert y["omnipose"].shape == torch.Size([1, 2, 64, 64])
80+
81+
82+
def test_multitaskunet_fwdbwd():
83+
x = torch.rand([1, 3, 64, 64])
84+
m = MultiTaskUnet(
85+
decoders=("sem",),
86+
heads={"sem": {"sem": 3}},
87+
n_layers={"sem": (1, 1, 1, 1)},
88+
n_blocks={"sem": ((2,), (2,), (2,), (2,))},
89+
out_channels={"sem": (128, 64, 32, 16)},
90+
long_skips={"sem": "unet"},
91+
dec_params={"sem": None},
92+
)
93+
y = m(x)
94+
y["sem"].mean().backward()
95+
96+
assert y["sem"].shape == torch.Size([1, 3, 64, 64])
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
### Features
2+
3+
- **models**: Add a universal multi-task U-net model builder (experimental)
4+
5+
### Type Hints
6+
7+
- **models**: Fix incorrect type hints.
8+
9+
### Test
10+
11+
- **models**: Update tests for multi-task U-Net

0 commit comments

Comments
 (0)