4
4
from itertools import chain
5
5
from pathlib import Path
6
6
from typing import Callable , Dict , List , Tuple , Union
7
+ from urllib .parse import urlparse
7
8
8
9
import numpy as np
9
10
import torch
@@ -150,17 +151,31 @@ def __init__(
150
151
151
152
# load weights and set devices
152
153
if checkpoint_path is not None :
153
- ckpt = torch .load (
154
- checkpoint_path , map_location = lambda storage , loc : storage
155
- )
154
+ checkpoint_path = Path (checkpoint_path )
155
+ # check if path is url or local and load weigths to memory
156
+ if urlparse (checkpoint_path .as_posix ()).scheme :
157
+ state_dict = torch .hub .load_state_dict_from_url (checkpoint_path )
158
+ else :
159
+ state_dict = torch .load (
160
+ checkpoint_path , map_location = lambda storage , loc : storage
161
+ )
162
+
163
+ # if the checkpoint is from lightning, the ckpt file contains a lot of other
164
+ # stuff than just the state dict.
165
+ if "state_dict" in state_dict .keys ():
166
+ state_dict = state_dict ["state_dict" ]
156
167
168
+ # try loading the weights to the model
157
169
try :
158
- self .model .load_state_dict (ckpt [ " state_dict" ] , strict = True )
170
+ msg = self .model .load_state_dict (state_dict , strict = True )
159
171
except RuntimeError :
160
- new_ckpt = self ._strip_state_dict (ckpt )
161
- self .model .load_state_dict (new_ckpt [ "state_dict" ] , strict = True )
172
+ new_ckpt = self ._strip_state_dict (state_dict )
173
+ msg = self .model .load_state_dict (new_ckpt , strict = True )
162
174
except BaseException as e :
163
- print (e )
175
+ raise RuntimeError (f"Error when loading checkpoint: { e } " )
176
+
177
+ print (f"Loading weights: { checkpoint_path } for inference." )
178
+ print (msg )
164
179
165
180
assert device in ("cuda" , "cpu" , "mps" )
166
181
if device == "cpu" :
@@ -213,6 +228,12 @@ def infer(self, mixed_precision: bool = False) -> None:
213
228
`classes_type`, `classes_sem`, `offsets`. See more in the
214
229
`FileHandler.save_masks` docs.
215
230
231
+
232
+ Parameters
233
+ ----------
234
+ mixed_precision : bool, default=False
235
+ If True, inference is performed with mixed precision.
236
+
216
237
Attributes
217
238
----------
218
239
- out_masks : Dict[str, Dict[str, np.ndarray]]
@@ -224,11 +245,6 @@ def infer(self, mixed_precision: bool = False) -> None:
224
245
The soft masks for each image. I.e. the soft predictions of the trained
225
246
model The keys are the image names and the values are dictionaries of
226
247
the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
227
-
228
- Parameters
229
- ----------
230
- mixed_precision : bool, default=False
231
- If True, inference is performed with mixed precision.
232
248
"""
233
249
self .soft_masks = {}
234
250
self .out_masks = {}
@@ -291,15 +307,13 @@ def infer(self, mixed_precision: bool = False) -> None:
291
307
def _strip_state_dict (self , ckpt : Dict ) -> OrderedDict :
292
308
"""Strip te first 'model.' (generated by lightning) from the state dict keys."""
293
309
state_dict = OrderedDict ()
294
- for k , w in ckpt [ "state_dict" ] .items ():
310
+ for k , w in ckpt .items ():
295
311
if "num_batches_track" not in k :
296
- # new_key = k.strip("model")[1:]
297
312
spl = ["" .join (kk ) for kk in k .split ("." )]
298
313
new_key = "." .join (spl [1 :])
299
314
state_dict [new_key ] = w
300
- ckpt ["state_dict" ] = state_dict
301
315
302
- return ckpt
316
+ return state_dict
303
317
304
318
def _check_and_set_head_args (self ) -> None :
305
319
"""Check the model has matching head names with the head args and set them."""
0 commit comments