Skip to content

Commit 84a383a

Browse files
committed
feat(models); add a generic model builder function
1 parent d047451 commit 84a383a

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

cellseg_models_pytorch/models/__init__.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,64 @@
2020
stardist_plus,
2121
)
2222

23+
MODEL_LOOKUP = {
24+
"cellpose_base": cellpose_base,
25+
"cellpose_plus": cellpose_plus,
26+
"omnipose_base": omnipose_base,
27+
"omnipose_plus": omnipose_plus,
28+
"hovernet_base": hovernet_base,
29+
"hovernet_small": hovernet_small,
30+
"hovernet_small_plus": hovernet_small_plus,
31+
"stardist_base": stardist_base,
32+
"stardist_plus": stardist_plus,
33+
"stardist_base_multiclass": stardist_base_multiclass,
34+
}
35+
36+
37+
def get_model(name: str, type: str, ntypes: int = None, ntissues: int = None):
38+
"""Get the corect model at hand given name and type."""
39+
if name == "stardist":
40+
if type == "base":
41+
model = MODEL_LOOKUP["stardist_base_multiclass"](
42+
n_rays=32, type_classes=ntypes
43+
)
44+
elif type == "plus":
45+
model = MODEL_LOOKUP["stardist_plus"](
46+
n_rays=32, type_classes=ntypes, sem_classes=ntissues
47+
)
48+
elif name == "cellpose":
49+
if type == "base":
50+
model = MODEL_LOOKUP["cellpose_base"](type_classes=ntypes)
51+
elif type == "plus":
52+
model = MODEL_LOOKUP["cellpose_plus"](
53+
type_classes=ntypes, sem_classes=ntissues
54+
)
55+
elif name == "omnipose":
56+
if type == "base":
57+
model = MODEL_LOOKUP["omnipose_base"](type_classes=ntypes)
58+
elif type == "plus":
59+
model = MODEL_LOOKUP["omnipose_plus"](
60+
type_classes=ntypes, sem_classes=ntissues
61+
)
62+
elif name == "hovernet":
63+
if type == "base":
64+
model = MODEL_LOOKUP["hovernet_base"](type_classes=ntypes)
65+
elif type == "small":
66+
model = MODEL_LOOKUP["hovernet_small"](type_classes=ntypes)
67+
elif type == "plus":
68+
model = MODEL_LOOKUP["hovernet_plus"](
69+
type_classes=ntypes, sem_classes=ntissues
70+
)
71+
elif type == "small_plus":
72+
model = MODEL_LOOKUP["hovernet_small_plus"](
73+
type_classes=ntypes, sem_classes=ntissues
74+
)
75+
else:
76+
raise ValueError("Unknown model type or name.")
77+
78+
return model
79+
80+
2381
__all__ = [
2482
"MultiTaskUnet",
2583
"HoverNet",
@@ -36,4 +94,6 @@
3694
"stardist_base",
3795
"stardist_plus",
3896
"stardist_base_multiclass",
97+
"MODEL_LOOKUP",
98+
"get_model",
3999
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Features
2+
3+
- Add a generic model builder function `get_model` to `models.__init__.py`

0 commit comments

Comments
 (0)