Skip to content

Commit 0fb050a

Browse files
committed
refactor: rewrite inference saving w FileHandler
1 parent f6a9f06 commit 0fb050a

File tree

3 files changed

+70
-125
lines changed

3 files changed

+70
-125
lines changed

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 49 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from abc import ABC, abstractmethod
23
from collections import OrderedDict
34
from itertools import chain
@@ -8,12 +9,10 @@
89
import torch
910
import torch.nn as nn
1011
import yaml
11-
from pathos.multiprocessing import ThreadPool as Pool
1212
from torch.utils.data import DataLoader
1313
from tqdm import tqdm
1414

15-
from ..utils import tensor_to_ndarray
16-
from ..utils.save_utils import mask2mat
15+
from ..utils import FileHandler, tensor_to_ndarray
1716
from .folder_dataset import FolderDataset
1817
from .post_processor import PostProcessor
1918
from .predictor import Predictor
@@ -33,14 +32,14 @@ def __init__(
3332
normalization: str = None,
3433
device: str = "cuda",
3534
n_devices: int = 1,
36-
save_masks: bool = True,
3735
save_intermediate: bool = False,
3836
save_dir: Union[Path, str] = None,
37+
save_format: str = ".mat",
3938
checkpoint_path: Union[Path, str] = None,
4039
n_images: int = None,
4140
type_post_proc: Callable = None,
4241
sem_post_proc: Callable = None,
43-
**postproc_kwargs,
42+
**kwargs,
4443
) -> None:
4544
"""Inference for an image folder.
4645
@@ -77,16 +76,14 @@ def __init__(
7776
n_devices : int, default=1
7877
Number of devices (cpus/gpus) used for inference.
7978
The model will be copied into these devices.
80-
save_masks : bool, default=False
81-
If True, the resulting segmentation masks will be saved into `out_masks`
82-
variable.
83-
save_intermediate : bool, default=False
84-
If True, intermediate soft masks will be saved into `soft_masks` var.
8579
save_dir : bool, optional
8680
Path to save directory. If None, no masks will be saved to disk as .mat
87-
files. If not None, overrides `save_masks`, thus for every batch the
88-
segmentation results are saved into disk and the intermediate results
89-
are flushed.
81+
or .json files. Instead the masks will be saved in `self.out_masks`.
82+
save_intermediate : bool, default=False
83+
If True, intermediate soft masks will be saved into `soft_masks` var.
84+
save_format : str, default=".mat"
85+
The file format for the saved output masks. One of (".mat", ".json").
86+
The ".json" option will save masks into geojson format.
9087
checkpoint_path : Path | str, optional
9188
Path to the model weight checkpoints.
9289
n_images : int, optional
@@ -97,8 +94,8 @@ def __init__(
9794
sem_post_proc : Callable, optional
9895
A post-processing function for the semantc seg maps. If not None,
9996
overrides the default.
100-
**postproc_kwargs:
101-
Arbitrary keyword arguments for the post-processing.
97+
**kwargs:
98+
Arbitrary keyword arguments expecially for post-processing and saving.
10299
"""
103100
# basic inits
104101
self.model = model
@@ -109,14 +106,25 @@ def __init__(
109106
self.out_activations = out_activations
110107
self.out_boundary_weights = out_boundary_weights
111108
self.head_kwargs = self._check_and_set_head_args()
109+
self.kwargs = kwargs
112110

113111
self.save_dir = Path(save_dir) if save_dir is not None else None
114-
self.save_masks = save_masks
115112
self.save_intermediate = save_intermediate
113+
self.save_format = save_format
116114

117115
# dataloader
118116
self.path = Path(input_folder)
117+
119118
folder_ds = FolderDataset(self.path, n_images=n_images)
119+
if self.save_dir is None and len(folder_ds.fnames) > 40:
120+
warnings.warn(
121+
"`save_dir` is None. Thus, the outputs are be saved in `out_masks` "
122+
"class variable. If the input folder contains many images, running "
123+
"inference will likely flood the memory depending on the size and "
124+
"number of the images. Consider saving outputs to disk by providing "
125+
"`save_dir` argument."
126+
)
127+
120128
self.dataloader = DataLoader(
121129
folder_ds, batch_size=batch_size, shuffle=False, pin_memory=True
122130
)
@@ -128,7 +136,7 @@ def __init__(
128136
aux_key=self.model.aux_key,
129137
type_post_proc=type_post_proc,
130138
sem_post_proc=sem_post_proc,
131-
**postproc_kwargs,
139+
**kwargs,
132140
)
133141

134142
# load weights and set devices
@@ -188,10 +196,16 @@ def _infer_batch(self):
188196
def infer(self) -> None:
189197
"""Run inference and post-processing for the images.
190198
191-
NOTE: Saves outputs in `self.out_masks` or to disk (.mat) files.
192-
193-
`self.out_masks` is a nested dict: E.g.
194-
{"image1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
199+
NOTE:
200+
- Saves outputs in `self.out_masks` or to disk (.mat/.json) files.
201+
- If `save_intermediate` is set to True, also intermiediate model outputs are
202+
saved to `self.soft_masks`
203+
- `self.out_masks` and `self.soft_masks` are nested dicts: E.g.
204+
{"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
205+
- If masks are saved to geojson .json files, more key word arguments
206+
need to be given at class initialization. Namely: `geo_format`,
207+
`classes_type`, `classes_sem`, `offsets`. See more in the
208+
`FileHandler.save_masks` docs.
195209
"""
196210
self.soft_masks = {}
197211
self.out_masks = {}
@@ -223,89 +237,25 @@ def infer(self) -> None:
223237
self.soft_masks[n] = m
224238

225239
if self.save_dir is None:
226-
if self.save_masks:
227-
for n, m in zip(names, seg_results):
228-
self.out_masks[n] = m
240+
for n, m in zip(names, seg_results):
241+
self.out_masks[n] = m
229242
else:
230243
loader.set_postfix_str("Saving results to disk")
231244
if self.batch_size > 1:
232-
self.save_parallel(seg_results, names, self.save_dir)
245+
fnames = [Path(self.save_dir) / n for n in names]
246+
FileHandler.save_masks_parallel(
247+
maps=seg_results,
248+
fnames=fnames,
249+
**{**self.kwargs, "format": self.save_format},
250+
)
233251
else:
234252
for n, m in zip(names, seg_results):
235-
self.save_mask(m, n, self.save_dir)
236-
237-
@staticmethod
238-
def save_mask(
239-
maps: Dict[str, np.ndarray],
240-
fname: str,
241-
save_dir: Union[str, Path],
242-
format: str = ".mat",
243-
) -> None:
244-
"""Save model outputs to .mat or geojson.
245-
246-
Parameters
247-
----------
248-
maps : Dict[str, np.ndarray]
249-
model output names mapped to model outputs.
250-
E.g. {"sem": np.ndarray, "type": np.ndarray, "inst": np.ndarray}
251-
fname : str
252-
Name for the output-file.
253-
save_dir : Path or str
254-
Path to the save directory.
255-
format : str
256-
One of ".mat" or "geojson"
257-
"""
258-
allowed = (".mat", ".json")
259-
if format not in allowed:
260-
raise ValueError(
261-
f"Illegal file-format. Got: {format}. Allowed formats: {allowed}"
262-
)
263-
264-
if format == ".mat":
265-
mask2mat(fname, save_dir, **maps)
266-
else:
267-
pass
268-
269-
return True
270-
271-
@staticmethod
272-
def save_parallel(
273-
maps: List[Dict[str, np.ndarray]],
274-
fnames: List[str],
275-
save_dir: Union[Path, str],
276-
format: str = ".mat",
277-
progress_bar: bool = False,
278-
) -> None:
279-
"""Save the model output masks to a folder. (multi-threaded).
280-
281-
Parameters
282-
----------
283-
maps : List[Dict[str, np.ndarray]]
284-
The model output map dictionaries in a list.
285-
fnames : List[str]
286-
Name for the output-files. (In the same order with `maps`)
287-
save_dir : Path or str
288-
Path to the save directory.
289-
format : str
290-
One of ".mat" or "geojson"
291-
progress_bar : bool, default=False
292-
If True, a tqdm progress bar is shown.
293-
"""
294-
args = tuple(zip(maps, fnames, [save_dir] * len(maps), [format] * len(maps)))
295-
296-
with Pool() as pool:
297-
if progress_bar:
298-
it = tqdm(pool.imap(BaseInferer._save_mask, args), total=len(maps))
299-
else:
300-
it = pool.imap(BaseInferer._save_mask, args)
301-
302-
for _ in it:
303-
pass
304-
305-
@staticmethod
306-
def _save_mask(args: Tuple[Dict[str, np.ndarray], str, str]) -> None:
307-
"""Unpacks the args for `save_mask` to enable multi-threading."""
308-
return BaseInferer.save_mask(*args)
253+
fname = Path(self.save_dir) / n
254+
FileHandler.save_masks(
255+
fname=fname,
256+
maps=m,
257+
**{**self.kwargs, "format": self.save_format},
258+
)
309259

310260
def _strip_state_dict(self, ckpt: Dict) -> OrderedDict:
311261
"""Strip te first 'model.' (generated by lightning) from the state dict keys."""

cellseg_models_pytorch/inference/resize_inferer.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ def __init__(
2222
normalization: str = None,
2323
device: str = "cuda",
2424
n_devices: int = 1,
25-
save_masks: bool = True,
2625
save_intermediate: bool = False,
2726
save_dir: Union[Path, str] = None,
27+
save_format: str = ".mat",
2828
checkpoint_path: Union[Path, str] = None,
2929
n_images: int = None,
3030
type_post_proc: Callable = None,
3131
sem_post_proc: Callable = None,
32-
**postproc_kwargs,
32+
**kwargs,
3333
) -> None:
3434
"""Resize inference for a folder of images.
3535
@@ -73,16 +73,13 @@ def __init__(
7373
n_devices : int, default=1
7474
Number of devices (cpus/gpus) used for inference.
7575
The model will be copied into these devices.
76-
save_masks : bool, default=False
77-
If True, the resulting segmentation masks will be saved into `out_masks`
78-
variable.
7976
save_intermediate : bool, default=False
8077
If True, intermediate soft masks will be saved into `soft_masks` var.
81-
save_dir : bool, optional
82-
Path to save directory. If None, no masks will be saved to disk as .mat
83-
files. If not None, overrides `save_masks`, thus for every batch the
84-
segmentation results are saved into disk and the intermediate results
85-
are flushed.
78+
save_intermediate : bool, default=False
79+
If True, intermediate soft masks will be saved into `soft_masks` var.
80+
save_format : str, default=".mat"
81+
The file format for the saved output masks. One of (".mat", ".json").
82+
The ".json" option will save masks into geojson format.
8683
checkpoint_path : Path | str, optional
8784
Path to the model weight checkpoints.
8885
n_images : int, optional
@@ -93,8 +90,8 @@ def __init__(
9390
sem_post_proc : Callable, optional
9491
A post-processing function for the semantc seg maps. If not None,
9592
overrides the default.
96-
**postproc_kwargs:
97-
Arbitrary keyword arguments for the post-processing.
93+
**kwargs:
94+
Arbitrary keyword arguments expecially for post-processing and saving.
9895
"""
9996
super().__init__(
10097
model=model,
@@ -108,14 +105,14 @@ def __init__(
108105
instance_postproc=instance_postproc,
109106
device=device,
110107
n_devices=n_devices,
111-
save_masks=save_masks,
112108
save_intermediate=save_intermediate,
113109
save_dir=save_dir,
110+
save_format=save_format,
114111
checkpoint_path=checkpoint_path,
115112
n_images=n_images,
116113
type_post_proc=type_post_proc,
117114
sem_post_proc=sem_post_proc,
118-
**postproc_kwargs,
115+
**kwargs,
119116
)
120117

121118
def _infer_batch(self, input_batch: torch.Tensor) -> Dict[str, torch.Tensor]:

cellseg_models_pytorch/inference/sliding_window_inferer.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ def __init__(
2525
normalization: str = None,
2626
device: str = "cuda",
2727
n_devices: int = 1,
28-
save_masks: bool = True,
2928
save_intermediate: bool = False,
3029
save_dir: Union[Path, str] = None,
30+
save_format: str = ".mat",
3131
checkpoint_path: Union[Path, str] = None,
3232
n_images: int = None,
3333
type_post_proc: Callable = None,
3434
sem_post_proc: Callable = None,
35-
**postproc_kwargs,
35+
**kwargs,
3636
) -> None:
3737
"""Sliding window inference for a folder of images.
3838
@@ -75,16 +75,14 @@ def __init__(
7575
n_devices : int, default=1
7676
Number of devices (cpus/gpus) used for inference.
7777
The model will be copied into these devices.
78-
save_masks : bool, default=False
79-
If True, the resulting segmentation masks will be saved into `out_masks`
80-
variable.
8178
save_intermediate : bool, default=False
8279
If True, intermediate soft masks will be saved into `soft_masks` var.
8380
save_dir : bool, optional
8481
Path to save directory. If None, no masks will be saved to disk as .mat
85-
files. If not None, overrides `save_masks`, thus for every batch the
86-
segmentation results are saved into disk and the intermediate results
87-
are flushed.
82+
or .json files. Instead the masks will be saved in `self.out_masks`.
83+
save_format : str, default=".mat"
84+
The file format for the saved output masks. One of (".mat", ".json").
85+
The ".json" option will save masks into geojson format.
8886
checkpoint_path : Path | str, optional
8987
Path to the model weight checkpoints.
9088
n_images : int, optional
@@ -95,8 +93,8 @@ def __init__(
9593
sem_post_proc : Callable, optional
9694
A post-processing function for the semantc seg maps. If not None,
9795
overrides the default.
98-
**postproc_kwargs:
99-
Arbitrary keyword arguments for the post-processing.
96+
**kwargs:
97+
Arbitrary keyword arguments expecially for post-processing and saving.
10098
"""
10199
super().__init__(
102100
model=model,
@@ -109,15 +107,15 @@ def __init__(
109107
normalization=normalization,
110108
instance_postproc=instance_postproc,
111109
device=device,
112-
save_masks=save_masks,
113110
save_intermediate=save_intermediate,
114111
save_dir=save_dir,
112+
save_format=save_format,
115113
checkpoint_path=checkpoint_path,
116114
n_images=n_images,
117115
n_devices=n_devices,
118116
type_post_proc=type_post_proc,
119117
sem_post_proc=sem_post_proc,
120-
**postproc_kwargs,
118+
**kwargs,
121119
)
122120

123121
self.stride = stride

0 commit comments

Comments
 (0)