6
6
from pathos .multiprocessing import ThreadPool as Pool
7
7
8
8
from ...transforms .albu_transforms import IMG_TRANSFORMS , compose
9
- from ...utils import FileHandler , TilerStitcher , fix_duplicates
9
+ from ...utils import FileHandler , TilerStitcher , fix_duplicates , remap_label
10
10
11
11
__all__ = ["BaseWriter" ]
12
12
@@ -18,48 +18,57 @@ class BaseWriter(ABC):
18
18
def __init__ (
19
19
self ,
20
20
in_dir_im : str ,
21
- in_dir_mask : str ,
22
- patch_size : Tuple [int , int ],
23
- stride : int ,
21
+ in_dir_mask : str = None ,
22
+ patch_size : Tuple [int , int ] = None ,
23
+ stride : int = None ,
24
24
transforms : Optional [List [str ]] = None ,
25
25
** kwargs ,
26
26
) -> None :
27
27
"""Init base class for sliding window data writers."""
28
- self .im_dir = Path (in_dir_im )
29
- self .mask_dir = Path (in_dir_mask )
30
28
self .stride = stride
29
+
30
+ if isinstance (patch_size , int ):
31
+ patch_size = (patch_size , patch_size )
32
+
31
33
self .patch_size = patch_size
34
+ self .im_dir = Path (in_dir_im )
32
35
36
+ # Imgs
33
37
if not self .im_dir .exists ():
34
38
raise ValueError (f"folder: { self .im_dir } does not exist" )
35
39
36
40
if not self .im_dir .is_dir ():
37
41
raise ValueError (f"path: { self .im_dir } is not a folder" )
38
42
39
- if not all ([ f . suffix in IMG_SUFFIXES for f in self . im_dir . iterdir ()]):
40
- raise ValueError (
41
- f"files formats in given folder need to be in { IMG_SUFFIXES } "
42
- )
43
+ im_files = []
44
+ for types in IMG_SUFFIXES :
45
+ im_files . extend ( self . im_dir . glob ( f"* { types } " ))
46
+ self . fnames_im = sorted ( im_files )
43
47
44
- if not self .mask_dir .exists ():
45
- raise ValueError (f"folder: { self .mask_dir } does not exist" )
48
+ # Masks
49
+ self .mask_dir = in_dir_mask
50
+ self .fnames_mask = None
51
+ if in_dir_mask is not None :
52
+ self .mask_dir = Path (in_dir_mask )
46
53
47
- if not self .mask_dir .is_dir ():
48
- raise ValueError (f"path : { self .mask_dir } is not a folder " )
54
+ if not self .mask_dir .exists ():
55
+ raise ValueError (f"folder : { self .mask_dir } does not exist " )
49
56
50
- if not all ([f .suffix in MASK_SUFFIXES for f in self .mask_dir .iterdir ()]):
51
- raise ValueError (
52
- f"files formats in given folder need to be in { MASK_SUFFIXES } "
53
- )
57
+ if not self .mask_dir .is_dir ():
58
+ raise ValueError (f"path: { self .mask_dir } is not a folder" )
54
59
55
- self .fnames_im = sorted (self .im_dir .glob ("*" ))
56
- self .fnames_mask = sorted (self .mask_dir .glob ("*" ))
57
- if len (self .fnames_im ) != len (self .fnames_mask ):
58
- raise ValueError (
59
- f"Found different number of files in { self .im_dir .as_posix ()} and "
60
- f"{ self .mask_dir .as_posix ()} ."
61
- )
60
+ mask_files = []
61
+ for types in MASK_SUFFIXES :
62
+ mask_files .extend (self .mask_dir .glob (f"*{ types } " ))
63
+ self .fnames_mask = sorted (mask_files )
62
64
65
+ if len (self .fnames_im ) != len (self .fnames_mask ):
66
+ raise ValueError (
67
+ f"Found different number of files in { self .im_dir .as_posix ()} and "
68
+ f"{ self .mask_dir .as_posix ()} ."
69
+ )
70
+
71
+ # Transformations
63
72
self .transforms = None
64
73
if transforms is not None :
65
74
allowed = list (IMG_TRANSFORMS .keys ())
@@ -77,78 +86,200 @@ def write(self):
77
86
"""Patch images and mask to and write them to disk."""
78
87
raise NotImplementedError
79
88
80
- def _get_tiles (
89
+ def get_array (
81
90
self ,
82
91
img_path : Union [str , Path ],
83
- mask_path : Union [str , Path ],
92
+ mask_path : Optional [Union [str , Path ]] = None ,
93
+ tiling : Optional [bool ] = False ,
94
+ pre_proc : Optional [Callable ] = None ,
95
+ ) -> Tuple [np .ndarray , Union [None , Dict [str , np .ndarray ]]]:
96
+ """Pipeline that (optionally) patches and transforms input images and masks.
97
+
98
+ Parameters
99
+ ----------
100
+ img_path : str or Path
101
+ Path to an image file.
102
+ mask_path : str or Path, optional
103
+ Path to a .mat mask file.
104
+ tiling : bool, default=False, optional
105
+ Flag, whether to do tiling on the images (and masks).
106
+ pre_proc : Callable, optional
107
+ A pre-processing function that can be used to pre-process given input
108
+ masks before the pipeline.
109
+
110
+ Raises
111
+ ------
112
+ ValueError if self.stride or self.patch_size are not set to integer values.
113
+
114
+ Returns
115
+ -------
116
+ Tuple[np.ndarray, Union[None, Dict[str, np.ndarray]]]:
117
+ The processed image & masks. If `mask_path=None`, returns no masks.
118
+ Img shape w/o tiling: (H, W, C). Dtype: uint8.
119
+ Img shape w/ tiling: (N, pH, pW, C). Dtype: uint8.
120
+ Masks w/o tiling: Shapes: (H, W). Dtypes: int32.
121
+ Masks w/ tiling: Shapes: (N, pH, pW). Dtypes: int32.
122
+ """
123
+ im , masks = self ._read_files (img_path , mask_path , pre_proc )
124
+
125
+ # do tiling first if flag set to True
126
+ if tiling :
127
+ if not isinstance (self .stride , int ) and not isinstance (
128
+ self .patch_size , int
129
+ ):
130
+ raise ValueError (
131
+ "`self.stride` and `self.patch_size` need to be integers. Got: "
132
+ f"self.stride={ self .stride } , self.patch_size={ self .patch_size } "
133
+ )
134
+
135
+ im , masks = self ._get_tiles (im , masks )
136
+
137
+ if masks is not None :
138
+ if "inst_map" in masks .keys ():
139
+ masks ["inst_map" ] = self ._fix_instances_tiles (masks ["inst_map" ])
140
+
141
+ if self .transforms is not None :
142
+ im , masks = self ._transform_tiles (im , masks )
143
+ else :
144
+ if masks is not None :
145
+ if "inst_map" in masks .keys ():
146
+ masks ["inst_map" ] = self ._fix_instances_one (masks ["inst_map" ])
147
+
148
+ if self .transforms is not None :
149
+ im , masks = self ._transform_one (im , masks )
150
+
151
+ return im , masks
152
+
153
+ def _fix_instances_one (self , inst_map : np .ndarray ) -> np .ndarray :
154
+ """Fix duplicate instances and remap instance labels."""
155
+ return remap_label (fix_duplicates (inst_map ))
156
+
157
+ def _read_files (
158
+ self ,
159
+ img_path : Union [str , Path ],
160
+ mask_path : Union [str , Path ] = None ,
84
161
pre_proc : Callable = None ,
85
- ) -> Dict [str , np .ndarray ]:
86
- """Read one image and corresponding masks and do tiling on them."""
87
- # im, masks = self._get_arrays()
162
+ ) -> Tuple [np .ndarray , Union [None , Dict [str , np .ndarray ]]]:
163
+ """Read image and corresponding masks if there are such."""
88
164
im = FileHandler .read_img (img_path )
89
- masks = FileHandler .read_mat (mask_path , return_all = True )
90
165
91
- if pre_proc is not None :
92
- masks = pre_proc (masks )
166
+ masks = None
167
+ if mask_path is not None :
168
+ masks = FileHandler .read_mat (mask_path , return_all = True )
169
+
170
+ if pre_proc is not None :
171
+ masks = pre_proc (masks )
93
172
94
- inst = None
95
- types = None
96
- sem = None
97
- if "inst_map" in masks .keys ():
98
- inst = masks ["inst_map" ]
99
- if "type_map" in masks .keys ():
100
- types = masks ["type_map" ]
101
- if "sem_map" in masks .keys ():
102
- sem = masks ["sem_map" ]
173
+ masks = {
174
+ key : arr
175
+ for key , arr in masks .items ()
176
+ if key in ("inst_map" , "type_map" , "sem_map" )
177
+ }
103
178
179
+ return im , masks
180
+
181
+ def _get_tiles (
182
+ self ,
183
+ im : np .ndarray ,
184
+ masks : Union [Dict [str , np .ndarray ], None ] = None ,
185
+ ) -> Tuple [Dict [str , np .ndarray ], Union [Dict [str , np .ndarray ], None ]]:
186
+ """Do tiling on an image and corresponding masks if there are such."""
187
+ # Init Tilers
104
188
im_tiler = TilerStitcher (
105
189
im_shape = im .shape , patch_shape = self .patch_size + (3 ,), stride = self .stride
106
190
)
191
+ im_tiles = im_tiler .patch (im )
192
+
193
+ # Tile masks if there are such.
194
+ mask_tiles = None
195
+ if masks is not None :
196
+ mask_tiles = {}
197
+ inst = None
198
+ types = None
199
+ sem = None
200
+ if "inst_map" in masks .keys ():
201
+ inst = masks ["inst_map" ]
202
+ if "type_map" in masks .keys ():
203
+ types = masks ["type_map" ]
204
+ if "sem_map" in masks .keys ():
205
+ sem = masks ["sem_map" ]
206
+
207
+ mask_tiler = TilerStitcher (
208
+ im_shape = inst .shape ,
209
+ patch_shape = self .patch_size + (1 ,),
210
+ stride = self .stride ,
211
+ )
107
212
108
- mask_tiler = TilerStitcher (
109
- im_shape = inst .shape , patch_shape = self .patch_size + (1 ,), stride = self .stride
110
- )
213
+ if inst is not None :
214
+ mask_tiles ["inst_map" ] = mask_tiler .patch (inst ).squeeze ()
215
+ if types is not None :
216
+ mask_tiles ["type_map" ] = mask_tiler .patch (types ).squeeze ()
217
+ if sem is not None :
218
+ mask_tiles ["sem_map" ] = mask_tiler .patch (sem ).squeeze ()
219
+
220
+ return im_tiles , mask_tiles
111
221
112
- tiles = {}
113
- tiles ["image" ] = im_tiler .patch (im )
114
- if inst is not None :
115
- tiles ["inst_map" ] = self ._fix_duplicates (mask_tiler .patch (inst ).squeeze ())
116
- if types is not None :
117
- tiles ["type_map" ] = mask_tiler .patch (types ).squeeze ()
118
- if sem is not None :
119
- tiles ["sem_map" ] = mask_tiler .patch (sem ).squeeze ()
222
+ def _transform_one (
223
+ self , im : np .ndarray , masks : Dict [str , np .ndarray ] = None
224
+ ) -> Tuple [np .ndarray , Union [Dict [str , np .ndarray ], None ]]:
225
+ """Transform an image and corresponding mask if there is one."""
226
+ if masks is not None :
227
+ mask_names = [name for name in masks .keys ()]
228
+ masks = [mask for mask in masks .values ()]
229
+ out = self .transforms (image = im , masks = masks )
230
+ masks = {n : mask for n , mask in zip (mask_names , out ["masks" ])}
231
+ else :
232
+ out = self .transforms (image = im )
120
233
121
- if self .transforms is not None :
122
- tiles = self ._transform (tiles )
234
+ im = out ["image" ]
123
235
124
- return tiles
236
+ return im , masks
125
237
126
- def _transform (self , tiles : Dict [str , np .ndarray ]) -> np .ndarray :
238
+ def _transform_tiles (
239
+ self ,
240
+ im_tiles : np .ndarray ,
241
+ mask_tiles : Union [Dict [str , np .ndarray ], None ] = None ,
242
+ ) -> Tuple [np .ndarray , Union [Dict [str , np .ndarray ], None ]]:
127
243
"""Apply transformations to the tiles one by one."""
128
- n_tiles = tiles ["image" ].shape [0 ]
129
- masks = [arr for key , arr in tiles .items () if key != "image" ]
130
- mask_names = [key for key in tiles .keys () if key != "image" ]
244
+ n_tiles = im_tiles .shape [0 ]
245
+ out_im_tiles = []
246
+
247
+ out_mask_tiles = None
248
+ if mask_tiles is not None :
249
+ mask_names = [key for key in mask_tiles .keys ()]
250
+ out_mask_tiles = {k : [] for k in mask_tiles .keys ()}
131
251
132
- out_tiles = {k : [] for k in tiles .keys ()}
133
252
for i in range (n_tiles ):
134
- m = [mask [i ] for mask in masks ]
135
- out = self .transforms (image = tiles ["image" ][i ], masks = m )
136
- out_tiles ["image" ].append (out ["image" ])
253
+ # Get one img tile
254
+ im = im_tiles [i ]
255
+
256
+ # Get one set of mask tiles
257
+ masks = None
258
+ if mask_tiles is not None :
259
+ masks = {n : mask_tiles [n ][i ] for n in mask_names }
260
+
261
+ # transform imgs & masks
262
+ im_tr , masks_tr = self ._transform_one (im , masks )
137
263
138
- for j , mname in enumerate (mask_names ):
139
- out_tiles [mname ].append (out ["masks" ][j ])
264
+ out_im_tiles .append (im_tr )
265
+ if mask_tiles is not None :
266
+ for mask_name in mask_names :
267
+ out_mask_tiles [mask_name ].append (masks_tr [mask_name ])
140
268
141
- for k , mask in out_tiles .items ():
142
- out_tiles [k ] = np .array (mask )
269
+ # convert list of 2D-arrays to np.ndarray
270
+ out_im_tiles = np .array (out_im_tiles )
271
+ if mask_tiles is not None :
272
+ for k , arr in out_mask_tiles .items ():
273
+ out_mask_tiles [k ] = np .array (arr )
143
274
144
- return out_tiles
275
+ return out_im_tiles , out_mask_tiles
145
276
146
- def _fix_duplicates (self , patches_inst : np .ndarray ) -> np .ndarray :
147
- """Fix repeatded labels in a patched instance labelled mask."""
277
+ def _fix_instances_tiles (self , patches_inst : np .ndarray ) -> np .ndarray :
278
+ """Fix repeated labels and remap them in a patched instance labelled mask."""
148
279
insts = []
149
280
150
281
for i in range (patches_inst .shape [0 ]):
151
- insts .append (fix_duplicates (patches_inst [i ]))
282
+ insts .append (self . _fix_instances_one (patches_inst [i ]))
152
283
153
284
insts = np .array (insts )
154
285
0 commit comments