8
8
from ..utils import FileHandler , fix_duplicates
9
9
10
10
try :
11
- from ..datasets import SegmentationFolderDataset
11
+ from ..datasets import SegmentationFolderDataset , SegmentationHDF5Dataset
12
+ from ..datasets .dataset_writers .hdf5_writer import HDF5Writer
12
13
from ._basemodule import BaseDataModule
13
14
from .downloader import SimpleDownloader
14
15
except ModuleNotFoundError :
@@ -26,6 +27,7 @@ def __init__(
26
27
fold_split : Dict [str , int ],
27
28
img_transforms : List [str ],
28
29
inst_transforms : List [str ],
30
+ dataset_type : str = "folder" ,
29
31
normalization : str = None ,
30
32
batch_size : int = 8 ,
31
33
num_workers : int = 8 ,
@@ -65,6 +67,8 @@ def __init__(
65
67
A list containg all the transformations that are applied to only the
66
68
instance labelled masks. Allowed ones: "cellpose", "contour", "dist",
67
69
"edgeweight", "hovernet", "omnipose", "smooth_dist", "binarize"
70
+ dataset_type : str, default="folder"
71
+ The dataset type. One of "folder", "hdf5".
68
72
normalization : str, optional
69
73
Apply img normalization after all the transformations. One of "minmax",
70
74
"norm", "percentile", None.
@@ -107,6 +111,14 @@ def __init__(
107
111
self .normalization = normalization
108
112
self .kwargs = kwargs if kwargs is not None else {}
109
113
114
+ if dataset_type not in ("folder" , "hdf5" ):
115
+ raise ValueError (
116
+ f"Illegal `dataset_type` arg. Got { dataset_type } . "
117
+ f"Allowed: { ('folder' , 'hdf5' )} "
118
+ )
119
+
120
+ self .dataset_type = dataset_type
121
+
110
122
@property
111
123
def type_classes (self ) -> Dict [str , int ]:
112
124
"""Pannuke cell type classes."""
@@ -127,7 +139,7 @@ def download(root: str) -> None:
127
139
SimpleDownloader .download (url , root )
128
140
PannukeDataModule .extract_zips (root , rm = True )
129
141
130
- def prepare_data (self , rm_orig : bool = True ) -> None :
142
+ def prepare_data (self , rm_orig : bool = False ) -> None :
131
143
"""Prepare the pannuke datasets.
132
144
133
145
1. Download pannuke folds from:
@@ -167,6 +179,18 @@ def prepare_data(self, rm_orig: bool = True) -> None:
167
179
self ._process_pannuke_fold (
168
180
fold_paths , save_im_dir , save_mask_dir , fold_ix , phase
169
181
)
182
+
183
+ if self .dataset_type == "hdf5" :
184
+ writer = HDF5Writer (
185
+ in_dir_im = save_im_dir ,
186
+ in_dir_mask = save_mask_dir ,
187
+ save_dir = self .save_dir / phase ,
188
+ file_name = f"pannuke_{ phase } .h5" ,
189
+ patch_size = None ,
190
+ stride = None ,
191
+ transforms = None ,
192
+ )
193
+ writer .write (tiling = False , msg = phase )
170
194
else :
171
195
print (
172
196
"Found processed pannuke data. "
@@ -178,31 +202,45 @@ def prepare_data(self, rm_orig: bool = True) -> None:
178
202
if "fold" in d .name .lower ():
179
203
shutil .rmtree (d )
180
204
205
+ def _get_path (self , phase : str , dstype : str , is_mask : bool = False ) -> Path :
206
+ if dstype == "hdf5" :
207
+ p = self .save_dir / phase / f"pannuke_{ phase } .h5"
208
+ else :
209
+ dtype = "labels" if is_mask else "images"
210
+ p = self .save_dir / phase / dtype
211
+
212
+ return p
213
+
181
214
def setup (self , stage : Optional [str ] = None ) -> None :
182
215
"""Set up the train, valid, and test datasets."""
183
- self .trainset = SegmentationFolderDataset (
184
- path = self .save_dir / "train" / "images" ,
185
- mask_path = self .save_dir / "train" / "labels" ,
216
+ if self .dataset_type == "hdf5" :
217
+ DS = SegmentationHDF5Dataset
218
+ else :
219
+ DS = SegmentationFolderDataset
220
+
221
+ self .trainset = DS (
222
+ path = self ._get_path ("train" , self .dataset_type , is_mask = False ),
223
+ mask_path = self ._get_path ("train" , self .dataset_type , is_mask = True ),
186
224
img_transforms = self .img_transforms ,
187
225
inst_transforms = self .inst_transforms ,
188
226
return_sem = False ,
189
227
normalization = self .normalization ,
190
228
** self .kwargs ,
191
229
)
192
230
193
- self .validset = SegmentationFolderDataset (
194
- path = self .save_dir / "valid" / "images" ,
195
- mask_path = self .save_dir / "valid" / "labels" ,
231
+ self .validset = DS (
232
+ path = self ._get_path ( "valid" , self . dataset_type , is_mask = False ) ,
233
+ mask_path = self ._get_path ( "valid" , self . dataset_type , is_mask = True ) ,
196
234
img_transforms = self .img_transforms ,
197
235
inst_transforms = self .inst_transforms ,
198
236
return_sem = False ,
199
237
normalization = self .normalization ,
200
238
** self .kwargs ,
201
239
)
202
240
203
- self .testset = SegmentationFolderDataset (
204
- path = self .save_dir / "test" / "images" ,
205
- mask_path = self .save_dir / "test" / "labels" ,
241
+ self .testset = DS (
242
+ path = self ._get_path ( "test" , self . dataset_type , is_mask = False ) ,
243
+ mask_path = self ._get_path ( "test" , self . dataset_type , is_mask = True ) ,
206
244
img_transforms = self .img_transforms ,
207
245
inst_transforms = self .inst_transforms ,
208
246
return_sem = False ,
@@ -256,7 +294,7 @@ def _process_pannuke_fold(
256
294
inst_map = self ._get_inst_map (temp_mask [..., 0 :5 ])
257
295
258
296
fn_mask = Path (save_mask_dir / name ).with_suffix (".mat" )
259
- FileHandler .write_mask (fn_mask , inst_map , type_map )
297
+ FileHandler .write_mat (fn_mask , inst_map , type_map )
260
298
pbar .update (1 )
261
299
262
300
def _get_type_map (self , pannuke_mask : np .ndarray ) -> np .ndarray :
0 commit comments