Skip to content

Commit 30a7510

Browse files
committed
fix: improve inferer ckpt loading
1 parent 0066971 commit 30a7510

File tree

1 file changed

+30
-16
lines changed

1 file changed

+30
-16
lines changed

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from itertools import chain
55
from pathlib import Path
66
from typing import Callable, Dict, List, Tuple, Union
7+
from urllib.parse import urlparse
78

89
import numpy as np
910
import torch
@@ -150,17 +151,31 @@ def __init__(
150151

151152
# load weights and set devices
152153
if checkpoint_path is not None:
153-
ckpt = torch.load(
154-
checkpoint_path, map_location=lambda storage, loc: storage
155-
)
154+
checkpoint_path = Path(checkpoint_path)
155+
# check if path is url or local and load weigths to memory
156+
if urlparse(checkpoint_path.as_posix()).scheme:
157+
state_dict = torch.hub.load_state_dict_from_url(checkpoint_path)
158+
else:
159+
state_dict = torch.load(
160+
checkpoint_path, map_location=lambda storage, loc: storage
161+
)
162+
163+
# if the checkpoint is from lightning, the ckpt file contains a lot of other
164+
# stuff than just the state dict.
165+
if "state_dict" in state_dict.keys():
166+
state_dict = state_dict["state_dict"]
156167

168+
# try loading the weights to the model
157169
try:
158-
self.model.load_state_dict(ckpt["state_dict"], strict=True)
170+
msg = self.model.load_state_dict(state_dict, strict=True)
159171
except RuntimeError:
160-
new_ckpt = self._strip_state_dict(ckpt)
161-
self.model.load_state_dict(new_ckpt["state_dict"], strict=True)
172+
new_ckpt = self._strip_state_dict(state_dict)
173+
msg = self.model.load_state_dict(new_ckpt, strict=True)
162174
except BaseException as e:
163-
print(e)
175+
raise RuntimeError(f"Error when loading checkpoint: {e}")
176+
177+
print(f"Loading weights: {checkpoint_path} for inference.")
178+
print(msg)
164179

165180
assert device in ("cuda", "cpu", "mps")
166181
if device == "cpu":
@@ -213,6 +228,12 @@ def infer(self, mixed_precision: bool = False) -> None:
213228
`classes_type`, `classes_sem`, `offsets`. See more in the
214229
`FileHandler.save_masks` docs.
215230
231+
232+
Parameters
233+
----------
234+
mixed_precision : bool, default=False
235+
If True, inference is performed with mixed precision.
236+
216237
Attributes
217238
----------
218239
- out_masks : Dict[str, Dict[str, np.ndarray]]
@@ -224,11 +245,6 @@ def infer(self, mixed_precision: bool = False) -> None:
224245
The soft masks for each image. I.e. the soft predictions of the trained
225246
model The keys are the image names and the values are dictionaries of
226247
the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
227-
228-
Parameters
229-
----------
230-
mixed_precision : bool, default=False
231-
If True, inference is performed with mixed precision.
232248
"""
233249
self.soft_masks = {}
234250
self.out_masks = {}
@@ -291,15 +307,13 @@ def infer(self, mixed_precision: bool = False) -> None:
291307
def _strip_state_dict(self, ckpt: Dict) -> OrderedDict:
292308
"""Strip te first 'model.' (generated by lightning) from the state dict keys."""
293309
state_dict = OrderedDict()
294-
for k, w in ckpt["state_dict"].items():
310+
for k, w in ckpt.items():
295311
if "num_batches_track" not in k:
296-
# new_key = k.strip("model")[1:]
297312
spl = ["".join(kk) for kk in k.split(".")]
298313
new_key = ".".join(spl[1:])
299314
state_dict[new_key] = w
300-
ckpt["state_dict"] = state_dict
301315

302-
return ckpt
316+
return state_dict
303317

304318
def _check_and_set_head_args(self) -> None:
305319
"""Check the model has matching head names with the head args and set them."""

0 commit comments

Comments
 (0)