23
23
24
24
__all__ = ["TrainDatasetH5" ]
25
25
26
- ALLOWED_KEYS = ("image" , "inst" , "type" , "cyto_inst" , "cyto_type" , "sem" )
27
-
28
26
29
27
class TrainDatasetH5 (Dataset ):
30
28
def __init__ (
31
29
self ,
32
30
path : str ,
33
- input_keys : Tuple [str , ...],
31
+ img_key : str ,
32
+ inst_keys : Tuple [str , ...],
33
+ mask_keys : Tuple [str , ...],
34
34
transforms : A .Compose ,
35
35
inst_transforms : ApplyEach ,
36
- map_out_keys : Dict [str , str ] = None ,
37
36
) -> None :
38
37
"""HDF5 train dataset for cell/panoptic segmentation models.
39
38
40
39
Parameters:
41
40
path (str):
42
41
Path to the h5 file.
43
- input_keys (Tuple[str, ...]):
44
- Tuple of keys to be read from the h5 file.
42
+ img_key (str):
43
+ Key for the image data in the h5 file.
44
+ inst_keys (Tuple[str, ...]):
45
+ Key for the instance data in the h5 file. This will be transformed
46
+ mask_keys (Tuple[str, ...]):
47
+ Keys for the semantic masks in the h5 file.
45
48
transforms (A.Compose):
46
49
Albumentations compose object for image and mask transforms.
47
50
inst_transforms (ApplyEach):
48
51
ApplyEach object for instance transforms.
49
- map_out_keys (Dict[str, str], default=None):
50
- A dictionary to map the default output keys to new output keys. .
51
- Useful if you want to match the output keys with model output keys.
52
- e.g. {"inst": "decoder1-inst", "inst-cellpose": decoder2-cellpose}.
53
- The default output keys are any of 'image', 'inst', 'type', 'cyto_inst',
54
- 'cyto_type', 'sem' & inst-{transform.name}, cyto_inst-{transform.name}.
55
52
56
53
Raises:
57
54
ModuleNotFoundError: If albumentations or tables is not installed.
58
55
ModuleNotFoundError: If tables is not installed.
59
- ValueError: If invalid keys are provided.
60
- ValueError: If 'image' key is not present in input_keys.
61
- ValueError: If 'inst' key is not present in input_keys.
62
56
"""
63
57
if not has_albu :
64
58
raise ModuleNotFoundError (
@@ -72,32 +66,18 @@ def __init__(
72
66
"Install with `pip install tables`"
73
67
)
74
68
75
- if not all (k in ALLOWED_KEYS for k in input_keys ):
76
- raise ValueError (
77
- f"Invalid keys. Allowed keys are { ALLOWED_KEYS } , got { input_keys } "
78
- )
79
-
80
- if "image" not in input_keys :
81
- raise ValueError ("'image' key must be present in keys" )
82
-
83
- if "inst" not in input_keys :
84
- raise ValueError ("'inst' key must be present in keys" )
85
-
86
69
self .path = path
87
- self .keys = input_keys
88
- self .mask_keys = [k for k in input_keys if k != "image" ]
89
- self .inst_in_keys = [k for k in input_keys if "inst" in k ]
90
- self .inst_out_keys = [
91
- f"{ key } -{ name } "
92
- for name in inst_transforms .names
93
- for key in self .inst_in_keys
94
- ]
70
+ self .img_key = img_key
71
+ self .inst_keys = inst_keys
72
+ self .mask_keys = mask_keys
73
+ self .keys = [img_key ] + list (mask_keys ) + list (inst_keys )
95
74
self .transforms = transforms
96
75
self .inst_transforms = inst_transforms
97
- self .map_out_keys = map_out_keys
98
76
99
77
with tb .open_file (path , "r" ) as h5 :
100
- self .n_items = len (h5 .root ["fname" ][:])
78
+ for array in h5 .walk_nodes ("/" , classname = "Array" ):
79
+ self .n_items = len (array )
80
+ break
101
81
102
82
def __len__ (self ) -> int :
103
83
"""Return the number of items in the db."""
@@ -107,49 +87,34 @@ def __getitem__(self, ix: int) -> Dict[str, np.ndarray]:
107
87
data = FileHandler .read_h5 (self .path , ix , keys = self .keys )
108
88
109
89
# get instance transform kwargs
110
- inst_kws = {
111
- k : data [k ] for k in self .inst_in_keys if data .get (k , None ) is not None
112
- }
90
+ inst_kws = {k : data [k ] for k in self .inst_keys }
113
91
114
92
# apply instance transforms
115
93
aux = self .inst_transforms (** inst_kws )
116
94
117
95
# append integer masks and instance transformed masks
118
- masks = [d [ ..., np .newaxis ] for k , d in data . items () if k != "image" ] + aux
96
+ masks = [data [ k ][ ..., np .newaxis ] for k in self . mask_keys ] + aux
119
97
120
98
# number of channels per non image data
121
99
mask_chls = [m .shape [2 ] for m in masks ]
122
100
123
101
# concatenate all masks + inst transforms
124
102
masks = np .concatenate (masks , axis = - 1 )
125
-
126
- tr = self .transforms (image = data ["image" ], masks = [masks ])
103
+ tr = self .transforms (image = data [self .img_key ], masks = [masks ])
127
104
128
105
image = to_tensor (tr ["image" ])
129
106
masks = to_tensor (tr ["masks" ][0 ])
130
107
masks = torch .split (masks , mask_chls , dim = 0 )
131
108
132
109
integer_masks = {
133
- n : masks [i ].squeeze ().long ()
134
- for i , n in enumerate (self .mask_keys )
135
- # n: masks[i].squeeze()
136
- # for i, n in enumerate(self.mask_keys)
110
+ n : masks [i ].squeeze ().long () for i , n in enumerate (self .mask_keys )
137
111
}
138
112
inst_transformed_masks = {
139
- # n: masks[len(integer_masks) + i]
140
- # for i, n in enumerate(self.inst_out_keys)
141
- n : masks [len (integer_masks ) + i ].float ()
142
- for i , n in enumerate (self .inst_out_keys )
113
+ f"{ n } _{ tr_n } " : masks [len (integer_masks ) + i ].float ()
114
+ for n in self .inst_keys
115
+ for i , tr_n in enumerate (self .inst_transforms .names )
143
116
}
144
117
145
- # out = {"image": image.float(), **integer_masks, **inst_transformed_masks}
146
- out = {"image" : image .float (), ** integer_masks , ** inst_transformed_masks }
147
-
148
- if self .map_out_keys is not None :
149
- new_out = {}
150
- for in_key , out_key in self .map_out_keys .items ():
151
- if in_key in out :
152
- new_out [out_key ] = out .pop (in_key )
153
- out = new_out
118
+ out = {self .img_key : image .float (), ** inst_transformed_masks , ** integer_masks }
154
119
155
120
return out
0 commit comments