Skip to content

Commit d047451

Browse files
committed
feat(datamodules): add h5 writing to pannuke
1 parent b3e06e9 commit d047451

File tree

3 files changed

+70
-27
lines changed

3 files changed

+70
-27
lines changed

cellseg_models_pytorch/datamodules/lizard_datamodule.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
try:
88
from ..datasets import SegmentationFolderDataset, SegmentationHDF5Dataset
9-
from ..datasets.dataset_writers.folder_writer import SlidingWindowFolderWriter
10-
from ..datasets.dataset_writers.hdf5_writer import SlidingWindowHDF5Writer
9+
from ..datasets.dataset_writers.folder_writer import FolderWriter
10+
from ..datasets.dataset_writers.hdf5_writer import HDF5Writer
1111
from ._basemodule import BaseDataModule
1212
from .downloader import SimpleDownloader
1313
except ModuleNotFoundError:
@@ -50,7 +50,7 @@ def __init__(
5050
Parameters
5151
----------
5252
save_dir : str
53-
Path to directory where the pannuke data will be saved.
53+
Path to directory where the lizard data will be saved.
5454
fold_split : Dict[str, int]
5555
Defines how the folds are split into train, valid, and test sets.
5656
E.g. {"train": 1, "valid": 2, "test": 3}
@@ -82,18 +82,20 @@ def __init__(
8282
>>> from cellseg_models_pytorch.datamodules import LizardDataModule
8383
8484
>>> fold_split = {"train": 1, "valid": 2, "test": 3}
85-
>>> save_dir = Path.home() / "pannuke"
85+
>>> save_dir = Path.home() / "lizard"
8686
>>> lizard_module = LizardDataModule(
8787
save_dir=save_dir,
8888
fold_split=fold_split,
8989
inst_transforms=["dist", "stardist"],
9090
img_transforms=["blur", "hue_sat"],
9191
normalization="percentile",
92-
dataset_type="hdf5"
92+
dataset_type="hdf5",
93+
patch_size=(320, 320),
94+
stride=128
9395
)
9496
9597
>>> # lizard_module.download(save_dir) # just the downloading
96-
>>> lizard_module.prepare_data(do_patching=True) # downloading & processing
98+
>>> lizard_module.prepare_data(tiling=True) # downloading & processing
9799
"""
98100
super().__init__(batch_size, num_workers)
99101
self.save_dir = Path(save_dir)
@@ -115,7 +117,7 @@ def __init__(
115117

116118
@property
117119
def type_classes(self) -> Dict[str, int]:
118-
"""Pannuke cell type classes."""
120+
"""Lizard cell type classes."""
119121
return {
120122
"bg": 0,
121123
"neutrophil": 1,
@@ -138,7 +140,7 @@ def download(root: str) -> None:
138140
SimpleDownloader.download(url, root)
139141
LizardDataModule.extract_zips(root, rm=True)
140142

141-
def prepare_data(self, rm_orig: bool = False, do_patching: bool = True) -> None:
143+
def prepare_data(self, rm_orig: bool = False, tiling: bool = True) -> None:
142144
"""Prepare the lizard datasets.
143145
144146
1. Download lizard folds from:
@@ -151,9 +153,9 @@ def prepare_data(self, rm_orig: bool = False, do_patching: bool = True) -> None:
151153
rm_orig : bool, default=False
152154
After processing all the files, If True, removes the original
153155
un-processed files.
154-
do_patching : bool, default=True
155-
Flag, whether to do patching at all. Can be used if you only want to
156-
download and split the data and then work it out on your own.
156+
tiling : bool, default=True
157+
Flag, whether to cut images into tiles. Can be set to False if you only
158+
want to download and split the data and then work it out on your own.
157159
"""
158160
folders_found = [
159161
d.name
@@ -222,7 +224,7 @@ def prepare_data(self, rm_orig: bool = False, do_patching: bool = True) -> None:
222224
if "lizard" in d.name.lower() or "macosx" in d.name.lower():
223225
shutil.rmtree(d)
224226

225-
if do_patching and not patches_found:
227+
if tiling and not patches_found:
226228
print("Patch the data... This will take a while...")
227229
for phase in self.fold_split.keys():
228230
save_im_dir = self.save_dir / phase / "images"
@@ -231,7 +233,7 @@ def prepare_data(self, rm_orig: bool = False, do_patching: bool = True) -> None:
231233
if self.dataset_type == "hdf5":
232234
sdir = self.save_dir / phase / f"{phase}_patches"
233235
sdir.mkdir(parents=True, exist_ok=True)
234-
writer = SlidingWindowHDF5Writer(
236+
writer = HDF5Writer(
235237
in_dir_im=save_im_dir,
236238
in_dir_mask=save_mask_dir,
237239
save_dir=sdir,
@@ -245,7 +247,7 @@ def prepare_data(self, rm_orig: bool = False, do_patching: bool = True) -> None:
245247
sdir_mask = self.save_dir / phase / f"{phase}_mask_patches"
246248
sdir_im.mkdir(parents=True, exist_ok=True)
247249
sdir_mask.mkdir(parents=True, exist_ok=True)
248-
writer = SlidingWindowFolderWriter(
250+
writer = FolderWriter(
249251
in_dir_im=save_im_dir,
250252
in_dir_mask=save_mask_dir,
251253
save_dir_im=sdir_im,
@@ -254,7 +256,7 @@ def prepare_data(self, rm_orig: bool = False, do_patching: bool = True) -> None:
254256
stride=self.stride,
255257
transforms=["rigid"],
256258
)
257-
writer.write(pre_proc=self._process_label, msg=phase)
259+
writer.write(tiling=True, pre_proc=self._process_label, msg=phase)
258260
else:
259261
print(
260262
"Found processed Lizard data. "

cellseg_models_pytorch/datamodules/pannuke_datamodule.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from ..utils import FileHandler, fix_duplicates
99

1010
try:
11-
from ..datasets import SegmentationFolderDataset
11+
from ..datasets import SegmentationFolderDataset, SegmentationHDF5Dataset
12+
from ..datasets.dataset_writers.hdf5_writer import HDF5Writer
1213
from ._basemodule import BaseDataModule
1314
from .downloader import SimpleDownloader
1415
except ModuleNotFoundError:
@@ -26,6 +27,7 @@ def __init__(
2627
fold_split: Dict[str, int],
2728
img_transforms: List[str],
2829
inst_transforms: List[str],
30+
dataset_type: str = "folder",
2931
normalization: str = None,
3032
batch_size: int = 8,
3133
num_workers: int = 8,
@@ -65,6 +67,8 @@ def __init__(
6567
A list containg all the transformations that are applied to only the
6668
instance labelled masks. Allowed ones: "cellpose", "contour", "dist",
6769
"edgeweight", "hovernet", "omnipose", "smooth_dist", "binarize"
70+
dataset_type : str, default="folder"
71+
The dataset type. One of "folder", "hdf5".
6872
normalization : str, optional
6973
Apply img normalization after all the transformations. One of "minmax",
7074
"norm", "percentile", None.
@@ -107,6 +111,14 @@ def __init__(
107111
self.normalization = normalization
108112
self.kwargs = kwargs if kwargs is not None else {}
109113

114+
if dataset_type not in ("folder", "hdf5"):
115+
raise ValueError(
116+
f"Illegal `dataset_type` arg. Got {dataset_type}. "
117+
f"Allowed: {('folder', 'hdf5')}"
118+
)
119+
120+
self.dataset_type = dataset_type
121+
110122
@property
111123
def type_classes(self) -> Dict[str, int]:
112124
"""Pannuke cell type classes."""
@@ -127,7 +139,7 @@ def download(root: str) -> None:
127139
SimpleDownloader.download(url, root)
128140
PannukeDataModule.extract_zips(root, rm=True)
129141

130-
def prepare_data(self, rm_orig: bool = True) -> None:
142+
def prepare_data(self, rm_orig: bool = False) -> None:
131143
"""Prepare the pannuke datasets.
132144
133145
1. Download pannuke folds from:
@@ -167,6 +179,18 @@ def prepare_data(self, rm_orig: bool = True) -> None:
167179
self._process_pannuke_fold(
168180
fold_paths, save_im_dir, save_mask_dir, fold_ix, phase
169181
)
182+
183+
if self.dataset_type == "hdf5":
184+
writer = HDF5Writer(
185+
in_dir_im=save_im_dir,
186+
in_dir_mask=save_mask_dir,
187+
save_dir=self.save_dir / phase,
188+
file_name=f"pannuke_{phase}.h5",
189+
patch_size=None,
190+
stride=None,
191+
transforms=None,
192+
)
193+
writer.write(tiling=False, msg=phase)
170194
else:
171195
print(
172196
"Found processed pannuke data. "
@@ -178,31 +202,45 @@ def prepare_data(self, rm_orig: bool = True) -> None:
178202
if "fold" in d.name.lower():
179203
shutil.rmtree(d)
180204

205+
def _get_path(self, phase: str, dstype: str, is_mask: bool = False) -> Path:
206+
if dstype == "hdf5":
207+
p = self.save_dir / phase / f"pannuke_{phase}.h5"
208+
else:
209+
dtype = "labels" if is_mask else "images"
210+
p = self.save_dir / phase / dtype
211+
212+
return p
213+
181214
def setup(self, stage: Optional[str] = None) -> None:
182215
"""Set up the train, valid, and test datasets."""
183-
self.trainset = SegmentationFolderDataset(
184-
path=self.save_dir / "train" / "images",
185-
mask_path=self.save_dir / "train" / "labels",
216+
if self.dataset_type == "hdf5":
217+
DS = SegmentationHDF5Dataset
218+
else:
219+
DS = SegmentationFolderDataset
220+
221+
self.trainset = DS(
222+
path=self._get_path("train", self.dataset_type, is_mask=False),
223+
mask_path=self._get_path("train", self.dataset_type, is_mask=True),
186224
img_transforms=self.img_transforms,
187225
inst_transforms=self.inst_transforms,
188226
return_sem=False,
189227
normalization=self.normalization,
190228
**self.kwargs,
191229
)
192230

193-
self.validset = SegmentationFolderDataset(
194-
path=self.save_dir / "valid" / "images",
195-
mask_path=self.save_dir / "valid" / "labels",
231+
self.validset = DS(
232+
path=self._get_path("valid", self.dataset_type, is_mask=False),
233+
mask_path=self._get_path("valid", self.dataset_type, is_mask=True),
196234
img_transforms=self.img_transforms,
197235
inst_transforms=self.inst_transforms,
198236
return_sem=False,
199237
normalization=self.normalization,
200238
**self.kwargs,
201239
)
202240

203-
self.testset = SegmentationFolderDataset(
204-
path=self.save_dir / "test" / "images",
205-
mask_path=self.save_dir / "test" / "labels",
241+
self.testset = DS(
242+
path=self._get_path("test", self.dataset_type, is_mask=False),
243+
mask_path=self._get_path("test", self.dataset_type, is_mask=True),
206244
img_transforms=self.img_transforms,
207245
inst_transforms=self.inst_transforms,
208246
return_sem=False,
@@ -256,7 +294,7 @@ def _process_pannuke_fold(
256294
inst_map = self._get_inst_map(temp_mask[..., 0:5])
257295

258296
fn_mask = Path(save_mask_dir / name).with_suffix(".mat")
259-
FileHandler.write_mask(fn_mask, inst_map, type_map)
297+
FileHandler.write_mat(fn_mask, inst_map, type_map)
260298
pbar.update(1)
261299

262300
def _get_type_map(self, pannuke_mask: np.ndarray) -> np.ndarray:
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Features
2+
3+
- Add option to write pannuke dataset to h5 db in `PannukeDataModule` and `LizardDataModule`.

0 commit comments

Comments
 (0)