1
- import warnings
2
1
from abc import ABC , abstractmethod
3
2
from collections import OrderedDict
4
3
from itertools import chain
@@ -33,9 +32,6 @@ def __init__(
33
32
normalization : str = None ,
34
33
device : str = "cuda" ,
35
34
n_devices : int = 1 ,
36
- save_intermediate : bool = False ,
37
- save_dir : Union [Path , str ] = None ,
38
- save_format : str = ".mat" ,
39
35
checkpoint_path : Union [Path , str ] = None ,
40
36
n_images : int = None ,
41
37
type_post_proc : Callable = None ,
@@ -46,57 +42,49 @@ def __init__(
46
42
47
43
Parameters
48
44
----------
49
- model : nn.Module
50
- A segmentation model.
51
- input_path : Path | str
52
- Path to a folder of images or to hdf5 db.
53
- out_activations : Dict[str, str]
54
- Dictionary of head names mapped to a string value that specifies the
55
- activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
56
- Allowed values: "softmax", "sigmoid", "tanh", None.
57
- out_boundary_weights : Dict[str, bool]
58
- Dictionary of head names mapped to a boolean value. If the value is
59
- True, after a prediction, a weight matrix is applied that assigns bigger
60
- weight on pixels in the center and less weight to pixels on the tile
61
- boundaries. helps dealing with prediction artefacts on the boundaries.
62
- E.g. {"type": False, "cellpose": True}
63
- patch_size : Tuple[int, int]:
64
- The size of the input patches that are fed to the segmentation model.
65
- instance_postproc : str
66
- The post-processing method for the instance segmentation mask. One of:
67
- "cellpose", "omnipose", "stardist", "hovernet", "dcan", "drfns", "dran"
68
- padding : int, optional
69
- The amount of reflection padding for the input images.
70
- batch_size : int, default=8
71
- Number of images loaded from the folder at every batch.
72
- normalization : str, optional
73
- Apply img normalization at forward pass (Same as during training).
74
- One of: "dataset", "minmax", "norm", "percentile", None.
75
- device : str, default="cuda"
76
- The device of the input and model. One of: "cuda", "cpu"
77
- n_devices : int, default=1
78
- Number of devices (cpus/gpus) used for inference.
79
- The model will be copied into these devices.
80
- save_dir : bool, optional
81
- Path to save directory. If None, no masks will be saved to disk as .mat
82
- or .json files. Instead the masks will be saved in `self.out_masks`.
83
- save_intermediate : bool, default=False
84
- If True, intermediate soft masks will be saved into `soft_masks` var.
85
- save_format : str, default=".mat"
86
- The file format for the saved output masks. One of (".mat", ".json").
87
- The ".json" option will save masks into geojson format.
88
- checkpoint_path : Path | str, optional
89
- Path to the model weight checkpoints.
90
- n_images : int, optional
91
- First n-number of images used from the `input_path`.
92
- type_post_proc : Callable, optional
93
- A post-processing function for the type maps. If not None, overrides
94
- the default.
95
- sem_post_proc : Callable, optional
96
- A post-processing function for the semantc seg maps. If not None,
97
- overrides the default.
98
- **kwargs:
99
- Arbitrary keyword arguments expecially for post-processing and saving.
45
+ model : nn.Module
46
+ A segmentation model.
47
+ input_path : Path | str
48
+ Path to a folder of images or to hdf5 db.
49
+ out_activations : Dict[str, str]
50
+ Dictionary of head names mapped to a string value that specifies the
51
+ activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
52
+ Allowed values: "softmax", "sigmoid", "tanh", None.
53
+ out_boundary_weights : Dict[str, bool]
54
+ Dictionary of head names mapped to a boolean value. If the value is
55
+ True, after a prediction, a weight matrix is applied that assigns bigger
56
+ weight on pixels in the center and less weight to pixels on the tile
57
+ boundaries. helps dealing with prediction artefacts on the boundaries.
58
+ E.g. {"type": False, "cellpose": True}
59
+ patch_size : Tuple[int, int]:
60
+ The size of the input patches that are fed to the segmentation model.
61
+ instance_postproc : str
62
+ The post-processing method for the instance segmentation mask. One of:
63
+ "cellpose", "omnipose", "stardist", "hovernet", "dcan", "drfns", "dran"
64
+ padding : int, optional
65
+ The amount of reflection padding for the input images.
66
+ batch_size : int, default=8
67
+ Number of images loaded from the folder at every batch.
68
+ normalization : str, optional
69
+ Apply img normalization at forward pass (Same as during training).
70
+ One of: "dataset", "minmax", "norm", "percentile", None.
71
+ device : str, default="cuda"
72
+ The device of the input and model. One of: "cuda", "cpu"
73
+ n_devices : int, default=1
74
+ Number of devices (cpus/gpus) used for inference.
75
+ The model will be copied into these devices.
76
+ checkpoint_path : Path | str, optional
77
+ Path to the model weight checkpoints.
78
+ n_images : int, optional
79
+ First n-number of images used from the `input_path`.
80
+ type_post_proc : Callable, optional
81
+ A post-processing function for the type maps. If not None, overrides
82
+ the default.
83
+ sem_post_proc : Callable, optional
84
+ A post-processing function for the semantc seg maps. If not None,
85
+ overrides the default.
86
+ **kwargs:
87
+ Arbitrary keyword arguments for post-processing.
100
88
"""
101
89
# basic inits
102
90
self .model = model
@@ -109,22 +97,10 @@ def __init__(
109
97
self .head_kwargs = self ._check_and_set_head_args ()
110
98
self .kwargs = kwargs
111
99
112
- self .save_dir = Path (save_dir ) if save_dir is not None else None
113
- self .save_intermediate = save_intermediate
114
- self .save_format = save_format
115
-
116
100
# dataset & dataloader
117
101
self .path = Path (input_path )
118
102
if self .path .is_dir ():
119
103
ds = FolderDatasetInfer (self .path , n_images = n_images )
120
- if self .save_dir is None and len (ds .fnames ) > 40 and n_images is None :
121
- warnings .warn (
122
- "`save_dir` is None. Thus, the outputs are be saved in `out_masks` "
123
- "class attribute. If the input folder contains many images, running"
124
- " inference will likely flood the memory depending on the size and "
125
- "number of the images. Consider saving outputs to disk by providing"
126
- " `save_dir` argument."
127
- )
128
104
elif self .path .is_file () and self .path .suffix in (".h5" , ".hdf5" ):
129
105
from .hdf5_dataset_infer import HDF5DatasetInfer
130
106
@@ -167,10 +143,10 @@ def __init__(
167
143
168
144
# try loading the weights to the model
169
145
try :
170
- msg = self .model .load_state_dict (state_dict , strict = True )
146
+ msg = self .model .load_state_dict (state_dict , strict = False )
171
147
except RuntimeError :
172
148
new_ckpt = self ._strip_state_dict (state_dict )
173
- msg = self .model .load_state_dict (new_ckpt , strict = True )
149
+ msg = self .model .load_state_dict (new_ckpt , strict = False )
174
150
except BaseException as e :
175
151
raise RuntimeError (f"Error when loading checkpoint: { e } " )
176
152
@@ -218,34 +194,74 @@ def from_yaml(cls, model: nn.Module, yaml_path: str):
218
194
def _infer_batch (self ):
219
195
raise NotImplementedError
220
196
221
- def infer (self , mixed_precision : bool = False ) -> None :
222
- """Run inference and post-processing for the images.
197
+ def infer (
198
+ self ,
199
+ save_dir : Union [Path , str ] = None ,
200
+ save_format : str = ".mat" ,
201
+ save_intermediate : bool = False ,
202
+ classes_type : Dict [str , int ] = None ,
203
+ classes_sem : Dict [str , int ] = None ,
204
+ offsets : bool = False ,
205
+ mixed_precision : bool = False ,
206
+ ) -> None :
207
+ """Run inference and post-processing for the image(s) inside `input_path`.
223
208
224
- NOTE:
225
- - Saves outputs in class attributes or to disk (.mat/.json) files.
226
- - If masks are saved to .json (geojson) files, more key word arguments
227
- need to be given at class initialization. Namely: `geo_format`,
228
- `classes_type`, `classes_sem`, `offsets`. See more in the
229
- `FileHandler.save_masks` docs.
209
+ NOTE: If `save_dir` is None, the output masks will be cached in a class
210
+ attribute `self.out_masks`. Otherwise the masks will be saved to disk.
230
211
212
+ WARNING: Running inference without setting `save_dir` can take a lot of memory
213
+ if the input directory contains many images.
231
214
232
215
Parameters
233
216
----------
234
- mixed_precision : bool, default=False
235
- If True, inference is performed with mixed precision.
217
+ save_dir : bool, optional
218
+ Path to save directory. If None, no masks will be saved to disk.
219
+ Instead the masks will be cached in a class attribute `self.out_masks`.
220
+ save_format : str, default=".mat"
221
+ The file format for the saved output masks. One of ".mat", ".geojson",
222
+ "feather" "parquet".
223
+ save_intermediate : bool, default=False
224
+ If True, intermediate soft masks will be saved into `self.soft_masks`
225
+ class attribute. WARNING: This can take a lot of memory if the input
226
+ directory contains many images.
227
+ classes_type : Dict[str, str], optional
228
+ Cell type dictionary. e.g. {"inflam":1, "epithelial":2, "connec":3}.
229
+ This is required only if `save_format` is one of the following formats:
230
+ ".geojson", ".parquet", ".feather".
231
+ classes_sem : Dict[str, str], otional
232
+ Tissue type dictionary. e.g. {"tissue1":1, "tissue2":2, "tissue3":3}
233
+ This is required only if `save_format` is one of the following formats:
234
+ ".geojson", ".parquet", ".feather".
235
+ offsets : bool, default=False
236
+ If True, geojson coords are shifted by the offsets that are encoded in
237
+ the filenames (e.g. "x-1000_y-4000.png"). Ignored if `format` == `.mat`.
238
+ mixed_precision : bool, default=False
239
+ If True, inference is performed with mixed precision.
236
240
237
241
Attributes
238
242
----------
239
- - out_masks : Dict[str, Dict[str, np.ndarray]]
240
- The output masks for each image. The keys are the image names and the
241
- values are dictionaries of the masks. E.g.
242
- {"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
243
- - soft_masks : Dict[str, Dict[str, np.ndarray]]
244
- NOTE: This attribute is set only if `save_intermediate = True`.
245
- The soft masks for each image. I.e. the soft predictions of the trained
246
- model The keys are the image names and the values are dictionaries of
247
- the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
243
+ - out_masks : Dict[str, Dict[str, np.ndarray]]
244
+ The output masks for each image. The keys are the image names and the
245
+ values are dictionaries of the masks. E.g.
246
+ {"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
247
+ - soft_masks : Dict[str, Dict[str, np.ndarray]]
248
+ NOTE: This attribute is set only if `save_intermediate = True`.
249
+ The soft masks for each image. I.e. the soft predictions of the trained
250
+ model The keys are the image names and the values are dictionaries of
251
+ the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
248
252
"""
253
+ # check save_dir and save_format
254
+ save_dir = Path (save_dir ) if save_dir is not None else None
255
+ save_intermediate = save_intermediate
256
+ save_format = save_format
257
+ if save_dir is not None :
258
+ allowed_formats = (".mat" , ".geojson" , ".feather" , ".parquet" )
259
+ if save_format not in allowed_formats :
260
+ raise ValueError (
261
+ f"Given `save_format`: { save_format } is not one of the allowed "
262
+ f"formats: { allowed_formats } "
263
+ )
264
+
249
265
self .soft_masks = {}
250
266
self .out_masks = {}
251
267
self .elapsed = []
@@ -271,7 +287,7 @@ def infer(self, mixed_precision: bool = False) -> None:
271
287
self .elapsed .append (loader .format_dict ["elapsed" ])
272
288
self .rate .append (loader .format_dict ["rate" ])
273
289
274
- if self . save_intermediate :
290
+ if save_intermediate :
275
291
for n , m in zip (names , soft_masks ):
276
292
self .soft_masks [n ] = m
277
293
@@ -283,25 +299,33 @@ def infer(self, mixed_precision: bool = False) -> None:
283
299
seg ["soft_sem" ] = soft ["sem" ]
284
300
285
301
# save to cache or disk
286
- if self . save_dir is None :
302
+ if save_dir is None :
287
303
for n , m in zip (names , seg_results ):
288
304
self .out_masks [n ] = m
289
305
else :
290
306
loader .set_postfix_str ("Saving results to disk" )
291
307
if self .batch_size > 1 :
292
- fnames = [Path (self . save_dir ) / n for n in names ]
308
+ fnames = [Path (save_dir ) / n for n in names ]
293
309
FileHandler .save_masks_parallel (
294
310
maps = seg_results ,
295
311
fnames = fnames ,
296
- ** {** self .kwargs , "format" : self .save_format },
312
+ format = save_format ,
313
+ classes_type = classes_type ,
314
+ classes_sem = classes_sem ,
315
+ offsets = offsets ,
316
+ pooltype = "thread" ,
317
+ maptype = "amap" ,
297
318
)
298
319
else :
299
320
for n , m in zip (names , seg_results ):
300
- fname = Path (self . save_dir ) / n
321
+ fname = Path (save_dir ) / n
301
322
FileHandler .save_masks (
302
323
fname = fname ,
303
324
maps = m ,
304
- ** {** self .kwargs , "format" : self .save_format },
325
+ format = save_format ,
326
+ classes_type = classes_type ,
327
+ classes_sem = classes_sem ,
328
+ offsets = offsets ,
305
329
)
306
330
307
331
def _strip_state_dict (self , ckpt : Dict ) -> OrderedDict :
0 commit comments