Skip to content

Commit a0b7599

Browse files
committed
fix(datasets): fix dataset writers shortcomings
1 parent a2eecd2 commit a0b7599

File tree

5 files changed

+466
-158
lines changed

5 files changed

+466
-158
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from ._base_writer import BaseWriter
2-
from .folder_writer import SlidingWindowFolderWriter
2+
from .folder_writer import FolderWriter
33

4-
__all__ = ["BaseWriter", "SlidingWindowFolderWriter"]
4+
__all__ = ["BaseWriter", "FolderWriter"]

cellseg_models_pytorch/datasets/dataset_writers/_base_writer.py

Lines changed: 203 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pathos.multiprocessing import ThreadPool as Pool
77

88
from ...transforms.albu_transforms import IMG_TRANSFORMS, compose
9-
from ...utils import FileHandler, TilerStitcher, fix_duplicates
9+
from ...utils import FileHandler, TilerStitcher, fix_duplicates, remap_label
1010

1111
__all__ = ["BaseWriter"]
1212

@@ -18,48 +18,57 @@ class BaseWriter(ABC):
1818
def __init__(
1919
self,
2020
in_dir_im: str,
21-
in_dir_mask: str,
22-
patch_size: Tuple[int, int],
23-
stride: int,
21+
in_dir_mask: str = None,
22+
patch_size: Tuple[int, int] = None,
23+
stride: int = None,
2424
transforms: Optional[List[str]] = None,
2525
**kwargs,
2626
) -> None:
2727
"""Init base class for sliding window data writers."""
28-
self.im_dir = Path(in_dir_im)
29-
self.mask_dir = Path(in_dir_mask)
3028
self.stride = stride
29+
30+
if isinstance(patch_size, int):
31+
patch_size = (patch_size, patch_size)
32+
3133
self.patch_size = patch_size
34+
self.im_dir = Path(in_dir_im)
3235

36+
# Imgs
3337
if not self.im_dir.exists():
3438
raise ValueError(f"folder: {self.im_dir} does not exist")
3539

3640
if not self.im_dir.is_dir():
3741
raise ValueError(f"path: {self.im_dir} is not a folder")
3842

39-
if not all([f.suffix in IMG_SUFFIXES for f in self.im_dir.iterdir()]):
40-
raise ValueError(
41-
f"files formats in given folder need to be in {IMG_SUFFIXES}"
42-
)
43+
im_files = []
44+
for types in IMG_SUFFIXES:
45+
im_files.extend(self.im_dir.glob(f"*{types}"))
46+
self.fnames_im = sorted(im_files)
4347

44-
if not self.mask_dir.exists():
45-
raise ValueError(f"folder: {self.mask_dir} does not exist")
48+
# Masks
49+
self.mask_dir = in_dir_mask
50+
self.fnames_mask = None
51+
if in_dir_mask is not None:
52+
self.mask_dir = Path(in_dir_mask)
4653

47-
if not self.mask_dir.is_dir():
48-
raise ValueError(f"path: {self.mask_dir} is not a folder")
54+
if not self.mask_dir.exists():
55+
raise ValueError(f"folder: {self.mask_dir} does not exist")
4956

50-
if not all([f.suffix in MASK_SUFFIXES for f in self.mask_dir.iterdir()]):
51-
raise ValueError(
52-
f"files formats in given folder need to be in {MASK_SUFFIXES}"
53-
)
57+
if not self.mask_dir.is_dir():
58+
raise ValueError(f"path: {self.mask_dir} is not a folder")
5459

55-
self.fnames_im = sorted(self.im_dir.glob("*"))
56-
self.fnames_mask = sorted(self.mask_dir.glob("*"))
57-
if len(self.fnames_im) != len(self.fnames_mask):
58-
raise ValueError(
59-
f"Found different number of files in {self.im_dir.as_posix()} and "
60-
f"{self.mask_dir.as_posix()}."
61-
)
60+
mask_files = []
61+
for types in MASK_SUFFIXES:
62+
mask_files.extend(self.mask_dir.glob(f"*{types}"))
63+
self.fnames_mask = sorted(mask_files)
6264

65+
if len(self.fnames_im) != len(self.fnames_mask):
66+
raise ValueError(
67+
f"Found different number of files in {self.im_dir.as_posix()} and "
68+
f"{self.mask_dir.as_posix()}."
69+
)
70+
71+
# Transformations
6372
self.transforms = None
6473
if transforms is not None:
6574
allowed = list(IMG_TRANSFORMS.keys())
@@ -77,78 +86,200 @@ def write(self):
7786
"""Patch images and mask to and write them to disk."""
7887
raise NotImplementedError
7988

80-
def _get_tiles(
89+
def get_array(
8190
self,
8291
img_path: Union[str, Path],
83-
mask_path: Union[str, Path],
92+
mask_path: Optional[Union[str, Path]] = None,
93+
tiling: Optional[bool] = False,
94+
pre_proc: Optional[Callable] = None,
95+
) -> Tuple[np.ndarray, Union[None, Dict[str, np.ndarray]]]:
96+
"""Pipeline that (optionally) patches and transforms input images and masks.
97+
98+
Parameters
99+
----------
100+
img_path : str or Path
101+
Path to an image file.
102+
mask_path : str or Path, optional
103+
Path to a .mat mask file.
104+
tiling : bool, default=False, optional
105+
Flag, whether to do tiling on the images (and masks).
106+
pre_proc : Callable, optional
107+
A pre-processing function that can be used to pre-process given input
108+
masks before the pipeline.
109+
110+
Raises
111+
------
112+
ValueError if self.stride or self.patch_size are not set to integer values.
113+
114+
Returns
115+
-------
116+
Tuple[np.ndarray, Union[None, Dict[str, np.ndarray]]]:
117+
The processed image & masks. If `mask_path=None`, returns no masks.
118+
Img shape w/o tiling: (H, W, C). Dtype: uint8.
119+
Img shape w/ tiling: (N, pH, pW, C). Dtype: uint8.
120+
Masks w/o tiling: Shapes: (H, W). Dtypes: int32.
121+
Masks w/ tiling: Shapes: (N, pH, pW). Dtypes: int32.
122+
"""
123+
im, masks = self._read_files(img_path, mask_path, pre_proc)
124+
125+
# do tiling first if flag set to True
126+
if tiling:
127+
if not isinstance(self.stride, int) and not isinstance(
128+
self.patch_size, int
129+
):
130+
raise ValueError(
131+
"`self.stride` and `self.patch_size` need to be integers. Got: "
132+
f"self.stride={self.stride}, self.patch_size={self.patch_size}"
133+
)
134+
135+
im, masks = self._get_tiles(im, masks)
136+
137+
if masks is not None:
138+
if "inst_map" in masks.keys():
139+
masks["inst_map"] = self._fix_instances_tiles(masks["inst_map"])
140+
141+
if self.transforms is not None:
142+
im, masks = self._transform_tiles(im, masks)
143+
else:
144+
if masks is not None:
145+
if "inst_map" in masks.keys():
146+
masks["inst_map"] = self._fix_instances_one(masks["inst_map"])
147+
148+
if self.transforms is not None:
149+
im, masks = self._transform_one(im, masks)
150+
151+
return im, masks
152+
153+
def _fix_instances_one(self, inst_map: np.ndarray) -> np.ndarray:
154+
"""Fix duplicate instances and remap instance labels."""
155+
return remap_label(fix_duplicates(inst_map))
156+
157+
def _read_files(
158+
self,
159+
img_path: Union[str, Path],
160+
mask_path: Union[str, Path] = None,
84161
pre_proc: Callable = None,
85-
) -> Dict[str, np.ndarray]:
86-
"""Read one image and corresponding masks and do tiling on them."""
87-
# im, masks = self._get_arrays()
162+
) -> Tuple[np.ndarray, Union[None, Dict[str, np.ndarray]]]:
163+
"""Read image and corresponding masks if there are such."""
88164
im = FileHandler.read_img(img_path)
89-
masks = FileHandler.read_mat(mask_path, return_all=True)
90165

91-
if pre_proc is not None:
92-
masks = pre_proc(masks)
166+
masks = None
167+
if mask_path is not None:
168+
masks = FileHandler.read_mat(mask_path, return_all=True)
169+
170+
if pre_proc is not None:
171+
masks = pre_proc(masks)
93172

94-
inst = None
95-
types = None
96-
sem = None
97-
if "inst_map" in masks.keys():
98-
inst = masks["inst_map"]
99-
if "type_map" in masks.keys():
100-
types = masks["type_map"]
101-
if "sem_map" in masks.keys():
102-
sem = masks["sem_map"]
173+
masks = {
174+
key: arr
175+
for key, arr in masks.items()
176+
if key in ("inst_map", "type_map", "sem_map")
177+
}
103178

179+
return im, masks
180+
181+
def _get_tiles(
182+
self,
183+
im: np.ndarray,
184+
masks: Union[Dict[str, np.ndarray], None] = None,
185+
) -> Tuple[Dict[str, np.ndarray], Union[Dict[str, np.ndarray], None]]:
186+
"""Do tiling on an image and corresponding masks if there are such."""
187+
# Init Tilers
104188
im_tiler = TilerStitcher(
105189
im_shape=im.shape, patch_shape=self.patch_size + (3,), stride=self.stride
106190
)
191+
im_tiles = im_tiler.patch(im)
192+
193+
# Tile masks if there are such.
194+
mask_tiles = None
195+
if masks is not None:
196+
mask_tiles = {}
197+
inst = None
198+
types = None
199+
sem = None
200+
if "inst_map" in masks.keys():
201+
inst = masks["inst_map"]
202+
if "type_map" in masks.keys():
203+
types = masks["type_map"]
204+
if "sem_map" in masks.keys():
205+
sem = masks["sem_map"]
206+
207+
mask_tiler = TilerStitcher(
208+
im_shape=inst.shape,
209+
patch_shape=self.patch_size + (1,),
210+
stride=self.stride,
211+
)
107212

108-
mask_tiler = TilerStitcher(
109-
im_shape=inst.shape, patch_shape=self.patch_size + (1,), stride=self.stride
110-
)
213+
if inst is not None:
214+
mask_tiles["inst_map"] = mask_tiler.patch(inst).squeeze()
215+
if types is not None:
216+
mask_tiles["type_map"] = mask_tiler.patch(types).squeeze()
217+
if sem is not None:
218+
mask_tiles["sem_map"] = mask_tiler.patch(sem).squeeze()
219+
220+
return im_tiles, mask_tiles
111221

112-
tiles = {}
113-
tiles["image"] = im_tiler.patch(im)
114-
if inst is not None:
115-
tiles["inst_map"] = self._fix_duplicates(mask_tiler.patch(inst).squeeze())
116-
if types is not None:
117-
tiles["type_map"] = mask_tiler.patch(types).squeeze()
118-
if sem is not None:
119-
tiles["sem_map"] = mask_tiler.patch(sem).squeeze()
222+
def _transform_one(
223+
self, im: np.ndarray, masks: Dict[str, np.ndarray] = None
224+
) -> Tuple[np.ndarray, Union[Dict[str, np.ndarray], None]]:
225+
"""Transform an image and corresponding mask if there is one."""
226+
if masks is not None:
227+
mask_names = [name for name in masks.keys()]
228+
masks = [mask for mask in masks.values()]
229+
out = self.transforms(image=im, masks=masks)
230+
masks = {n: mask for n, mask in zip(mask_names, out["masks"])}
231+
else:
232+
out = self.transforms(image=im)
120233

121-
if self.transforms is not None:
122-
tiles = self._transform(tiles)
234+
im = out["image"]
123235

124-
return tiles
236+
return im, masks
125237

126-
def _transform(self, tiles: Dict[str, np.ndarray]) -> np.ndarray:
238+
def _transform_tiles(
239+
self,
240+
im_tiles: np.ndarray,
241+
mask_tiles: Union[Dict[str, np.ndarray], None] = None,
242+
) -> Tuple[np.ndarray, Union[Dict[str, np.ndarray], None]]:
127243
"""Apply transformations to the tiles one by one."""
128-
n_tiles = tiles["image"].shape[0]
129-
masks = [arr for key, arr in tiles.items() if key != "image"]
130-
mask_names = [key for key in tiles.keys() if key != "image"]
244+
n_tiles = im_tiles.shape[0]
245+
out_im_tiles = []
246+
247+
out_mask_tiles = None
248+
if mask_tiles is not None:
249+
mask_names = [key for key in mask_tiles.keys()]
250+
out_mask_tiles = {k: [] for k in mask_tiles.keys()}
131251

132-
out_tiles = {k: [] for k in tiles.keys()}
133252
for i in range(n_tiles):
134-
m = [mask[i] for mask in masks]
135-
out = self.transforms(image=tiles["image"][i], masks=m)
136-
out_tiles["image"].append(out["image"])
253+
# Get one img tile
254+
im = im_tiles[i]
255+
256+
# Get one set of mask tiles
257+
masks = None
258+
if mask_tiles is not None:
259+
masks = {n: mask_tiles[n][i] for n in mask_names}
260+
261+
# transform imgs & masks
262+
im_tr, masks_tr = self._transform_one(im, masks)
137263

138-
for j, mname in enumerate(mask_names):
139-
out_tiles[mname].append(out["masks"][j])
264+
out_im_tiles.append(im_tr)
265+
if mask_tiles is not None:
266+
for mask_name in mask_names:
267+
out_mask_tiles[mask_name].append(masks_tr[mask_name])
140268

141-
for k, mask in out_tiles.items():
142-
out_tiles[k] = np.array(mask)
269+
# convert list of 2D-arrays to np.ndarray
270+
out_im_tiles = np.array(out_im_tiles)
271+
if mask_tiles is not None:
272+
for k, arr in out_mask_tiles.items():
273+
out_mask_tiles[k] = np.array(arr)
143274

144-
return out_tiles
275+
return out_im_tiles, out_mask_tiles
145276

146-
def _fix_duplicates(self, patches_inst: np.ndarray) -> np.ndarray:
147-
"""Fix repeatded labels in a patched instance labelled mask."""
277+
def _fix_instances_tiles(self, patches_inst: np.ndarray) -> np.ndarray:
278+
"""Fix repeated labels and remap them in a patched instance labelled mask."""
148279
insts = []
149280

150281
for i in range(patches_inst.shape[0]):
151-
insts.append(fix_duplicates(patches_inst[i]))
282+
insts.append(self._fix_instances_one(patches_inst[i]))
152283

153284
insts = np.array(insts)
154285

0 commit comments

Comments
 (0)