|
| 1 | +import shutil |
| 2 | +from pathlib import Path |
| 3 | +from typing import Dict, List, Optional, Tuple |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | +try: |
| 8 | + from ..datasets import SegmentationFolderDataset, SegmentationHDF5Dataset |
| 9 | + from ..datasets.dataset_writers.folder_writer import SlidingWindowFolderWriter |
| 10 | + from ..datasets.dataset_writers.hdf5_writer import SlidingWindowHDF5Writer |
| 11 | + from ._basemodule import BaseDataModule |
| 12 | + from .downloader import SimpleDownloader |
| 13 | +except ModuleNotFoundError: |
| 14 | + raise ModuleNotFoundError( |
| 15 | + "To use the LizardDataModule, requests, pytorch-lightning, & albumentations " |
| 16 | + "libraries are needed. Install with " |
| 17 | + "`pip install requests pytorch-lightning albumentations`" |
| 18 | + ) |
| 19 | + |
| 20 | + |
| 21 | +class LizardDataModule(BaseDataModule): |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + save_dir: str, |
| 25 | + fold_split: Dict[str, int], |
| 26 | + img_transforms: List[str], |
| 27 | + inst_transforms: List[str], |
| 28 | + dataset_type: str = "folder", |
| 29 | + patch_size: Tuple[int, int] = (256, 256), |
| 30 | + stride: int = 128, |
| 31 | + normalization: str = None, |
| 32 | + batch_size: int = 8, |
| 33 | + num_workers: int = 8, |
| 34 | + **kwargs, |
| 35 | + ) -> None: |
| 36 | + """Set up Lizard datamodule. Creates overlapping patches of the Lizard dataset. |
| 37 | +
|
| 38 | + The patches will be saved in directories: |
| 39 | + - `{save_dir}/train/*` |
| 40 | + - `{save_dir}/test/*` |
| 41 | + - `{save_dir}/valid/*` |
| 42 | +
|
| 43 | + References |
| 44 | + ---------- |
| 45 | + Graham, S., Jahanifar, M., Azam, A., Nimir, M., Tsang, Y.W., Dodd, K., Hero, E., |
| 46 | + Sahota, H., Tank, A., Benes, K., & others (2021). Lizard: A Large-Scale Dataset |
| 47 | + for Colonic Nuclear Instance Segmentation and Classification. In Proceedings of |
| 48 | + the IEEE/CVF International Conference on Computer Vision (pp. 684-693). |
| 49 | +
|
| 50 | + Parameters |
| 51 | + ---------- |
| 52 | + save_dir : str |
| 53 | + Path to directory where the pannuke data will be saved. |
| 54 | + fold_split : Dict[str, int] |
| 55 | + Defines how the folds are split into train, valid, and test sets. |
| 56 | + E.g. {"train": 1, "valid": 2, "test": 3} |
| 57 | + img_transforms : List[str] |
| 58 | + A list containing all the transformations that are applied to the input |
| 59 | + images and corresponding masks. Allowed ones: "blur", "non_spatial", |
| 60 | + "non_rigid", "rigid", "hue_sat", "random_crop", "center_crop", "resize" |
| 61 | + inst_transforms : List[str] |
| 62 | + A list containg all the transformations that are applied to only the |
| 63 | + instance labelled masks. Allowed ones: "cellpose", "contour", "dist", |
| 64 | + "edgeweight", "hovernet", "omnipose", "smooth_dist", "binarize" |
| 65 | + dataset_type : str, default="folder" |
| 66 | + The dataset type. One of "folder", "hdf5". |
| 67 | + patch_size : Tuple[int, int], default=(256, 256) |
| 68 | + The size of the patch extracted from the images. |
| 69 | + stride : int, default=128 |
| 70 | + The stride of the sliding window patcher. |
| 71 | + normalization : str, optional |
| 72 | + Apply img normalization after all the transformations. One of "minmax", |
| 73 | + "norm", "percentile", None. |
| 74 | + batch_size : int, default=8 |
| 75 | + Batch size for the dataloader. |
| 76 | + num_workers : int, default=8 |
| 77 | + number of cpu cores/threads used in the dataloading process. |
| 78 | +
|
| 79 | + Example |
| 80 | + ------- |
| 81 | + >>> from pathlib import Path |
| 82 | + >>> from cellseg_models_pytorch.datamodules import LizardDataModule |
| 83 | +
|
| 84 | + >>> fold_split = {"train": 1, "valid": 2, "test": 3} |
| 85 | + >>> save_dir = Path.home() / "pannuke" |
| 86 | + >>> pannuke_module = PannukeDataModule( |
| 87 | + save_dir=save_dir, |
| 88 | + fold_split=fold_split, |
| 89 | + inst_transforms=["dist", "stardist"], |
| 90 | + img_transforms=["blur", "hue_sat"], |
| 91 | + normalization="percentile", |
| 92 | + n_rays=32 |
| 93 | + ) |
| 94 | +
|
| 95 | + >>> lizard_module = LizardDataModule( |
| 96 | + save_dir=save_dir, |
| 97 | + fold_split=fold_split, |
| 98 | + inst_transforms=["dist", "stardist"], |
| 99 | + img_transforms=["blur", "hue_sat"], |
| 100 | + normalization="percentile", |
| 101 | + dataset_type="hdf5" |
| 102 | + ) |
| 103 | +
|
| 104 | + >>> # lizard_module.download(save_dir) # just the downloading |
| 105 | + >>> lizard_module.prepare_data(do_patching=True) # downloading & processing |
| 106 | + """ |
| 107 | + super().__init__(batch_size, num_workers) |
| 108 | + self.save_dir = Path(save_dir) |
| 109 | + self.fold_split = fold_split |
| 110 | + self.patch_size = patch_size |
| 111 | + self.stride = stride |
| 112 | + self.img_transforms = img_transforms |
| 113 | + self.inst_transforms = inst_transforms |
| 114 | + self.normalization = normalization |
| 115 | + self.kwargs = kwargs if kwargs is not None else {} |
| 116 | + |
| 117 | + if dataset_type not in ("folder", "hdf5"): |
| 118 | + raise ValueError( |
| 119 | + f"Illegal `dataset_type` arg. Got {dataset_type}. " |
| 120 | + f"Allowed: {('folder', 'hdf5')}" |
| 121 | + ) |
| 122 | + |
| 123 | + self.dataset_type = dataset_type |
| 124 | + |
| 125 | + @staticmethod |
| 126 | + def download(root: str) -> None: |
| 127 | + """Download the lizard dataset from online.""" |
| 128 | + for ix in [1, 2]: |
| 129 | + fn = f"lizard_images{ix}.zip" |
| 130 | + url = f"https://warwick.ac.uk/fac/cross_fac/tia/data/lizard/{fn}" |
| 131 | + SimpleDownloader.download(url, root) |
| 132 | + |
| 133 | + url = "https://warwick.ac.uk/fac/cross_fac/tia/data/lizard/lizard_labels.zip" |
| 134 | + SimpleDownloader.download(url, root) |
| 135 | + LizardDataModule.extract_zips(root, rm=True) |
| 136 | + |
| 137 | + def prepare_data(self, rm_orig: bool = False, do_patching: bool = True) -> None: |
| 138 | + """Prepare the lizard datasets. |
| 139 | +
|
| 140 | + 1. Download lizard folds from: |
| 141 | + "https://warwick.ac.uk/fac/cross_fac/tia/data/lizard/" |
| 142 | + 2. split the images and masks into train, valid and test sets. |
| 143 | + 3. Patch the images such that they can be used with the datasets/loaders. |
| 144 | +
|
| 145 | + Parameters |
| 146 | + ---------- |
| 147 | + do_patching : bool, default=True |
| 148 | + Flag, whether to do patching at all. Can be used if you only want to |
| 149 | + download and split the data and then work it out on your own. |
| 150 | + """ |
| 151 | + folders_found = [ |
| 152 | + d.name |
| 153 | + for d in self.save_dir.iterdir() |
| 154 | + if d.name.lower() in ("lizard_images1", "lizard_images2", "lizard_labels") |
| 155 | + and d.is_dir() |
| 156 | + ] |
| 157 | + phases_found = [ |
| 158 | + d.name |
| 159 | + for d in self.save_dir.iterdir() |
| 160 | + if d.name in ("train", "test", "valid") and d.is_dir() |
| 161 | + ] |
| 162 | + |
| 163 | + patches_found = [] |
| 164 | + if phases_found: |
| 165 | + patches_found = [ |
| 166 | + sub_d.name |
| 167 | + for d in self.save_dir.iterdir() |
| 168 | + if d.name in ("train", "test", "valid") and d.is_dir() |
| 169 | + for sub_d in d.iterdir() |
| 170 | + if sub_d.name |
| 171 | + in ( |
| 172 | + f"{d.name}_im_patches", |
| 173 | + f"{d.name}_mask_patches", |
| 174 | + f"{d.name}_patches", |
| 175 | + ) |
| 176 | + and any(sub_d.iterdir()) |
| 177 | + ] |
| 178 | + |
| 179 | + if len(folders_found) < 3 and not phases_found: |
| 180 | + print( |
| 181 | + "Found no data or an incomplete dataset. Downloading the whole thing..." |
| 182 | + ) |
| 183 | + for d in self.save_dir.iterdir(): |
| 184 | + shutil.rmtree(d) |
| 185 | + LizardDataModule.download(self.save_dir) |
| 186 | + else: |
| 187 | + print("Found all folds. Skip downloading.") |
| 188 | + |
| 189 | + if not phases_found: |
| 190 | + print("Splitting the files into train, valid, and test sets.") |
| 191 | + for phase, fold_ix in self.fold_split.items(): |
| 192 | + img_dir1 = self.save_dir / "Lizard_Images1" |
| 193 | + img_dir2 = self.save_dir / "Lizard_Images2" |
| 194 | + label_dir = self.save_dir / "Lizard_Labels" |
| 195 | + save_im_dir = self.save_dir / phase / "images" |
| 196 | + save_mask_dir = self.save_dir / phase / "labels" |
| 197 | + |
| 198 | + self._split_to_fold( |
| 199 | + img_dir1, |
| 200 | + img_dir2, |
| 201 | + label_dir, |
| 202 | + save_im_dir, |
| 203 | + save_mask_dir, |
| 204 | + fold_ix, |
| 205 | + not rm_orig, |
| 206 | + ) |
| 207 | + else: |
| 208 | + print( |
| 209 | + "Found splitted Lizard data. " |
| 210 | + "If in need of a re-download, please empty the `save_dir` folder." |
| 211 | + ) |
| 212 | + |
| 213 | + if rm_orig: |
| 214 | + for d in self.save_dir.iterdir(): |
| 215 | + if "lizard" in d.name.lower() or "macosx" in d.name.lower(): |
| 216 | + shutil.rmtree(d) |
| 217 | + |
| 218 | + if do_patching and not patches_found: |
| 219 | + print("Patch the data... This will take a while...") |
| 220 | + for phase in self.fold_split.keys(): |
| 221 | + save_im_dir = self.save_dir / phase / "images" |
| 222 | + save_mask_dir = self.save_dir / phase / "labels" |
| 223 | + |
| 224 | + if self.dataset_type == "hdf5": |
| 225 | + sdir = self.save_dir / phase / f"{phase}_patches" |
| 226 | + sdir.mkdir(parents=True, exist_ok=True) |
| 227 | + writer = SlidingWindowHDF5Writer( |
| 228 | + in_dir_im=save_im_dir, |
| 229 | + in_dir_mask=save_mask_dir, |
| 230 | + save_dir=sdir, |
| 231 | + file_name=f"lizard_{phase}.h5", |
| 232 | + patch_size=self.patch_size, |
| 233 | + stride=self.stride, |
| 234 | + transforms=["rigid"], |
| 235 | + ) |
| 236 | + else: |
| 237 | + sdir_im = self.save_dir / phase / f"{phase}_im_patches" |
| 238 | + sdir_mask = self.save_dir / phase / f"{phase}_mask_patches" |
| 239 | + sdir_im.mkdir(parents=True, exist_ok=True) |
| 240 | + sdir_mask.mkdir(parents=True, exist_ok=True) |
| 241 | + writer = SlidingWindowFolderWriter( |
| 242 | + in_dir_im=save_im_dir, |
| 243 | + in_dir_mask=save_mask_dir, |
| 244 | + save_dir_im=sdir_im, |
| 245 | + save_dir_mask=sdir_mask, |
| 246 | + patch_size=self.patch_size, |
| 247 | + stride=self.stride, |
| 248 | + transforms=["rigid"], |
| 249 | + ) |
| 250 | + writer.write(pre_proc=self._process_label, msg=phase) |
| 251 | + else: |
| 252 | + print( |
| 253 | + "Found processed Lizard data. " |
| 254 | + "If in need of a re-process, please empty the `save_dir` folders." |
| 255 | + ) |
| 256 | + |
| 257 | + def _get_path(self, phase: str, dstype: str, is_mask: bool = False) -> Path: |
| 258 | + if dstype == "hdf5": |
| 259 | + p = self.save_dir / phase / f"{phase}_patches" / f"lizard_{phase}.h5" |
| 260 | + else: |
| 261 | + dtype = "mask" if is_mask else "im" |
| 262 | + p = self.save_dir / phase / f"{phase}_{dtype}_patches" |
| 263 | + |
| 264 | + return p |
| 265 | + |
| 266 | + def setup(self, stage: Optional[str] = None) -> None: |
| 267 | + """Set up the train, valid, and test datasets.""" |
| 268 | + if self.dataset_type == "hdf5": |
| 269 | + DS = SegmentationHDF5Dataset |
| 270 | + else: |
| 271 | + DS = SegmentationFolderDataset |
| 272 | + |
| 273 | + self.trainset = DS( |
| 274 | + path=self._get_path("train", self.dataset_type, is_mask=False), |
| 275 | + mask_path=self._get_path("train", self.dataset_type, is_mask=True), |
| 276 | + img_transforms=self.img_transforms, |
| 277 | + inst_transforms=self.inst_transforms, |
| 278 | + return_sem=False, |
| 279 | + normalization=self.normalization, |
| 280 | + **self.kwargs, |
| 281 | + ) |
| 282 | + |
| 283 | + self.validset = DS( |
| 284 | + path=self._get_path("valid", self.dataset_type, is_mask=False), |
| 285 | + mask_path=self._get_path("valid", self.dataset_type, is_mask=True), |
| 286 | + img_transforms=self.img_transforms, |
| 287 | + inst_transforms=self.inst_transforms, |
| 288 | + return_sem=False, |
| 289 | + normalization=self.normalization, |
| 290 | + **self.kwargs, |
| 291 | + ) |
| 292 | + |
| 293 | + self.testset = DS( |
| 294 | + path=self._get_path("test", self.dataset_type, is_mask=False), |
| 295 | + mask_path=self._get_path("test", self.dataset_type, is_mask=True), |
| 296 | + img_transforms=self.img_transforms, |
| 297 | + inst_transforms=self.inst_transforms, |
| 298 | + return_sem=False, |
| 299 | + normalization=self.normalization, |
| 300 | + **self.kwargs, |
| 301 | + ) |
| 302 | + |
| 303 | + def _split_to_fold( |
| 304 | + self, |
| 305 | + img_dir1: Path, |
| 306 | + img_dir2: Path, |
| 307 | + label_dir: Path, |
| 308 | + save_im_dir: Path, |
| 309 | + save_mask_dir: Path, |
| 310 | + fold: int, |
| 311 | + copy: bool = True, |
| 312 | + ) -> None: |
| 313 | + """Move the downloaded data split into one of 'train', 'valid' or 'test' dir.""" |
| 314 | + Path(save_im_dir).mkdir(parents=True, exist_ok=True) |
| 315 | + Path(save_mask_dir).mkdir(parents=True, exist_ok=True) |
| 316 | + info_path = label_dir / "info.csv" |
| 317 | + info = np.genfromtxt(info_path, dtype="str", delimiter=",", skip_header=True) |
| 318 | + info = info[info[:, -1] == str(fold)] |
| 319 | + for i in range(info.shape[0]): |
| 320 | + fn, _, _ = info[i] |
| 321 | + |
| 322 | + p1 = img_dir1 / f"{fn}.png" |
| 323 | + p2 = img_dir2 / f"{fn}.png" |
| 324 | + src_im = p1 if p1.exists() else p2 |
| 325 | + src_mask = label_dir / "Labels" / f"{fn}.mat" |
| 326 | + |
| 327 | + if copy: |
| 328 | + shutil.copy(src_im, save_im_dir) |
| 329 | + shutil.copy(src_mask, save_mask_dir) |
| 330 | + else: |
| 331 | + src_im.rename(save_im_dir / src_im.name) |
| 332 | + src_mask.rename(save_mask_dir / src_mask.name) |
| 333 | + |
| 334 | + def _process_label(self, label: np.ndarray) -> None: |
| 335 | + """Process the labels. |
| 336 | +
|
| 337 | + NOTE: this is done to match the schema that's used by the Dataset classes. |
| 338 | + """ |
| 339 | + inst_map = label["inst_map"] |
| 340 | + classes = label["class"] |
| 341 | + nuclei_id = label["id"] |
| 342 | + |
| 343 | + type_map = np.zeros_like(inst_map) |
| 344 | + unique_values = np.unique(inst_map).tolist()[1:] # remove 0 |
| 345 | + nuclei_id = np.squeeze(nuclei_id).tolist() |
| 346 | + for value in unique_values: |
| 347 | + # Get the position of the corresponding value |
| 348 | + inst = np.copy(inst_map == value) |
| 349 | + idx = nuclei_id.index(value) |
| 350 | + |
| 351 | + class_ = classes[idx] |
| 352 | + type_map[inst > 0] = class_ |
| 353 | + |
| 354 | + return {"inst_map": inst_map, "type_map": type_map} |
0 commit comments