Skip to content

Commit 357178c

Browse files
committed
fix: fix base model from_pretrained
1 parent a9310a3 commit 357178c

File tree

13 files changed

+210
-77
lines changed

13 files changed

+210
-77
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
__all__ = ["PRETRAINED"]
2+
3+
PRETRAINED = {
4+
"hovernet": {
5+
"hgsc_v1_efficientnet_b5": {
6+
"repo_id": "csmp-hub/hovernet-histo-hgsc-nuc-v1",
7+
"filename": "hovernet_hgsc_v1_efficientnet_b5.safetensors",
8+
},
9+
},
10+
"cellpose": {
11+
"hgsc_v1_efficientnet_b5": {
12+
"repo_id": "csmp-hub/cellpose-histo-hgsc-nuc-v1",
13+
"filename": "cellpose_hgsc_v1_efficientnet_b5.safetensors",
14+
},
15+
},
16+
"cellvit": {
17+
"hgsc_v1_efficientnet_b5": {
18+
"repo_id": "csmp-hub/cellvit-histo-hgsc-nuc-v1",
19+
"filename": "cellvit_hgsc_v1_efficientnet_b5.safetensors",
20+
},
21+
},
22+
"stardist": {
23+
"hgsc_v1_efficientnet_b5": {
24+
"repo_id": "csmp-hub/stardist-histo-hgsc-nuc-v1",
25+
"filename": "stardist_hgsc_v1_efficientnet_b5.safetensors",
26+
},
27+
},
28+
"cppnet": {
29+
"hgsc_v1_efficientnet_b5": {
30+
"repo_id": "csmp-hub/cppnet-histo-hgsc-nuc-v1",
31+
"filename": "cppnet_hgsc_v1_efficientnet_b5.safetensors",
32+
},
33+
},
34+
}

cellseg_models_pytorch/models/base/_base_model_inst.py

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
import numpy as np
55
import torch
6+
from huggingface_hub import hf_hub_download
67
from PIL.Image import Image
78

89
from cellseg_models_pytorch.decoders.multitask_decoder import (
910
SoftInstanceOutput,
1011
SoftSemanticOutput,
1112
)
13+
from cellseg_models_pytorch.models.base import PRETRAINED
1214

1315
__all__ = ["BaseModelInst"]
1416

@@ -24,24 +26,60 @@ def set_inference_mode(self) -> None:
2426
@classmethod
2527
def from_pretrained(
2628
cls,
27-
weights_path: Union[str, Path],
28-
n_nuc_classes: int,
29-
enc_name: str = "efficientnet_b5",
30-
enc_freeze: bool = False,
29+
weights: Union[str, Path],
3130
device: torch.device = torch.device("cuda"),
3231
model_kwargs: Dict[str, Any] = {},
33-
) -> None:
34-
"""Load the model from pretrained weights."""
32+
) -> "BaseModelInst":
33+
"""Load the model from pretrained weights.
34+
35+
Parameters:
36+
model_name (str):
37+
Name of the pretrained model.
38+
device (torch.device, default=torch.device("cuda")):
39+
Device to run the model on. Default is "cuda".
40+
model_kwargs (Dict[str, Any], default={}):
41+
Additional arguments for the model.
42+
"""
43+
weights_path = Path(weights)
44+
if not weights_path.is_file():
45+
if weights_path.as_posix() in PRETRAINED[cls.model_name].keys():
46+
weights_path = Path(
47+
hf_hub_download(
48+
repo_id=PRETRAINED[cls.model_name][weights]["repo_id"],
49+
filename=PRETRAINED[cls.model_name][weights]["filename"],
50+
)
51+
)
52+
53+
else:
54+
raise ValueError(
55+
"Please provide a valid path. or a pre-trained model downloaded from the"
56+
f" csmp-hub. One of {list(PRETRAINED[cls.model_name].keys())}."
57+
)
58+
59+
try:
60+
from safetensors.torch import load_model
61+
except ImportError:
62+
raise ImportError(
63+
"Please install `safetensors` package to load .safetensors files."
64+
)
65+
66+
enc_name, n_nuc_classes, state_dict = cls._get_state_dict(
67+
weights_path, device=device
68+
)
69+
3570
model_inst = cls(
3671
n_nuc_classes=n_nuc_classes,
3772
enc_name=enc_name,
3873
enc_pretrain=False,
39-
enc_freeze=enc_freeze,
74+
enc_freeze=False,
4075
device=device,
4176
model_kwargs=model_kwargs,
4277
)
43-
state_dict = torch.load(weights_path, map_location=device)
44-
model_inst.model.load_state_dict(state_dict, strict=True)
78+
79+
if weights_path.suffix == ".safetensors":
80+
load_model(model_inst.model, weights_path, device.type)
81+
else:
82+
model_inst.model.load_state_dict(state_dict, strict=True)
4583

4684
return model_inst
4785

@@ -174,3 +212,34 @@ def post_process(
174212
)
175213

176214
return x
215+
216+
@staticmethod
217+
def _get_state_dict(
218+
weights_path: Union[str, Path], device: torch.device = torch.device("cuda")
219+
) -> None:
220+
"""Load the model from pretrained weights."""
221+
weights_path = Path(weights_path)
222+
if not weights_path.exists():
223+
raise ValueError(f"Model weights not found at {weights_path}")
224+
if weights_path.suffix == ".safetensors":
225+
try:
226+
from safetensors.torch import load_file
227+
except ImportError:
228+
raise ImportError(
229+
"Please install `safetensors` package to load .safetensors files."
230+
)
231+
state_dict = load_file(weights_path, device=device.type)
232+
else:
233+
state_dict = torch.load(weights_path, map_location=device)
234+
235+
# infer encoder name and number of classes from state_dict
236+
enc_keys = [key for key in state_dict.keys() if "encoder." in key]
237+
enc_name = enc_keys[0].split(".")[0] if enc_keys else None
238+
nuc_type_head_key = next(
239+
key
240+
for key in state_dict.keys()
241+
if "nuc_type_head.head" in key and "weight" in key
242+
)
243+
n_nuc_classes = state_dict[nuc_type_head_key].shape[0]
244+
245+
return enc_name, n_nuc_classes, state_dict

cellseg_models_pytorch/models/cellpose/cellpose.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212

1313
class CellPose(BaseModelInst):
14+
model_name = "cellpose"
15+
1416
def __init__(
1517
self,
1618
n_nuc_classes: int,

cellseg_models_pytorch/models/cellpose/cellpose_unet.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
]
2020

2121

22-
class CellPoseUnet(nn.Module):
22+
class CellPoseUnet(nn.ModuleDict):
2323
def __init__(
2424
self,
2525
decoders: Tuple[str, ...],
@@ -127,6 +127,7 @@ def __init__(
127127
super().__init__()
128128
self.inst_key = inst_key
129129
self.aux_key = "cellpose"
130+
self.enc_name = enc_name
130131

131132
if enc_out_indices is None:
132133
enc_out_indices = tuple(range(depth))
@@ -155,18 +156,21 @@ def __init__(
155156
)
156157

157158
# set encoder
158-
self.encoder = Encoder(
159-
timm_encoder_name=enc_name,
160-
timm_encoder_out_indices=enc_out_indices,
161-
timm_encoder_pretrained=enc_pretrain,
162-
timm_extra_kwargs=encoder_kws,
159+
self.add_module(
160+
self.enc_name,
161+
Encoder(
162+
timm_encoder_name=enc_name,
163+
timm_encoder_out_indices=enc_out_indices,
164+
timm_encoder_pretrained=enc_pretrain,
165+
timm_extra_kwargs=encoder_kws,
166+
),
163167
)
164168

165169
self.decoder = MultiTaskDecoder(
166170
decoders=decoders,
167171
heads=heads,
168172
out_channels=out_channels,
169-
enc_feature_info=self.encoder.feature_info,
173+
enc_feature_info=self[self.enc_name].feature_info,
170174
n_layers=n_layers,
171175
n_blocks=n_blocks,
172176
stage_kws=stage_kws,
@@ -181,7 +185,7 @@ def __init__(
181185

182186
# freeze encoder if specified
183187
if enc_freeze:
184-
self.encoder.freeze_encoder()
188+
self[self.enc_name].freeze_encoder()
185189

186190
self.name = f"CellPoseUnet-{enc_name}"
187191

@@ -204,7 +208,7 @@ def forward(self, x: torch.Tensor, return_pred_only: bool = True) -> Dict[str, A
204208
- "dec_feats": Dict[str, List[torch.Tensor]].
205209
- "enc_out": torch.Tensor.
206210
"""
207-
enc_output, feats = self.encoder(x)
211+
enc_output, feats = self[self.enc_name](x)
208212
dec_out: DecoderSoftOutput = self.decoder(feats, x)
209213

210214
res = {

cellseg_models_pytorch/models/cellvit/cellvit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212

1313
class CellVit(BaseModelInst):
14+
model_name = "cellvit"
15+
1416
def __init__(
1517
self,
1618
n_nuc_classes: int,

cellseg_models_pytorch/models/cellvit/cellvit_unet.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
]
1818

1919

20-
class CellVitSamUnet(nn.Module):
20+
class CellVitSamUnet(nn.ModuleDict):
2121
def __init__(
2222
self,
2323
decoders: Tuple[str, ...],
@@ -119,9 +119,22 @@ def __init__(
119119
segmentation post-processing pipeline as the binary segmentation result.
120120
"""
121121
super().__init__()
122+
allowed = (
123+
"samvit_base_patch16",
124+
"samvit_base_patch16_224",
125+
"samvit_huge_patch16",
126+
"samvit_large_patch16",
127+
)
128+
if enc_name not in allowed:
129+
raise ValueError(
130+
f"Wrong encoder name. Got: {enc_name}. "
131+
f"Allowed encoder for CellVit: {allowed}"
132+
)
133+
122134
self.inst_key = inst_key
123135
self.aux_key = "hovernet"
124136
self.depth = len(layer_depths)
137+
self.enc_name = enc_name
125138

126139
if enc_out_indices is None:
127140
enc_out_indices = tuple(range(self.depth))
@@ -151,31 +164,22 @@ def __init__(
151164
skip_kws,
152165
)
153166

154-
allowed = (
155-
"samvit_base_patch16",
156-
"samvit_base_patch16_224",
157-
"samvit_huge_patch16",
158-
"samvit_large_patch16",
159-
)
160-
if enc_name not in allowed:
161-
raise ValueError(
162-
f"Wrong encoder name. Got: {enc_name}. "
163-
f"Allowed encoder for CellVit: {allowed}"
164-
)
165-
166167
# set encoders
167-
self.encoder = Encoder(
168-
timm_encoder_name=enc_name,
169-
timm_encoder_out_indices=enc_out_indices,
170-
timm_encoder_pretrained=enc_pretrain,
171-
timm_extra_kwargs=encoder_kws,
168+
self.add_module(
169+
self.enc_name,
170+
Encoder(
171+
timm_encoder_name=enc_name,
172+
timm_encoder_out_indices=enc_out_indices,
173+
timm_encoder_pretrained=enc_pretrain,
174+
timm_extra_kwargs=encoder_kws,
175+
),
172176
)
173177

174178
self.decoder = MultiTaskDecoder(
175179
decoders=decoders,
176180
heads=heads,
177181
out_channels=out_channels,
178-
enc_feature_info=self.encoder.feature_info,
182+
enc_feature_info=self[self.enc_name].feature_info,
179183
n_layers=n_layers,
180184
n_blocks=n_blocks,
181185
stage_kws=stage_kws,
@@ -190,7 +194,7 @@ def __init__(
190194

191195
# freeze encoder if specified
192196
if enc_freeze:
193-
self.encoder.freeze_encoder()
197+
self[self.enc_name].freeze_encoder()
194198

195199
self.name = f"CellVit-{enc_name}"
196200

@@ -213,7 +217,7 @@ def forward(self, x: torch.Tensor, return_pred_only: bool = True) -> Dict[str, A
213217
- "dec_feats": Dict[str, List[torch.Tensor]].
214218
- "enc_out": torch.Tensor.
215219
"""
216-
enc_output, feats = self.encoder(x)
220+
enc_output, feats = self[self.enc_name](x)
217221
dec_out: DecoderSoftOutput = self.decoder(feats, x)
218222

219223
res = {

cellseg_models_pytorch/models/cppnet/cppnet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212

1313
class CPPNet(BaseModelInst):
14+
model_name = "cppnet"
15+
1416
def __init__(
1517
self,
1618
n_nuc_classes: int,

cellseg_models_pytorch/models/cppnet/cppnet_unet.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def forward(
8787
return ray_refined, confidence_refined
8888

8989

90-
class CPPNetUnet(nn.Module):
90+
class CPPNetUnet(nn.ModuleDict):
9191
def __init__(
9292
self,
9393
decoders: Tuple[str, ...],
@@ -194,6 +194,7 @@ def __init__(
194194
self.inst_key = inst_key
195195
self.aux_key = "stardist"
196196
self.n_rays = n_rays
197+
self.enc_name = enc_name
197198

198199
if enc_out_indices is None:
199200
enc_out_indices = tuple(range(depth))
@@ -221,18 +222,21 @@ def __init__(
221222
)
222223

223224
# set encoder
224-
self.encoder = Encoder(
225-
timm_encoder_name=enc_name,
226-
timm_encoder_out_indices=enc_out_indices,
227-
timm_encoder_pretrained=enc_pretrain,
228-
timm_extra_kwargs=encoder_kws,
225+
self.add_module(
226+
self.enc_name,
227+
Encoder(
228+
timm_encoder_name=enc_name,
229+
timm_encoder_out_indices=enc_out_indices,
230+
timm_encoder_pretrained=enc_pretrain,
231+
timm_extra_kwargs=encoder_kws,
232+
),
229233
)
230234

231235
self.decoder = MultiTaskDecoder(
232236
decoders=decoders,
233237
heads=heads,
234238
out_channels=out_channels,
235-
enc_feature_info=self.encoder.feature_info,
239+
enc_feature_info=self[self.enc_name].feature_info,
236240
n_layers=n_layers,
237241
n_blocks=n_blocks,
238242
stage_kws=stage_kws,
@@ -255,7 +259,7 @@ def __init__(
255259

256260
# freeze encoder if specified
257261
if enc_freeze:
258-
self.encoder.freeze_encoder()
262+
self[self.enc_name].freeze_encoder()
259263

260264
def forward(self, x: torch.Tensor, return_pred_only: bool = True) -> Dict[str, Any]:
261265
"""Forward pass of Cellpose U-net.
@@ -276,7 +280,7 @@ def forward(self, x: torch.Tensor, return_pred_only: bool = True) -> Dict[str, A
276280
- "dec_feats": Dict[str, List[torch.Tensor]].
277281
- "enc_out": torch.Tensor.
278282
"""
279-
enc_output, feats = self.encoder(x)
283+
enc_output, feats = self[self.enc_name](x)
280284
dec_out: DecoderSoftOutput = self.decoder(feats, x)
281285
if dec_out.nuc_map is not None:
282286
dec_name = dec_out.nuc_map.parents["aux_map"][0]

0 commit comments

Comments
 (0)