1
+ import warnings
1
2
from abc import ABC , abstractmethod
2
3
from collections import OrderedDict
3
4
from itertools import chain
8
9
import torch
9
10
import torch .nn as nn
10
11
import yaml
11
- from pathos .multiprocessing import ThreadPool as Pool
12
12
from torch .utils .data import DataLoader
13
13
from tqdm import tqdm
14
14
15
- from ..utils import tensor_to_ndarray
16
- from ..utils .save_utils import mask2mat
15
+ from ..utils import FileHandler , tensor_to_ndarray
17
16
from .folder_dataset import FolderDataset
18
17
from .post_processor import PostProcessor
19
18
from .predictor import Predictor
@@ -33,14 +32,14 @@ def __init__(
33
32
normalization : str = None ,
34
33
device : str = "cuda" ,
35
34
n_devices : int = 1 ,
36
- save_masks : bool = True ,
37
35
save_intermediate : bool = False ,
38
36
save_dir : Union [Path , str ] = None ,
37
+ save_format : str = ".mat" ,
39
38
checkpoint_path : Union [Path , str ] = None ,
40
39
n_images : int = None ,
41
40
type_post_proc : Callable = None ,
42
41
sem_post_proc : Callable = None ,
43
- ** postproc_kwargs ,
42
+ ** kwargs ,
44
43
) -> None :
45
44
"""Inference for an image folder.
46
45
@@ -77,16 +76,14 @@ def __init__(
77
76
n_devices : int, default=1
78
77
Number of devices (cpus/gpus) used for inference.
79
78
The model will be copied into these devices.
80
- save_masks : bool, default=False
81
- If True, the resulting segmentation masks will be saved into `out_masks`
82
- variable.
83
- save_intermediate : bool, default=False
84
- If True, intermediate soft masks will be saved into `soft_masks` var.
85
79
save_dir : bool, optional
86
80
Path to save directory. If None, no masks will be saved to disk as .mat
87
- files. If not None, overrides `save_masks`, thus for every batch the
88
- segmentation results are saved into disk and the intermediate results
89
- are flushed.
81
+ or .json files. Instead the masks will be saved in `self.out_masks`.
82
+ save_intermediate : bool, default=False
83
+ If True, intermediate soft masks will be saved into `soft_masks` var.
84
+ save_format : str, default=".mat"
85
+ The file format for the saved output masks. One of (".mat", ".json").
86
+ The ".json" option will save masks into geojson format.
90
87
checkpoint_path : Path | str, optional
91
88
Path to the model weight checkpoints.
92
89
n_images : int, optional
@@ -97,8 +94,8 @@ def __init__(
97
94
sem_post_proc : Callable, optional
98
95
A post-processing function for the semantc seg maps. If not None,
99
96
overrides the default.
100
- **postproc_kwargs :
101
- Arbitrary keyword arguments for the post-processing.
97
+ **kwargs :
98
+ Arbitrary keyword arguments expecially for post-processing and saving .
102
99
"""
103
100
# basic inits
104
101
self .model = model
@@ -109,14 +106,25 @@ def __init__(
109
106
self .out_activations = out_activations
110
107
self .out_boundary_weights = out_boundary_weights
111
108
self .head_kwargs = self ._check_and_set_head_args ()
109
+ self .kwargs = kwargs
112
110
113
111
self .save_dir = Path (save_dir ) if save_dir is not None else None
114
- self .save_masks = save_masks
115
112
self .save_intermediate = save_intermediate
113
+ self .save_format = save_format
116
114
117
115
# dataloader
118
116
self .path = Path (input_folder )
117
+
119
118
folder_ds = FolderDataset (self .path , n_images = n_images )
119
+ if self .save_dir is None and len (folder_ds .fnames ) > 40 :
120
+ warnings .warn (
121
+ "`save_dir` is None. Thus, the outputs are be saved in `out_masks` "
122
+ "class variable. If the input folder contains many images, running "
123
+ "inference will likely flood the memory depending on the size and "
124
+ "number of the images. Consider saving outputs to disk by providing "
125
+ "`save_dir` argument."
126
+ )
127
+
120
128
self .dataloader = DataLoader (
121
129
folder_ds , batch_size = batch_size , shuffle = False , pin_memory = True
122
130
)
@@ -128,7 +136,7 @@ def __init__(
128
136
aux_key = self .model .aux_key ,
129
137
type_post_proc = type_post_proc ,
130
138
sem_post_proc = sem_post_proc ,
131
- ** postproc_kwargs ,
139
+ ** kwargs ,
132
140
)
133
141
134
142
# load weights and set devices
@@ -188,10 +196,16 @@ def _infer_batch(self):
188
196
def infer (self ) -> None :
189
197
"""Run inference and post-processing for the images.
190
198
191
- NOTE: Saves outputs in `self.out_masks` or to disk (.mat) files.
192
-
193
- `self.out_masks` is a nested dict: E.g.
194
- {"image1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
199
+ NOTE:
200
+ - Saves outputs in `self.out_masks` or to disk (.mat/.json) files.
201
+ - If `save_intermediate` is set to True, also intermiediate model outputs are
202
+ saved to `self.soft_masks`
203
+ - `self.out_masks` and `self.soft_masks` are nested dicts: E.g.
204
+ {"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
205
+ - If masks are saved to geojson .json files, more key word arguments
206
+ need to be given at class initialization. Namely: `geo_format`,
207
+ `classes_type`, `classes_sem`, `offsets`. See more in the
208
+ `FileHandler.save_masks` docs.
195
209
"""
196
210
self .soft_masks = {}
197
211
self .out_masks = {}
@@ -223,89 +237,25 @@ def infer(self) -> None:
223
237
self .soft_masks [n ] = m
224
238
225
239
if self .save_dir is None :
226
- if self .save_masks :
227
- for n , m in zip (names , seg_results ):
228
- self .out_masks [n ] = m
240
+ for n , m in zip (names , seg_results ):
241
+ self .out_masks [n ] = m
229
242
else :
230
243
loader .set_postfix_str ("Saving results to disk" )
231
244
if self .batch_size > 1 :
232
- self .save_parallel (seg_results , names , self .save_dir )
245
+ fnames = [Path (self .save_dir ) / n for n in names ]
246
+ FileHandler .save_masks_parallel (
247
+ maps = seg_results ,
248
+ fnames = fnames ,
249
+ ** {** self .kwargs , "format" : self .save_format },
250
+ )
233
251
else :
234
252
for n , m in zip (names , seg_results ):
235
- self .save_mask (m , n , self .save_dir )
236
-
237
- @staticmethod
238
- def save_mask (
239
- maps : Dict [str , np .ndarray ],
240
- fname : str ,
241
- save_dir : Union [str , Path ],
242
- format : str = ".mat" ,
243
- ) -> None :
244
- """Save model outputs to .mat or geojson.
245
-
246
- Parameters
247
- ----------
248
- maps : Dict[str, np.ndarray]
249
- model output names mapped to model outputs.
250
- E.g. {"sem": np.ndarray, "type": np.ndarray, "inst": np.ndarray}
251
- fname : str
252
- Name for the output-file.
253
- save_dir : Path or str
254
- Path to the save directory.
255
- format : str
256
- One of ".mat" or "geojson"
257
- """
258
- allowed = (".mat" , ".json" )
259
- if format not in allowed :
260
- raise ValueError (
261
- f"Illegal file-format. Got: { format } . Allowed formats: { allowed } "
262
- )
263
-
264
- if format == ".mat" :
265
- mask2mat (fname , save_dir , ** maps )
266
- else :
267
- pass
268
-
269
- return True
270
-
271
- @staticmethod
272
- def save_parallel (
273
- maps : List [Dict [str , np .ndarray ]],
274
- fnames : List [str ],
275
- save_dir : Union [Path , str ],
276
- format : str = ".mat" ,
277
- progress_bar : bool = False ,
278
- ) -> None :
279
- """Save the model output masks to a folder. (multi-threaded).
280
-
281
- Parameters
282
- ----------
283
- maps : List[Dict[str, np.ndarray]]
284
- The model output map dictionaries in a list.
285
- fnames : List[str]
286
- Name for the output-files. (In the same order with `maps`)
287
- save_dir : Path or str
288
- Path to the save directory.
289
- format : str
290
- One of ".mat" or "geojson"
291
- progress_bar : bool, default=False
292
- If True, a tqdm progress bar is shown.
293
- """
294
- args = tuple (zip (maps , fnames , [save_dir ] * len (maps ), [format ] * len (maps )))
295
-
296
- with Pool () as pool :
297
- if progress_bar :
298
- it = tqdm (pool .imap (BaseInferer ._save_mask , args ), total = len (maps ))
299
- else :
300
- it = pool .imap (BaseInferer ._save_mask , args )
301
-
302
- for _ in it :
303
- pass
304
-
305
- @staticmethod
306
- def _save_mask (args : Tuple [Dict [str , np .ndarray ], str , str ]) -> None :
307
- """Unpacks the args for `save_mask` to enable multi-threading."""
308
- return BaseInferer .save_mask (* args )
253
+ fname = Path (self .save_dir ) / n
254
+ FileHandler .save_masks (
255
+ fname = fname ,
256
+ maps = m ,
257
+ ** {** self .kwargs , "format" : self .save_format },
258
+ )
309
259
310
260
def _strip_state_dict (self , ckpt : Dict ) -> OrderedDict :
311
261
"""Strip te first 'model.' (generated by lightning) from the state dict keys."""
0 commit comments