Skip to content

Commit 07f31e8

Browse files
committed
feat(datamodules): add lizard datamodule, fix bugs
1 parent 8eca06c commit 07f31e8

File tree

8 files changed

+446
-27
lines changed

8 files changed

+446
-27
lines changed
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
from .custom_datamodule import CustomDataModule
22
from .downloader import SimpleDownloader
3+
from .lizard_datamodule import LizardDataModule
34
from .pannuke_datamodule import PannukeDataModule
45

5-
__all__ = ["CustomDataModule", "SimpleDownloader", "PannukeDataModule"]
6+
__all__ = [
7+
"CustomDataModule",
8+
"SimpleDownloader",
9+
"PannukeDataModule",
10+
"LizardDataModule",
11+
]
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
import shutil
2+
from pathlib import Path
3+
from typing import Dict, List, Optional, Tuple
4+
5+
import numpy as np
6+
7+
try:
8+
from ..datasets import SegmentationFolderDataset, SegmentationHDF5Dataset
9+
from ..datasets.dataset_writers.folder_writer import SlidingWindowFolderWriter
10+
from ..datasets.dataset_writers.hdf5_writer import SlidingWindowHDF5Writer
11+
from ._basemodule import BaseDataModule
12+
from .downloader import SimpleDownloader
13+
except ModuleNotFoundError:
14+
raise ModuleNotFoundError(
15+
"To use the LizardDataModule, requests, pytorch-lightning, & albumentations "
16+
"libraries are needed. Install with "
17+
"`pip install requests pytorch-lightning albumentations`"
18+
)
19+
20+
21+
class LizardDataModule(BaseDataModule):
22+
def __init__(
23+
self,
24+
save_dir: str,
25+
fold_split: Dict[str, int],
26+
img_transforms: List[str],
27+
inst_transforms: List[str],
28+
dataset_type: str = "folder",
29+
patch_size: Tuple[int, int] = (256, 256),
30+
stride: int = 128,
31+
normalization: str = None,
32+
batch_size: int = 8,
33+
num_workers: int = 8,
34+
**kwargs,
35+
) -> None:
36+
"""Set up Lizard datamodule. Creates overlapping patches of the Lizard dataset.
37+
38+
The patches will be saved in directories:
39+
- `{save_dir}/train/*`
40+
- `{save_dir}/test/*`
41+
- `{save_dir}/valid/*`
42+
43+
References
44+
----------
45+
Graham, S., Jahanifar, M., Azam, A., Nimir, M., Tsang, Y.W., Dodd, K., Hero, E.,
46+
Sahota, H., Tank, A., Benes, K., & others (2021). Lizard: A Large-Scale Dataset
47+
for Colonic Nuclear Instance Segmentation and Classification. In Proceedings of
48+
the IEEE/CVF International Conference on Computer Vision (pp. 684-693).
49+
50+
Parameters
51+
----------
52+
save_dir : str
53+
Path to directory where the pannuke data will be saved.
54+
fold_split : Dict[str, int]
55+
Defines how the folds are split into train, valid, and test sets.
56+
E.g. {"train": 1, "valid": 2, "test": 3}
57+
img_transforms : List[str]
58+
A list containing all the transformations that are applied to the input
59+
images and corresponding masks. Allowed ones: "blur", "non_spatial",
60+
"non_rigid", "rigid", "hue_sat", "random_crop", "center_crop", "resize"
61+
inst_transforms : List[str]
62+
A list containg all the transformations that are applied to only the
63+
instance labelled masks. Allowed ones: "cellpose", "contour", "dist",
64+
"edgeweight", "hovernet", "omnipose", "smooth_dist", "binarize"
65+
dataset_type : str, default="folder"
66+
The dataset type. One of "folder", "hdf5".
67+
patch_size : Tuple[int, int], default=(256, 256)
68+
The size of the patch extracted from the images.
69+
stride : int, default=128
70+
The stride of the sliding window patcher.
71+
normalization : str, optional
72+
Apply img normalization after all the transformations. One of "minmax",
73+
"norm", "percentile", None.
74+
batch_size : int, default=8
75+
Batch size for the dataloader.
76+
num_workers : int, default=8
77+
number of cpu cores/threads used in the dataloading process.
78+
79+
Example
80+
-------
81+
>>> from pathlib import Path
82+
>>> from cellseg_models_pytorch.datamodules import LizardDataModule
83+
84+
>>> fold_split = {"train": 1, "valid": 2, "test": 3}
85+
>>> save_dir = Path.home() / "pannuke"
86+
>>> pannuke_module = PannukeDataModule(
87+
save_dir=save_dir,
88+
fold_split=fold_split,
89+
inst_transforms=["dist", "stardist"],
90+
img_transforms=["blur", "hue_sat"],
91+
normalization="percentile",
92+
n_rays=32
93+
)
94+
95+
>>> lizard_module = LizardDataModule(
96+
save_dir=save_dir,
97+
fold_split=fold_split,
98+
inst_transforms=["dist", "stardist"],
99+
img_transforms=["blur", "hue_sat"],
100+
normalization="percentile",
101+
dataset_type="hdf5"
102+
)
103+
104+
>>> # lizard_module.download(save_dir) # just the downloading
105+
>>> lizard_module.prepare_data(do_patching=True) # downloading & processing
106+
"""
107+
super().__init__(batch_size, num_workers)
108+
self.save_dir = Path(save_dir)
109+
self.fold_split = fold_split
110+
self.patch_size = patch_size
111+
self.stride = stride
112+
self.img_transforms = img_transforms
113+
self.inst_transforms = inst_transforms
114+
self.normalization = normalization
115+
self.kwargs = kwargs if kwargs is not None else {}
116+
117+
if dataset_type not in ("folder", "hdf5"):
118+
raise ValueError(
119+
f"Illegal `dataset_type` arg. Got {dataset_type}. "
120+
f"Allowed: {('folder', 'hdf5')}"
121+
)
122+
123+
self.dataset_type = dataset_type
124+
125+
@staticmethod
126+
def download(root: str) -> None:
127+
"""Download the lizard dataset from online."""
128+
for ix in [1, 2]:
129+
fn = f"lizard_images{ix}.zip"
130+
url = f"https://warwick.ac.uk/fac/cross_fac/tia/data/lizard/{fn}"
131+
SimpleDownloader.download(url, root)
132+
133+
url = "https://warwick.ac.uk/fac/cross_fac/tia/data/lizard/lizard_labels.zip"
134+
SimpleDownloader.download(url, root)
135+
LizardDataModule.extract_zips(root, rm=True)
136+
137+
def prepare_data(self, rm_orig: bool = False, do_patching: bool = True) -> None:
138+
"""Prepare the lizard datasets.
139+
140+
1. Download lizard folds from:
141+
"https://warwick.ac.uk/fac/cross_fac/tia/data/lizard/"
142+
2. split the images and masks into train, valid and test sets.
143+
3. Patch the images such that they can be used with the datasets/loaders.
144+
145+
Parameters
146+
----------
147+
do_patching : bool, default=True
148+
Flag, whether to do patching at all. Can be used if you only want to
149+
download and split the data and then work it out on your own.
150+
"""
151+
folders_found = [
152+
d.name
153+
for d in self.save_dir.iterdir()
154+
if d.name.lower() in ("lizard_images1", "lizard_images2", "lizard_labels")
155+
and d.is_dir()
156+
]
157+
phases_found = [
158+
d.name
159+
for d in self.save_dir.iterdir()
160+
if d.name in ("train", "test", "valid") and d.is_dir()
161+
]
162+
163+
patches_found = []
164+
if phases_found:
165+
patches_found = [
166+
sub_d.name
167+
for d in self.save_dir.iterdir()
168+
if d.name in ("train", "test", "valid") and d.is_dir()
169+
for sub_d in d.iterdir()
170+
if sub_d.name
171+
in (
172+
f"{d.name}_im_patches",
173+
f"{d.name}_mask_patches",
174+
f"{d.name}_patches",
175+
)
176+
and any(sub_d.iterdir())
177+
]
178+
179+
if len(folders_found) < 3 and not phases_found:
180+
print(
181+
"Found no data or an incomplete dataset. Downloading the whole thing..."
182+
)
183+
for d in self.save_dir.iterdir():
184+
shutil.rmtree(d)
185+
LizardDataModule.download(self.save_dir)
186+
else:
187+
print("Found all folds. Skip downloading.")
188+
189+
if not phases_found:
190+
print("Splitting the files into train, valid, and test sets.")
191+
for phase, fold_ix in self.fold_split.items():
192+
img_dir1 = self.save_dir / "Lizard_Images1"
193+
img_dir2 = self.save_dir / "Lizard_Images2"
194+
label_dir = self.save_dir / "Lizard_Labels"
195+
save_im_dir = self.save_dir / phase / "images"
196+
save_mask_dir = self.save_dir / phase / "labels"
197+
198+
self._split_to_fold(
199+
img_dir1,
200+
img_dir2,
201+
label_dir,
202+
save_im_dir,
203+
save_mask_dir,
204+
fold_ix,
205+
not rm_orig,
206+
)
207+
else:
208+
print(
209+
"Found splitted Lizard data. "
210+
"If in need of a re-download, please empty the `save_dir` folder."
211+
)
212+
213+
if rm_orig:
214+
for d in self.save_dir.iterdir():
215+
if "lizard" in d.name.lower() or "macosx" in d.name.lower():
216+
shutil.rmtree(d)
217+
218+
if do_patching and not patches_found:
219+
print("Patch the data... This will take a while...")
220+
for phase in self.fold_split.keys():
221+
save_im_dir = self.save_dir / phase / "images"
222+
save_mask_dir = self.save_dir / phase / "labels"
223+
224+
if self.dataset_type == "hdf5":
225+
sdir = self.save_dir / phase / f"{phase}_patches"
226+
sdir.mkdir(parents=True, exist_ok=True)
227+
writer = SlidingWindowHDF5Writer(
228+
in_dir_im=save_im_dir,
229+
in_dir_mask=save_mask_dir,
230+
save_dir=sdir,
231+
file_name=f"lizard_{phase}.h5",
232+
patch_size=self.patch_size,
233+
stride=self.stride,
234+
transforms=["rigid"],
235+
)
236+
else:
237+
sdir_im = self.save_dir / phase / f"{phase}_im_patches"
238+
sdir_mask = self.save_dir / phase / f"{phase}_mask_patches"
239+
sdir_im.mkdir(parents=True, exist_ok=True)
240+
sdir_mask.mkdir(parents=True, exist_ok=True)
241+
writer = SlidingWindowFolderWriter(
242+
in_dir_im=save_im_dir,
243+
in_dir_mask=save_mask_dir,
244+
save_dir_im=sdir_im,
245+
save_dir_mask=sdir_mask,
246+
patch_size=self.patch_size,
247+
stride=self.stride,
248+
transforms=["rigid"],
249+
)
250+
writer.write(pre_proc=self._process_label, msg=phase)
251+
else:
252+
print(
253+
"Found processed Lizard data. "
254+
"If in need of a re-process, please empty the `save_dir` folders."
255+
)
256+
257+
def _get_path(self, phase: str, dstype: str, is_mask: bool = False) -> Path:
258+
if dstype == "hdf5":
259+
p = self.save_dir / phase / f"{phase}_patches" / f"lizard_{phase}.h5"
260+
else:
261+
dtype = "mask" if is_mask else "im"
262+
p = self.save_dir / phase / f"{phase}_{dtype}_patches"
263+
264+
return p
265+
266+
def setup(self, stage: Optional[str] = None) -> None:
267+
"""Set up the train, valid, and test datasets."""
268+
if self.dataset_type == "hdf5":
269+
DS = SegmentationHDF5Dataset
270+
else:
271+
DS = SegmentationFolderDataset
272+
273+
self.trainset = DS(
274+
path=self._get_path("train", self.dataset_type, is_mask=False),
275+
mask_path=self._get_path("train", self.dataset_type, is_mask=True),
276+
img_transforms=self.img_transforms,
277+
inst_transforms=self.inst_transforms,
278+
return_sem=False,
279+
normalization=self.normalization,
280+
**self.kwargs,
281+
)
282+
283+
self.validset = DS(
284+
path=self._get_path("valid", self.dataset_type, is_mask=False),
285+
mask_path=self._get_path("valid", self.dataset_type, is_mask=True),
286+
img_transforms=self.img_transforms,
287+
inst_transforms=self.inst_transforms,
288+
return_sem=False,
289+
normalization=self.normalization,
290+
**self.kwargs,
291+
)
292+
293+
self.testset = DS(
294+
path=self._get_path("test", self.dataset_type, is_mask=False),
295+
mask_path=self._get_path("test", self.dataset_type, is_mask=True),
296+
img_transforms=self.img_transforms,
297+
inst_transforms=self.inst_transforms,
298+
return_sem=False,
299+
normalization=self.normalization,
300+
**self.kwargs,
301+
)
302+
303+
def _split_to_fold(
304+
self,
305+
img_dir1: Path,
306+
img_dir2: Path,
307+
label_dir: Path,
308+
save_im_dir: Path,
309+
save_mask_dir: Path,
310+
fold: int,
311+
copy: bool = True,
312+
) -> None:
313+
"""Move the downloaded data split into one of 'train', 'valid' or 'test' dir."""
314+
Path(save_im_dir).mkdir(parents=True, exist_ok=True)
315+
Path(save_mask_dir).mkdir(parents=True, exist_ok=True)
316+
info_path = label_dir / "info.csv"
317+
info = np.genfromtxt(info_path, dtype="str", delimiter=",", skip_header=True)
318+
info = info[info[:, -1] == str(fold)]
319+
for i in range(info.shape[0]):
320+
fn, _, _ = info[i]
321+
322+
p1 = img_dir1 / f"{fn}.png"
323+
p2 = img_dir2 / f"{fn}.png"
324+
src_im = p1 if p1.exists() else p2
325+
src_mask = label_dir / "Labels" / f"{fn}.mat"
326+
327+
if copy:
328+
shutil.copy(src_im, save_im_dir)
329+
shutil.copy(src_mask, save_mask_dir)
330+
else:
331+
src_im.rename(save_im_dir / src_im.name)
332+
src_mask.rename(save_mask_dir / src_mask.name)
333+
334+
def _process_label(self, label: np.ndarray) -> None:
335+
"""Process the labels.
336+
337+
NOTE: this is done to match the schema that's used by the Dataset classes.
338+
"""
339+
inst_map = label["inst_map"]
340+
classes = label["class"]
341+
nuclei_id = label["id"]
342+
343+
type_map = np.zeros_like(inst_map)
344+
unique_values = np.unique(inst_map).tolist()[1:] # remove 0
345+
nuclei_id = np.squeeze(nuclei_id).tolist()
346+
for value in unique_values:
347+
# Get the position of the corresponding value
348+
inst = np.copy(inst_map == value)
349+
idx = nuclei_id.index(value)
350+
351+
class_ = classes[idx]
352+
type_map[inst > 0] = class_
353+
354+
return {"inst_map": inst_map, "type_map": type_map}

0 commit comments

Comments
 (0)