Skip to content

Commit eb52ee4

Browse files
committed
fix(nodels): Enable Dippa cellseg model conversion
1 parent a91fae7 commit eb52ee4

File tree

7 files changed

+157
-2
lines changed

7 files changed

+157
-2
lines changed

cellseg_models_pytorch/inference/post_processor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def post_proc_pipeline(
161161

162162
if "type" in maps.keys():
163163
res["type"] = self._get_type_map(maps["type"], res["inst"], **self.kwargs)
164+
res["inst"] *= res["type"] > 0
164165

165166
return res
166167

cellseg_models_pytorch/models/base/_base_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,11 @@ def _check_decoder_args(
8686

8787
def _get_inner_keys(self, d: Dict[str, Dict[str, Any]]) -> List[str]:
8888
"""Get the inner dict keys from a nested dict."""
89-
return chain.from_iterable(list(d[k].keys()) for k in d.keys())
89+
return list(chain.from_iterable(list(d[k].keys()) for k in d.keys()))
90+
91+
def _flatten_inner_dicts(self, d: Dict[str, Dict[str, Any]]) -> List[str]:
92+
"""Get the inner dicts as one dict from a nested dict."""
93+
return dict(chain.from_iterable(list(d[k].items()) for k in d.keys()))
9094

9195
def _check_head_args(
9296
self, heads: Dict[str, int], decoders: Tuple[str, ...]

cellseg_models_pytorch/models/base/_multitask_unet.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def __init__(
2626
enc_name: str = "resnet50",
2727
enc_pretrain: bool = True,
2828
enc_freeze: bool = False,
29+
inst_key: str = None,
30+
aux_key: str = None,
2931
) -> None:
3032
"""Create a universal multi-task (2D) unet.
3133
@@ -63,11 +65,19 @@ def __init__(
6365
Whether to use imagenet pretrained weights in the encoder.
6466
enc_freeze : bool, default=False
6567
Freeze encoder weights for training.
68+
inst_key : str, optional
69+
The key for the model output that will be used in the instance
70+
segmentation post-processing pipeline as the binary segmentation result.
71+
aux_key : str, optional
72+
The key for the model output that will be used in the instance
73+
segmentation post-processing pipeline as the auxilliary map.
6674
"""
6775
super().__init__()
6876
self.enc_freeze = enc_freeze
6977
use_style = style_channels is not None
7078
self.heads = heads
79+
self.inst_key = inst_key
80+
self.aux_key = aux_key
7181

7282
# set timm encoder
7383
self.encoder = TimmEncoder(

cellseg_models_pytorch/modules/conv/ws_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
)
4242
self.eps = eps
4343

44-
def forward(self, x: torch.Tensor):
44+
def forward(self, x: torch.Tensor) -> torch.Tensor:
4545
"""Weight standardized convolution forward pass."""
4646
weight = self.weight
4747

cellseg_models_pytorch/modules/conv_base.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"InvertedBottleneckConv",
1313
"FusedMobileInvertedConv",
1414
"HoverNetDenseConv",
15+
"BasicConvOld",
1516
]
1617

1718

@@ -829,3 +830,135 @@ def forward_features_preact(self, x: torch.Tensor) -> torch.Tensor:
829830
x = self.conv(x)
830831

831832
return x
833+
834+
835+
class BasicConvOld(nn.Module):
836+
def __init__(
837+
self,
838+
in_channels: int,
839+
out_channels: int,
840+
same_padding: bool = True,
841+
normalization: str = "bn",
842+
activation: str = "relu",
843+
convolution: str = "conv",
844+
preactivate: bool = False,
845+
kernel_size=3,
846+
groups: int = 1,
847+
bias: bool = False,
848+
attention: str = None,
849+
preattend: bool = False,
850+
**kwargs
851+
) -> None:
852+
"""Conv-block (basic) parent class.
853+
854+
Parameters
855+
----------
856+
in_channels : int
857+
Number of input channels.
858+
out_channels : int
859+
Number of output channels.
860+
same_padding : bool, default=True
861+
if True, performs same-covolution.
862+
normalization : str, default="bn":
863+
Normalization method.
864+
One of: "bn", "bcn", "gn", "in", "ln", None
865+
activation : str, default="relu"
866+
Activation method.
867+
One of: "mish", "swish", "relu", "relu6", "rrelu", "selu",
868+
"celu", "gelu", "glu", "tanh", "sigmoid", "silu", "prelu",
869+
"leaky-relu", "elu", "hardshrink", "tanhshrink", "hardsigmoid"
870+
convolution : str, default="conv"
871+
The convolution method. One of: "conv", "wsconv", "scaled_wsconv"
872+
preactivate : bool, default=False
873+
If True, normalization will be applied before convolution.
874+
kernel_size : int, default=3
875+
The size of the convolution kernel.
876+
groups : int, default=1
877+
Number of groups the kernels are divided into. If `groups == 1`
878+
normal convolution is applied. If `groups = in_channels`
879+
depthwise convolution is applied.
880+
bias : bool, default=False,
881+
Include bias term in the convolution.
882+
attention : str, default=None
883+
Attention method. One of: "se", "scse", "gc", "eca", None
884+
preattend : bool, default=False
885+
If True, Attention is applied at the beginning of forward pass.
886+
"""
887+
super().__init__()
888+
self.conv_choice = convolution
889+
self.out_channels = out_channels
890+
self.preattend = preattend
891+
self.preactivate = preactivate
892+
893+
# set norm channel number for preactivation or normal
894+
norm_channels = in_channels if preactivate else self.out_channels
895+
896+
# set padding. Works if dilation or stride are not adjusted
897+
padding = (kernel_size - 1) // 2 if same_padding else 0
898+
899+
self.conv = Conv(
900+
name=self.conv_choice,
901+
in_channels=in_channels,
902+
out_channels=out_channels,
903+
kernel_size=kernel_size,
904+
groups=groups,
905+
padding=padding,
906+
bias=bias,
907+
)
908+
909+
self.norm = Norm(normalization, num_features=norm_channels)
910+
self.act = Activation(activation)
911+
912+
# set attention channels
913+
att_channels = in_channels if preattend else self.out_channels
914+
self.att = Attention(attention, in_channels=att_channels)
915+
916+
self.downsample = None
917+
if in_channels != out_channels:
918+
self.downsample = nn.Sequential(
919+
Conv(
920+
self.conv_choice,
921+
in_channels=in_channels,
922+
out_channels=out_channels,
923+
bias=False,
924+
kernel_size=1,
925+
padding=0,
926+
),
927+
Norm(normalization, num_features=out_channels),
928+
)
929+
930+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
931+
"""Forward pass."""
932+
identity = x
933+
if self.downsample is not None:
934+
identity = self.downsample(x)
935+
936+
x = self.att(x)
937+
938+
# residual
939+
x = self.conv(x)
940+
x = self.norm(x)
941+
942+
x += identity
943+
x = self.act(x)
944+
945+
return x
946+
947+
def forward_features_preact(self, x: torch.Tensor) -> torch.Tensor:
948+
"""Forward pass with pre-activation."""
949+
identity = x
950+
if self.downsample is not None:
951+
identity = self.downsample(x)
952+
953+
# pre-attention
954+
x = self.att(x)
955+
956+
# preact residual
957+
x = self.norm(x)
958+
x = self.act(x)
959+
x = self.conv(x)
960+
961+
x += identity
962+
x = self.act(x)
963+
964+
return x

cellseg_models_pytorch/modules/conv_block.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .conv_base import (
55
BasicConv,
6+
BasicConvOld,
67
BottleneckConv,
78
DepthWiseSeparableConv,
89
FusedMobileInvertedConv,
@@ -20,6 +21,7 @@
2021
"mbconv": InvertedBottleneckConv,
2122
"fmbconv": FusedMobileInvertedConv,
2223
"hover_dense": HoverNetDenseConv,
24+
"basic_old": BasicConvOld,
2325
}
2426

2527

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
## Fixes
2+
3+
- Add a conv block `BasicConvOld` to enable `Dippa` to cellseg conversion of models.
4+
- Fix `inst_key`, `aux_key` bug in `MultiTaskUnet`
5+
- Add a type_map > 0 masking for the `inst_map`s in post-processing

0 commit comments

Comments
 (0)