Skip to content

Commit 351c232

Browse files
committed
feat: add result seg file formats from geopandas
1 parent 30a7510 commit 351c232

File tree

8 files changed

+2143
-1482
lines changed

8 files changed

+2143
-1482
lines changed

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 118 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from abc import ABC, abstractmethod
32
from collections import OrderedDict
43
from itertools import chain
@@ -33,9 +32,6 @@ def __init__(
3332
normalization: str = None,
3433
device: str = "cuda",
3534
n_devices: int = 1,
36-
save_intermediate: bool = False,
37-
save_dir: Union[Path, str] = None,
38-
save_format: str = ".mat",
3935
checkpoint_path: Union[Path, str] = None,
4036
n_images: int = None,
4137
type_post_proc: Callable = None,
@@ -46,57 +42,49 @@ def __init__(
4642
4743
Parameters
4844
----------
49-
model : nn.Module
50-
A segmentation model.
51-
input_path : Path | str
52-
Path to a folder of images or to hdf5 db.
53-
out_activations : Dict[str, str]
54-
Dictionary of head names mapped to a string value that specifies the
55-
activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
56-
Allowed values: "softmax", "sigmoid", "tanh", None.
57-
out_boundary_weights : Dict[str, bool]
58-
Dictionary of head names mapped to a boolean value. If the value is
59-
True, after a prediction, a weight matrix is applied that assigns bigger
60-
weight on pixels in the center and less weight to pixels on the tile
61-
boundaries. helps dealing with prediction artefacts on the boundaries.
62-
E.g. {"type": False, "cellpose": True}
63-
patch_size : Tuple[int, int]:
64-
The size of the input patches that are fed to the segmentation model.
65-
instance_postproc : str
66-
The post-processing method for the instance segmentation mask. One of:
67-
"cellpose", "omnipose", "stardist", "hovernet", "dcan", "drfns", "dran"
68-
padding : int, optional
69-
The amount of reflection padding for the input images.
70-
batch_size : int, default=8
71-
Number of images loaded from the folder at every batch.
72-
normalization : str, optional
73-
Apply img normalization at forward pass (Same as during training).
74-
One of: "dataset", "minmax", "norm", "percentile", None.
75-
device : str, default="cuda"
76-
The device of the input and model. One of: "cuda", "cpu"
77-
n_devices : int, default=1
78-
Number of devices (cpus/gpus) used for inference.
79-
The model will be copied into these devices.
80-
save_dir : bool, optional
81-
Path to save directory. If None, no masks will be saved to disk as .mat
82-
or .json files. Instead the masks will be saved in `self.out_masks`.
83-
save_intermediate : bool, default=False
84-
If True, intermediate soft masks will be saved into `soft_masks` var.
85-
save_format : str, default=".mat"
86-
The file format for the saved output masks. One of (".mat", ".json").
87-
The ".json" option will save masks into geojson format.
88-
checkpoint_path : Path | str, optional
89-
Path to the model weight checkpoints.
90-
n_images : int, optional
91-
First n-number of images used from the `input_path`.
92-
type_post_proc : Callable, optional
93-
A post-processing function for the type maps. If not None, overrides
94-
the default.
95-
sem_post_proc : Callable, optional
96-
A post-processing function for the semantc seg maps. If not None,
97-
overrides the default.
98-
**kwargs:
99-
Arbitrary keyword arguments expecially for post-processing and saving.
45+
model : nn.Module
46+
A segmentation model.
47+
input_path : Path | str
48+
Path to a folder of images or to hdf5 db.
49+
out_activations : Dict[str, str]
50+
Dictionary of head names mapped to a string value that specifies the
51+
activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
52+
Allowed values: "softmax", "sigmoid", "tanh", None.
53+
out_boundary_weights : Dict[str, bool]
54+
Dictionary of head names mapped to a boolean value. If the value is
55+
True, after a prediction, a weight matrix is applied that assigns bigger
56+
weight on pixels in the center and less weight to pixels on the tile
57+
boundaries. helps dealing with prediction artefacts on the boundaries.
58+
E.g. {"type": False, "cellpose": True}
59+
patch_size : Tuple[int, int]:
60+
The size of the input patches that are fed to the segmentation model.
61+
instance_postproc : str
62+
The post-processing method for the instance segmentation mask. One of:
63+
"cellpose", "omnipose", "stardist", "hovernet", "dcan", "drfns", "dran"
64+
padding : int, optional
65+
The amount of reflection padding for the input images.
66+
batch_size : int, default=8
67+
Number of images loaded from the folder at every batch.
68+
normalization : str, optional
69+
Apply img normalization at forward pass (Same as during training).
70+
One of: "dataset", "minmax", "norm", "percentile", None.
71+
device : str, default="cuda"
72+
The device of the input and model. One of: "cuda", "cpu"
73+
n_devices : int, default=1
74+
Number of devices (cpus/gpus) used for inference.
75+
The model will be copied into these devices.
76+
checkpoint_path : Path | str, optional
77+
Path to the model weight checkpoints.
78+
n_images : int, optional
79+
First n-number of images used from the `input_path`.
80+
type_post_proc : Callable, optional
81+
A post-processing function for the type maps. If not None, overrides
82+
the default.
83+
sem_post_proc : Callable, optional
84+
A post-processing function for the semantc seg maps. If not None,
85+
overrides the default.
86+
**kwargs:
87+
Arbitrary keyword arguments for post-processing.
10088
"""
10189
# basic inits
10290
self.model = model
@@ -109,22 +97,10 @@ def __init__(
10997
self.head_kwargs = self._check_and_set_head_args()
11098
self.kwargs = kwargs
11199

112-
self.save_dir = Path(save_dir) if save_dir is not None else None
113-
self.save_intermediate = save_intermediate
114-
self.save_format = save_format
115-
116100
# dataset & dataloader
117101
self.path = Path(input_path)
118102
if self.path.is_dir():
119103
ds = FolderDatasetInfer(self.path, n_images=n_images)
120-
if self.save_dir is None and len(ds.fnames) > 40 and n_images is None:
121-
warnings.warn(
122-
"`save_dir` is None. Thus, the outputs are be saved in `out_masks` "
123-
"class attribute. If the input folder contains many images, running"
124-
" inference will likely flood the memory depending on the size and "
125-
"number of the images. Consider saving outputs to disk by providing"
126-
" `save_dir` argument."
127-
)
128104
elif self.path.is_file() and self.path.suffix in (".h5", ".hdf5"):
129105
from .hdf5_dataset_infer import HDF5DatasetInfer
130106

@@ -167,10 +143,10 @@ def __init__(
167143

168144
# try loading the weights to the model
169145
try:
170-
msg = self.model.load_state_dict(state_dict, strict=True)
146+
msg = self.model.load_state_dict(state_dict, strict=False)
171147
except RuntimeError:
172148
new_ckpt = self._strip_state_dict(state_dict)
173-
msg = self.model.load_state_dict(new_ckpt, strict=True)
149+
msg = self.model.load_state_dict(new_ckpt, strict=False)
174150
except BaseException as e:
175151
raise RuntimeError(f"Error when loading checkpoint: {e}")
176152

@@ -218,34 +194,74 @@ def from_yaml(cls, model: nn.Module, yaml_path: str):
218194
def _infer_batch(self):
219195
raise NotImplementedError
220196

221-
def infer(self, mixed_precision: bool = False) -> None:
222-
"""Run inference and post-processing for the images.
197+
def infer(
198+
self,
199+
save_dir: Union[Path, str] = None,
200+
save_format: str = ".mat",
201+
save_intermediate: bool = False,
202+
classes_type: Dict[str, int] = None,
203+
classes_sem: Dict[str, int] = None,
204+
offsets: bool = False,
205+
mixed_precision: bool = False,
206+
) -> None:
207+
"""Run inference and post-processing for the image(s) inside `input_path`.
223208
224-
NOTE:
225-
- Saves outputs in class attributes or to disk (.mat/.json) files.
226-
- If masks are saved to .json (geojson) files, more key word arguments
227-
need to be given at class initialization. Namely: `geo_format`,
228-
`classes_type`, `classes_sem`, `offsets`. See more in the
229-
`FileHandler.save_masks` docs.
209+
NOTE: If `save_dir` is None, the output masks will be cached in a class
210+
attribute `self.out_masks`. Otherwise the masks will be saved to disk.
230211
212+
WARNING: Running inference without setting `save_dir` can take a lot of memory
213+
if the input directory contains many images.
231214
232215
Parameters
233216
----------
234-
mixed_precision : bool, default=False
235-
If True, inference is performed with mixed precision.
217+
save_dir : bool, optional
218+
Path to save directory. If None, no masks will be saved to disk.
219+
Instead the masks will be cached in a class attribute `self.out_masks`.
220+
save_format : str, default=".mat"
221+
The file format for the saved output masks. One of ".mat", ".geojson",
222+
"feather" "parquet".
223+
save_intermediate : bool, default=False
224+
If True, intermediate soft masks will be saved into `self.soft_masks`
225+
class attribute. WARNING: This can take a lot of memory if the input
226+
directory contains many images.
227+
classes_type : Dict[str, str], optional
228+
Cell type dictionary. e.g. {"inflam":1, "epithelial":2, "connec":3}.
229+
This is required only if `save_format` is one of the following formats:
230+
".geojson", ".parquet", ".feather".
231+
classes_sem : Dict[str, str], otional
232+
Tissue type dictionary. e.g. {"tissue1":1, "tissue2":2, "tissue3":3}
233+
This is required only if `save_format` is one of the following formats:
234+
".geojson", ".parquet", ".feather".
235+
offsets : bool, default=False
236+
If True, geojson coords are shifted by the offsets that are encoded in
237+
the filenames (e.g. "x-1000_y-4000.png"). Ignored if `format` == `.mat`.
238+
mixed_precision : bool, default=False
239+
If True, inference is performed with mixed precision.
236240
237241
Attributes
238242
----------
239-
- out_masks : Dict[str, Dict[str, np.ndarray]]
240-
The output masks for each image. The keys are the image names and the
241-
values are dictionaries of the masks. E.g.
242-
{"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
243-
- soft_masks : Dict[str, Dict[str, np.ndarray]]
244-
NOTE: This attribute is set only if `save_intermediate = True`.
245-
The soft masks for each image. I.e. the soft predictions of the trained
246-
model The keys are the image names and the values are dictionaries of
247-
the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
243+
- out_masks : Dict[str, Dict[str, np.ndarray]]
244+
The output masks for each image. The keys are the image names and the
245+
values are dictionaries of the masks. E.g.
246+
{"sample1": {"inst": [H, W], "type": [H, W], "sem": [H, W]}}
247+
- soft_masks : Dict[str, Dict[str, np.ndarray]]
248+
NOTE: This attribute is set only if `save_intermediate = True`.
249+
The soft masks for each image. I.e. the soft predictions of the trained
250+
model The keys are the image names and the values are dictionaries of
251+
the soft masks. E.g. {"sample1": {"type": [H, W], "aux": [C, H, W]}}
248252
"""
253+
# check save_dir and save_format
254+
save_dir = Path(save_dir) if save_dir is not None else None
255+
save_intermediate = save_intermediate
256+
save_format = save_format
257+
if save_dir is not None:
258+
allowed_formats = (".mat", ".geojson", ".feather", ".parquet")
259+
if save_format not in allowed_formats:
260+
raise ValueError(
261+
f"Given `save_format`: {save_format} is not one of the allowed "
262+
f"formats: {allowed_formats}"
263+
)
264+
249265
self.soft_masks = {}
250266
self.out_masks = {}
251267
self.elapsed = []
@@ -271,7 +287,7 @@ def infer(self, mixed_precision: bool = False) -> None:
271287
self.elapsed.append(loader.format_dict["elapsed"])
272288
self.rate.append(loader.format_dict["rate"])
273289

274-
if self.save_intermediate:
290+
if save_intermediate:
275291
for n, m in zip(names, soft_masks):
276292
self.soft_masks[n] = m
277293

@@ -283,25 +299,33 @@ def infer(self, mixed_precision: bool = False) -> None:
283299
seg["soft_sem"] = soft["sem"]
284300

285301
# save to cache or disk
286-
if self.save_dir is None:
302+
if save_dir is None:
287303
for n, m in zip(names, seg_results):
288304
self.out_masks[n] = m
289305
else:
290306
loader.set_postfix_str("Saving results to disk")
291307
if self.batch_size > 1:
292-
fnames = [Path(self.save_dir) / n for n in names]
308+
fnames = [Path(save_dir) / n for n in names]
293309
FileHandler.save_masks_parallel(
294310
maps=seg_results,
295311
fnames=fnames,
296-
**{**self.kwargs, "format": self.save_format},
312+
format=save_format,
313+
classes_type=classes_type,
314+
classes_sem=classes_sem,
315+
offsets=offsets,
316+
pooltype="thread",
317+
maptype="amap",
297318
)
298319
else:
299320
for n, m in zip(names, seg_results):
300-
fname = Path(self.save_dir) / n
321+
fname = Path(save_dir) / n
301322
FileHandler.save_masks(
302323
fname=fname,
303324
maps=m,
304-
**{**self.kwargs, "format": self.save_format},
325+
format=save_format,
326+
classes_type=classes_type,
327+
classes_sem=classes_sem,
328+
offsets=offsets,
305329
)
306330

307331
def _strip_state_dict(self, ckpt: Dict) -> OrderedDict:

0 commit comments

Comments
 (0)