Skip to content

Commit b3e06e9

Browse files
committed
feat(inference): add hdf5 file reading
1 parent b91b6b0 commit b3e06e9

File tree

7 files changed

+159
-34
lines changed

7 files changed

+159
-34
lines changed

cellseg_models_pytorch/inference/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ._base_inferer import BaseInferer
2-
from .folder_dataset import FolderDataset
2+
from .folder_dataset_infer import FolderDatasetInfer
33
from .post_processor import PostProcessor
44
from .predictor import Predictor
55
from .resize_inferer import ResizeInferer
@@ -11,5 +11,5 @@
1111
"PostProcessor",
1212
"ResizeInferer",
1313
"SlidingWindowInferer",
14-
"FolderDataset",
14+
"FolderDatasetInfer",
1515
]

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from tqdm import tqdm
1414

1515
from ..utils import FileHandler, tensor_to_ndarray
16-
from .folder_dataset import FolderDataset
16+
from .folder_dataset_infer import FolderDatasetInfer
17+
from .hdf5_dataset_infer import HDF5DatasetInfer
1718
from .post_processor import PostProcessor
1819
from .predictor import Predictor
1920

@@ -22,7 +23,7 @@ class BaseInferer(ABC):
2223
def __init__(
2324
self,
2425
model: nn.Module,
25-
input_folder: Union[Path, str],
26+
input_path: Union[Path, str],
2627
out_activations: Dict[str, str],
2728
out_boundary_weights: Dict[str, bool],
2829
patch_size: Tuple[int, int],
@@ -47,8 +48,8 @@ def __init__(
4748
----------
4849
model : nn.Module
4950
A segmentation model.
50-
input_folder : Path | str
51-
Path to a folder of images.
51+
input_path : Path | str
52+
Path to a folder of images or to hdf5 db.
5253
out_activations : Dict[str, str]
5354
Dictionary of head names mapped to a string value that specifies the
5455
activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
@@ -87,7 +88,7 @@ def __init__(
8788
checkpoint_path : Path | str, optional
8889
Path to the model weight checkpoints.
8990
n_images : int, optional
90-
First n-number of images used from the `input_folder`.
91+
First n-number of images used from the `input_path`.
9192
type_post_proc : Callable, optional
9293
A post-processing function for the type maps. If not None, overrides
9394
the default.
@@ -112,21 +113,28 @@ def __init__(
112113
self.save_intermediate = save_intermediate
113114
self.save_format = save_format
114115

115-
# dataloader
116-
self.path = Path(input_folder)
117-
118-
folder_ds = FolderDataset(self.path, n_images=n_images)
119-
if self.save_dir is None and len(folder_ds.fnames) > 40:
120-
warnings.warn(
121-
"`save_dir` is None. Thus, the outputs are be saved in `out_masks` "
122-
"class variable. If the input folder contains many images, running "
123-
"inference will likely flood the memory depending on the size and "
124-
"number of the images. Consider saving outputs to disk by providing "
125-
"`save_dir` argument."
116+
# dataset & dataloader
117+
self.path = Path(input_path)
118+
if self.path.is_dir():
119+
ds = FolderDatasetInfer(self.path, n_images=n_images)
120+
if self.save_dir is None and len(ds.fnames) > 40:
121+
warnings.warn(
122+
"`save_dir` is None. Thus, the outputs are be saved in `out_masks` "
123+
"class attribute. If the input folder contains many images, running"
124+
" inference will likely flood the memory depending on the size and "
125+
"number of the images. Consider saving outputs to disk by providing"
126+
" `save_dir` argument."
127+
)
128+
elif self.path.is_file() and self.path.suffix in (".h5", ".hdf5"):
129+
ds = HDF5DatasetInfer(self.path, n_images=n_images)
130+
else:
131+
raise ValueError(
132+
f"Given `input_path`: {input_path} is neither an image folder or a h5 "
133+
"database. Allowed suffices for h5 database are ('.h5', '.hdf5')"
126134
)
127135

128136
self.dataloader = DataLoader(
129-
folder_ds, batch_size=batch_size, shuffle=False, pin_memory=True
137+
ds, batch_size=batch_size, shuffle=False, pin_memory=True
130138
)
131139

132140
# Set post processor

cellseg_models_pytorch/inference/folder_dataset.py renamed to cellseg_models_pytorch/inference/folder_dataset_infer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
SUFFIXES = (".jpeg", ".jpg", ".tif", ".tiff", ".png")
1010

1111

12-
__all__ = ["FolderDataset"]
12+
__all__ = ["FolderDatasetInfer"]
1313

1414

15-
class FolderDataset(Dataset, FileHandler):
15+
class FolderDatasetInfer(Dataset, FileHandler):
1616
def __init__(
1717
self, path: Union[str, Path], pattern: str = "*", n_images: int = None
1818
) -> None:
@@ -55,10 +55,10 @@ def __len__(self) -> int:
5555
"""Length of folder."""
5656
return len(self.fnames)
5757

58-
def __getitem__(self, index: int) -> torch.Tensor:
58+
def __getitem__(self, ix: int) -> torch.Tensor:
5959
"""Read image."""
60-
fn = self.fnames[index]
60+
fn = self.fnames[ix]
6161
im = FileHandler.read_img(fn.as_posix())
6262
im = torch.from_numpy(im.transpose(2, 0, 1))
6363

64-
return {"im": im, "file": fn.name[:-4]}
64+
return {"im": im, "file": fn.with_suffix("").name}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from pathlib import Path
2+
from typing import Union
3+
4+
import torch
5+
from torch.utils.data import Dataset
6+
7+
from cellseg_models_pytorch.utils import FileHandler
8+
9+
try:
10+
import tables as tb
11+
except Exception:
12+
raise ImportError(
13+
"`pytables` needed for this class. Install with: `pip install tables`"
14+
)
15+
16+
17+
__all__ = ["HDF5DatasetInfer"]
18+
19+
20+
class HDF5DatasetInfer(Dataset, FileHandler):
21+
def __init__(self, path: Union[str, Path], n_images: int = None, **kwargs) -> None:
22+
"""Folder dataset that can be used during inference for loading images.
23+
24+
NOTE: loads only images.
25+
26+
Parameters
27+
----------
28+
path : str | Path
29+
Path to the folder containing image files.
30+
n_images : int, optional
31+
First n-number of images used from the folder.
32+
33+
Raises
34+
------
35+
ValueError if the input path has incorrect suffix.
36+
"""
37+
super().__init__()
38+
39+
self.path = Path(path)
40+
41+
if self.path.suffix not in (".h5", ".hdf5"):
42+
raise ValueError(
43+
f"The input path has to be a hdf5 db. Got suffix: {self.path.suffix} "
44+
"Allowed suffices: {('.h5', '.hdf5')}"
45+
)
46+
47+
with tb.open_file(self.path) as h5:
48+
if n_images is not None:
49+
self.fnames = h5.root.fnames[:n_images]
50+
else:
51+
self.fnames = h5.root.fnames[:]
52+
53+
def __len__(self) -> int:
54+
"""Return the number of items in the db."""
55+
return len(self.fnames)
56+
57+
def __getitem__(self, ix: int) -> torch.Tensor:
58+
"""Read image."""
59+
fn = self.fnames[ix]
60+
61+
with tb.open_file(self.path.as_posix(), "r") as h5:
62+
im = h5.root.imgs[ix, ...]
63+
64+
im = torch.from_numpy(im.transpose(2, 0, 1))
65+
return {"im": im, "file": Path(fn.decode("UTF-8")).name}

cellseg_models_pytorch/inference/resize_inferer.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class ResizeInferer(BaseInferer):
1212
def __init__(
1313
self,
1414
model: nn.Module,
15-
input_folder: Union[Path, str],
15+
input_path: Union[Path, str],
1616
out_activations: Dict[str, str],
1717
out_boundary_weights: Dict[str, bool],
1818
resize: Tuple[int, int],
@@ -43,8 +43,8 @@ def __init__(
4343
----------
4444
model : nn.Module
4545
A segmentation model.
46-
input_folder : Path | str
47-
Path to a folder of images.
46+
input_path : Path | str
47+
Path to a folder of images or to hdf5 db.
4848
out_activations : Dict[str, str]
4949
Dictionary of head names mapped to a string value that specifies the
5050
activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
@@ -83,7 +83,7 @@ def __init__(
8383
checkpoint_path : Path | str, optional
8484
Path to the model weight checkpoints.
8585
n_images : int, optional
86-
First n-number of images used from the `input_folder`.
86+
First n-number of images used from the `input_path`.
8787
type_post_proc : Callable, optional
8888
A post-processing function for the type maps. If not None, overrides
8989
the default.
@@ -92,10 +92,34 @@ def __init__(
9292
overrides the default.
9393
**kwargs:
9494
Arbitrary keyword arguments expecially for post-processing and saving.
95+
96+
Examples
97+
--------
98+
>>> # initialize model and paths
99+
>>> model = cellpose_base(len(type_classes))
100+
>>> inputs = "/path/to/imgs"
101+
>>> ckpt_path = "/path/to/myweights.ckpt"
102+
103+
>>> # initialize output head args
104+
>>> out_activations={"type": "softmax", "cellpose": None}
105+
>>> out_boundary_weights={"type": None, "cellpose": None}
106+
107+
>>> inferer = ResizeInferer(
108+
model=model,
109+
input_path=inputs,
110+
checkpoint_path=ckpt_path,
111+
out_activations=out_activations,
112+
out_boundary_weights=out_boundary_weights,
113+
resize=(256, 256),
114+
instance_postproc="cellpose",
115+
padding=0,
116+
normalization="minmax" # This needs to be same as during training
117+
)
118+
>>> inferer.infer()
95119
"""
96120
super().__init__(
97121
model=model,
98-
input_folder=input_folder,
122+
input_path=input_path,
99123
out_activations=out_activations,
100124
out_boundary_weights=out_boundary_weights,
101125
patch_size=resize,

cellseg_models_pytorch/inference/sliding_window_inferer.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class SlidingWindowInferer(BaseInferer):
1414
def __init__(
1515
self,
1616
model: nn.Module,
17-
input_folder: Union[Path, str],
17+
input_path: Union[Path, str],
1818
out_activations: Dict[str, str],
1919
out_boundary_weights: Dict[str, bool],
2020
stride: int,
@@ -44,8 +44,8 @@ def __init__(
4444
----------
4545
model : nn.Module
4646
A segmentation model.
47-
input_folder : Path | str
48-
Path to a folder of images.
47+
input_path : Path | str
48+
Path to a folder of images or to hdf5 db.
4949
out_activations : Dict[str, str]
5050
Dictionary of head names mapped to a string value that specifies the
5151
activation applied at the head. E.g. {"type": "tanh", "cellpose": None}
@@ -86,7 +86,7 @@ def __init__(
8686
checkpoint_path : Path | str, optional
8787
Path to the model weight checkpoints.
8888
n_images : int, optional
89-
First n-number of images used from the `ìnput_folder`.
89+
First n-number of images used from the `input_path`.
9090
type_post_proc : Callable, optional
9191
A post-processing function for the type maps. If not None, overrides
9292
the default.
@@ -95,10 +95,35 @@ def __init__(
9595
overrides the default.
9696
**kwargs:
9797
Arbitrary keyword arguments expecially for post-processing and saving.
98+
99+
Examples
100+
--------
101+
>>> # initialize model and paths
102+
>>> model = cellpose_plus(len(type_classes), len(area_classes))
103+
>>> inputs = "/path/to/images"
104+
>>> ckpt_path = "/path/to/my_weights.ckpt"
105+
106+
>>> # initialize output head args
107+
>>> out_activations={"type": "softmax", "cellpose": None, "sem": "softmax"}
108+
>>> out_boundary_weights={"type": False, "cellpose": True, "sem": False}
109+
110+
>>> # Run inference
111+
>>> inferer = SlidingWindowInferer(
112+
model=model,
113+
input_path=inputs,
114+
checkpoint_path=ckpt_path,
115+
out_activations=out_activations,
116+
out_boundary_weights=out_boundary_weights,
117+
stride=256,
118+
patch_size=(320, 320),
119+
instance_postproc="cellpose",
120+
normalization="minmax" # This needs to be same as during training
121+
)
122+
>>> inferer.infer()
98123
"""
99124
super().__init__(
100125
model=model,
101-
input_folder=input_folder,
126+
input_path=input_path,
102127
out_activations=out_activations,
103128
out_boundary_weights=out_boundary_weights,
104129
patch_size=patch_size,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Features
2+
3+
- Add hdf5 input file reading to `Inferer` classes.

0 commit comments

Comments
 (0)