Skip to content

Commit 1bd5d23

Browse files
committed
fix: rewrite model API
1 parent ecadbe5 commit 1bd5d23

File tree

23 files changed

+2205
-1853
lines changed

23 files changed

+2205
-1853
lines changed
Lines changed: 171 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,180 +1,180 @@
1-
import torch.nn as nn
1+
# import torch.nn as nn
22

3-
from .cellpose.cellpose import (
4-
CellPoseUnet,
5-
cellpose_base,
6-
cellpose_plus,
7-
omnipose_base,
8-
omnipose_plus,
9-
)
10-
from .cellvit.cellvit import (
11-
CellVitSAM,
12-
cellvit_sam_base,
13-
cellvit_sam_plus,
14-
cellvit_sam_small,
15-
cellvit_sam_small_plus,
16-
)
17-
from .cppnet.cppnet import CPPNet, cppnet_base, cppnet_base_multiclass, cppnet_plus
18-
from .hovernet.hovernet import (
19-
HoverNet,
20-
hovernet_base,
21-
hovernet_plus,
22-
hovernet_small,
23-
hovernet_small_plus,
24-
)
25-
from .stardist.stardist import (
26-
StarDistUnet,
27-
stardist_base,
28-
stardist_base_multiclass,
29-
stardist_plus,
30-
)
3+
# from .cellpose.cellpose import (
4+
# CellPoseUnet,
5+
# cellpose_base,
6+
# cellpose_plus,
7+
# omnipose_base,
8+
# omnipose_plus,
9+
# )
10+
# from .cellvit.cellvit import (
11+
# CellVitSAM,
12+
# cellvit_sam_base,
13+
# cellvit_sam_plus,
14+
# cellvit_sam_small,
15+
# cellvit_sam_small_plus,
16+
# )
17+
# from .cppnet.cppnet import CPPNet, cppnet_base, cppnet_base_multiclass, cppnet_plus
18+
# from .hovernet.hovernet import (
19+
# HoverNet,
20+
# hovernet_base,
21+
# hovernet_plus,
22+
# hovernet_small,
23+
# hovernet_small_plus,
24+
# )
25+
# from .stardist.stardist import (
26+
# StarDistUnet,
27+
# stardist_base,
28+
# stardist_base_multiclass,
29+
# stardist_plus,
30+
# )
3131

32-
MODEL_LOOKUP = {
33-
"cellpose_base": cellpose_base,
34-
"cellpose_plus": cellpose_plus,
35-
"omnipose_base": omnipose_base,
36-
"omnipose_plus": omnipose_plus,
37-
"hovernet_base": hovernet_base,
38-
"hovernet_plus": hovernet_plus,
39-
"hovernet_small": hovernet_small,
40-
"hovernet_small_plus": hovernet_small_plus,
41-
"stardist_base": stardist_base,
42-
"stardist_plus": stardist_plus,
43-
"stardist_base_multiclass": stardist_base_multiclass,
44-
"cellvit_sam_base": cellvit_sam_base,
45-
"cellvit_sam_plus": cellvit_sam_plus,
46-
"cellvit_sam_small": cellvit_sam_small,
47-
"cellvit_sam_small_plus": cellvit_sam_small_plus,
48-
"cppnet_base": cppnet_base,
49-
"cppnet_base_multiclass": cppnet_base_multiclass,
50-
"cppnet_plus": cppnet_plus,
51-
}
32+
# MODEL_LOOKUP = {
33+
# "cellpose_base": cellpose_base,
34+
# "cellpose_plus": cellpose_plus,
35+
# "omnipose_base": omnipose_base,
36+
# "omnipose_plus": omnipose_plus,
37+
# "hovernet_base": hovernet_base,
38+
# "hovernet_plus": hovernet_plus,
39+
# "hovernet_small": hovernet_small,
40+
# "hovernet_small_plus": hovernet_small_plus,
41+
# "stardist_base": stardist_base,
42+
# "stardist_plus": stardist_plus,
43+
# "stardist_base_multiclass": stardist_base_multiclass,
44+
# "cellvit_sam_base": cellvit_sam_base,
45+
# "cellvit_sam_plus": cellvit_sam_plus,
46+
# "cellvit_sam_small": cellvit_sam_small,
47+
# "cellvit_sam_small_plus": cellvit_sam_small_plus,
48+
# "cppnet_base": cppnet_base,
49+
# "cppnet_base_multiclass": cppnet_base_multiclass,
50+
# "cppnet_plus": cppnet_plus,
51+
# }
5252

5353

54-
def get_model(
55-
name: str,
56-
type: str,
57-
n_type_classes: int = None,
58-
n_sem_classes: int = None,
59-
**kwargs,
60-
) -> nn.Module:
61-
"""Get the corect model at hand given name and type.
54+
# def get_model(
55+
# name: str,
56+
# type: str,
57+
# n_type_classes: int = None,
58+
# n_sem_classes: int = None,
59+
# **kwargs,
60+
# ) -> nn.Module:
61+
# """Get the corect model at hand given name and type.
6262

63-
Parameters:
64-
name (str):
65-
Name of the model.
66-
type (str):
67-
Type of the model. One of "base", "plus", "small", "small_plus".
68-
n_type_classes (int):
69-
Number of cell types to segment.
70-
n_sem_classes (int):
71-
Number of tissue types to segment.
72-
**kwargs
73-
Additional keyword arguments.
63+
# Parameters:
64+
# name (str):
65+
# Name of the model.
66+
# type (str):
67+
# Type of the model. One of "base", "plus", "small", "small_plus".
68+
# n_type_classes (int):
69+
# Number of cell types to segment.
70+
# n_sem_classes (int):
71+
# Number of tissue types to segment.
72+
# **kwargs
73+
# Additional keyword arguments.
7474

75-
Returns:
76-
nn.Module: The specified model.
77-
"""
78-
if name == "stardist":
79-
if type == "base":
80-
model = MODEL_LOOKUP["stardist_base_multiclass"](
81-
n_type_classes=n_type_classes, **kwargs
82-
)
83-
elif type == "plus":
84-
model = MODEL_LOOKUP["stardist_plus"](
85-
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
86-
)
87-
elif name == "cppnet":
88-
if type == "base":
89-
model = MODEL_LOOKUP["cppnet_base_multiclass"](
90-
n_type_classes=n_type_classes, **kwargs
91-
)
92-
elif type == "plus":
93-
model = MODEL_LOOKUP["cppnet_plus"](
94-
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
95-
)
96-
elif name == "cellpose":
97-
if type == "base":
98-
model = MODEL_LOOKUP["cellpose_base"](
99-
n_type_classes=n_type_classes, **kwargs
100-
)
101-
elif type == "plus":
102-
model = MODEL_LOOKUP["cellpose_plus"](
103-
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
104-
)
105-
elif name == "omnipose":
106-
if type == "base":
107-
model = MODEL_LOOKUP["omnipose_base"](
108-
n_type_classes=n_type_classes, **kwargs
109-
)
110-
elif type == "plus":
111-
model = MODEL_LOOKUP["omnipose_plus"](
112-
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
113-
)
114-
elif name == "hovernet":
115-
if type == "base":
116-
model = MODEL_LOOKUP["hovernet_base"](
117-
n_type_classes=n_type_classes, **kwargs
118-
)
119-
elif type == "small":
120-
model = MODEL_LOOKUP["hovernet_small"](
121-
n_type_classes=n_type_classes, **kwargs
122-
)
123-
elif type == "plus":
124-
model = MODEL_LOOKUP["hovernet_plus"](
125-
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
126-
)
127-
elif type == "small_plus":
128-
model = MODEL_LOOKUP["hovernet_small_plus"](
129-
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
130-
)
131-
elif name == "cellvit":
132-
if type == "base":
133-
model = MODEL_LOOKUP["cellvit_sam_base"](
134-
n_type_classes=n_type_classes, **kwargs
135-
)
136-
elif type == "small":
137-
model = MODEL_LOOKUP["cellvit_sam_small"](
138-
n_type_classes=n_type_classes, **kwargs
139-
)
140-
elif type == "plus":
141-
model = MODEL_LOOKUP["cellvit_sam_plus"](
142-
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
143-
)
144-
elif type == "small_plus":
145-
model = MODEL_LOOKUP["cellvit_sam_small_plus"](
146-
n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
147-
)
148-
else:
149-
raise ValueError("Unknown model type or name.")
75+
# Returns:
76+
# nn.Module: The specified model.
77+
# """
78+
# if name == "stardist":
79+
# if type == "base":
80+
# model = MODEL_LOOKUP["stardist_base_multiclass"](
81+
# n_type_classes=n_type_classes, **kwargs
82+
# )
83+
# elif type == "plus":
84+
# model = MODEL_LOOKUP["stardist_plus"](
85+
# n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
86+
# )
87+
# elif name == "cppnet":
88+
# if type == "base":
89+
# model = MODEL_LOOKUP["cppnet_base_multiclass"](
90+
# n_type_classes=n_type_classes, **kwargs
91+
# )
92+
# elif type == "plus":
93+
# model = MODEL_LOOKUP["cppnet_plus"](
94+
# n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
95+
# )
96+
# elif name == "cellpose":
97+
# if type == "base":
98+
# model = MODEL_LOOKUP["cellpose_base"](
99+
# n_type_classes=n_type_classes, **kwargs
100+
# )
101+
# elif type == "plus":
102+
# model = MODEL_LOOKUP["cellpose_plus"](
103+
# n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
104+
# )
105+
# elif name == "omnipose":
106+
# if type == "base":
107+
# model = MODEL_LOOKUP["omnipose_base"](
108+
# n_type_classes=n_type_classes, **kwargs
109+
# )
110+
# elif type == "plus":
111+
# model = MODEL_LOOKUP["omnipose_plus"](
112+
# n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
113+
# )
114+
# elif name == "hovernet":
115+
# if type == "base":
116+
# model = MODEL_LOOKUP["hovernet_base"](
117+
# n_type_classes=n_type_classes, **kwargs
118+
# )
119+
# elif type == "small":
120+
# model = MODEL_LOOKUP["hovernet_small"](
121+
# n_type_classes=n_type_classes, **kwargs
122+
# )
123+
# elif type == "plus":
124+
# model = MODEL_LOOKUP["hovernet_plus"](
125+
# n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
126+
# )
127+
# elif type == "small_plus":
128+
# model = MODEL_LOOKUP["hovernet_small_plus"](
129+
# n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
130+
# )
131+
# elif name == "cellvit":
132+
# if type == "base":
133+
# model = MODEL_LOOKUP["cellvit_sam_base"](
134+
# n_type_classes=n_type_classes, **kwargs
135+
# )
136+
# elif type == "small":
137+
# model = MODEL_LOOKUP["cellvit_sam_small"](
138+
# n_type_classes=n_type_classes, **kwargs
139+
# )
140+
# elif type == "plus":
141+
# model = MODEL_LOOKUP["cellvit_sam_plus"](
142+
# n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
143+
# )
144+
# elif type == "small_plus":
145+
# model = MODEL_LOOKUP["cellvit_sam_small_plus"](
146+
# n_type_classes=n_type_classes, n_sem_classes=n_sem_classes, **kwargs
147+
# )
148+
# else:
149+
# raise ValueError("Unknown model type or name.")
150150

151-
return model
151+
# return model
152152

153153

154-
__all__ = [
155-
"HoverNet",
156-
"hovernet_base",
157-
"hovernet_plus",
158-
"hovernet_small",
159-
"hovernet_small_plus",
160-
"CellPoseUnet",
161-
"cellpose_base",
162-
"cellpose_plus",
163-
"omnipose_base",
164-
"omnipose_plus",
165-
"StarDistUnet",
166-
"stardist_base",
167-
"stardist_plus",
168-
"stardist_base_multiclass",
169-
"MODEL_LOOKUP",
170-
"get_model",
171-
"CellVitSAM",
172-
"cellvit_sam_base",
173-
"cellvit_sam_plus",
174-
"cellvit_sam_small",
175-
"cellvit_sam_small_plus",
176-
"cppnet_base",
177-
"cppnet_base_multiclass",
178-
"cppnet_plus",
179-
"CPPNet",
180-
]
154+
# __all__ = [
155+
# "HoverNet",
156+
# "hovernet_base",
157+
# "hovernet_plus",
158+
# "hovernet_small",
159+
# "hovernet_small_plus",
160+
# "CellPoseUnet",
161+
# "cellpose_base",
162+
# "cellpose_plus",
163+
# "omnipose_base",
164+
# "omnipose_plus",
165+
# "StarDistUnet",
166+
# "stardist_base",
167+
# "stardist_plus",
168+
# "stardist_base_multiclass",
169+
# "MODEL_LOOKUP",
170+
# "get_model",
171+
# "CellVitSAM",
172+
# "cellvit_sam_base",
173+
# "cellvit_sam_plus",
174+
# "cellvit_sam_small",
175+
# "cellvit_sam_small_plus",
176+
# "cppnet_base",
177+
# "cppnet_base_multiclass",
178+
# "cppnet_plus",
179+
# "CPPNet",
180+
# ]

0 commit comments

Comments
 (0)