Skip to content

Commit 8ea3556

Browse files
committed
feat: add CellVIT-SAM model
1 parent fb2cee7 commit 8ea3556

File tree

5 files changed

+515
-5
lines changed

5 files changed

+515
-5
lines changed

cellseg_models_pytorch/models/__init__.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
omnipose_base,
99
omnipose_plus,
1010
)
11+
from .cellvit.cellvit import (
12+
CellVitSAM,
13+
cellvit_sam_base,
14+
cellvit_sam_plus,
15+
cellvit_sam_small,
16+
cellvit_sam_small_plus,
17+
)
1118
from .hovernet.hovernet import (
1219
HoverNet,
1320
hovernet_base,
@@ -34,6 +41,10 @@
3441
"stardist_base": stardist_base,
3542
"stardist_plus": stardist_plus,
3643
"stardist_base_multiclass": stardist_base_multiclass,
44+
"cellvit_sam_base": cellvit_sam_base,
45+
"cellvit_sam_plus": cellvit_sam_plus,
46+
"cellvit_sam_small": cellvit_sam_small,
47+
"cellvit_sam_small_plus": cellvit_sam_small_plus,
3748
}
3849

3950

@@ -62,11 +73,11 @@ def get_model(
6273
if name == "stardist":
6374
if type == "base":
6475
model = MODEL_LOOKUP["stardist_base_multiclass"](
65-
n_rays=32, type_classes=ntypes, **kwargs
76+
type_classes=ntypes, **kwargs
6677
)
6778
elif type == "plus":
6879
model = MODEL_LOOKUP["stardist_plus"](
69-
n_rays=32, type_classes=ntypes, sem_classes=ntissues, **kwargs
80+
type_classes=ntypes, sem_classes=ntissues, **kwargs
7081
)
7182
elif name == "cellpose":
7283
if type == "base":
@@ -95,6 +106,19 @@ def get_model(
95106
model = MODEL_LOOKUP["hovernet_small_plus"](
96107
type_classes=ntypes, sem_classes=ntissues, **kwargs
97108
)
109+
elif name == "cellvit":
110+
if type == "base":
111+
model = MODEL_LOOKUP["cellvit_sam_base"](type_classes=ntypes, **kwargs)
112+
elif type == "small":
113+
model = MODEL_LOOKUP["cellvit_sam_small"](type_classes=ntypes, **kwargs)
114+
elif type == "plus":
115+
model = MODEL_LOOKUP["cellvit_sam_plus"](
116+
type_classes=ntypes, sem_classes=ntissues, **kwargs
117+
)
118+
elif type == "small_plus":
119+
model = MODEL_LOOKUP["cellvit_sam_small_plus"](
120+
type_classes=ntypes, sem_classes=ntissues, **kwargs
121+
)
98122
else:
99123
raise ValueError("Unknown model type or name.")
100124

@@ -119,4 +143,9 @@ def get_model(
119143
"stardist_base_multiclass",
120144
"MODEL_LOOKUP",
121145
"get_model",
146+
"CellVitSAM",
147+
"cellvit_sam_base",
148+
"cellvit_sam_plus",
149+
"cellvit_sam_small",
150+
"cellvit_sam_small_plus",
122151
]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Any, Dict, Tuple
2+
3+
__all__ = ["_create_cellvit_args"]
4+
5+
6+
def _create_cellvit_args(
7+
layer_depth: Tuple[int, ...],
8+
norm: str,
9+
act: str,
10+
conv: str,
11+
att: str,
12+
preact: bool,
13+
preattend: bool,
14+
short_skip: str,
15+
use_style: bool,
16+
merge_policy: str,
17+
skip_params: Dict[str, Any],
18+
) -> Tuple[Dict[str, Any], ...]:
19+
"""Create the args to build CellVit-Unet decoders."""
20+
skip_params = skip_params if skip_params is not None else {"k": None}
21+
22+
return tuple(
23+
{
24+
"layer_residual": False,
25+
"upsampling": "conv_transpose",
26+
"merge_policy": merge_policy,
27+
"short_skips": (short_skip,),
28+
"block_types": (("basic",) * ld,),
29+
"kernel_sizes": ((3,) * ld,),
30+
"expand_ratios": ((1.0,) * ld,),
31+
"groups": ((1,) * ld,),
32+
"biases": ((False,) * ld,),
33+
"normalizations": ((norm,) * ld,),
34+
"activations": ((act,) * ld,),
35+
"convolutions": ((conv,) * ld,),
36+
"attentions": ((att,) + (None,) * (ld - 1),),
37+
"preactivates": ((preact,) * ld,),
38+
"preattends": ((preattend,) * ld,),
39+
"use_styles": ((use_style,) * (ld - 1) + (False,),),
40+
"skip_params": {
41+
"short_skips": (short_skip,),
42+
"block_types": (("basic",),),
43+
"convolutions": ((conv,),),
44+
"normalizations": ((norm,),),
45+
"activations": ((act,),),
46+
**skip_params,
47+
},
48+
}
49+
for ld in layer_depth
50+
)

0 commit comments

Comments
 (0)