Skip to content

Commit 307f118

Browse files
committed
fix: use albumentations in torch datasets
1 parent 1bd5d23 commit 307f118

File tree

2 files changed

+45
-68
lines changed

2 files changed

+45
-68
lines changed

cellseg_models_pytorch/torch_datasets/hdf5_dataset_train.py

Lines changed: 24 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -23,42 +23,36 @@
2323

2424
__all__ = ["TrainDatasetH5"]
2525

26-
ALLOWED_KEYS = ("image", "inst", "type", "cyto_inst", "cyto_type", "sem")
27-
2826

2927
class TrainDatasetH5(Dataset):
3028
def __init__(
3129
self,
3230
path: str,
33-
input_keys: Tuple[str, ...],
31+
img_key: str,
32+
inst_keys: Tuple[str, ...],
33+
mask_keys: Tuple[str, ...],
3434
transforms: A.Compose,
3535
inst_transforms: ApplyEach,
36-
map_out_keys: Dict[str, str] = None,
3736
) -> None:
3837
"""HDF5 train dataset for cell/panoptic segmentation models.
3938
4039
Parameters:
4140
path (str):
4241
Path to the h5 file.
43-
input_keys (Tuple[str, ...]):
44-
Tuple of keys to be read from the h5 file.
42+
img_key (str):
43+
Key for the image data in the h5 file.
44+
inst_keys (Tuple[str, ...]):
45+
Key for the instance data in the h5 file. This will be transformed
46+
mask_keys (Tuple[str, ...]):
47+
Keys for the semantic masks in the h5 file.
4548
transforms (A.Compose):
4649
Albumentations compose object for image and mask transforms.
4750
inst_transforms (ApplyEach):
4851
ApplyEach object for instance transforms.
49-
map_out_keys (Dict[str, str], default=None):
50-
A dictionary to map the default output keys to new output keys. .
51-
Useful if you want to match the output keys with model output keys.
52-
e.g. {"inst": "decoder1-inst", "inst-cellpose": decoder2-cellpose}.
53-
The default output keys are any of 'image', 'inst', 'type', 'cyto_inst',
54-
'cyto_type', 'sem' & inst-{transform.name}, cyto_inst-{transform.name}.
5552
5653
Raises:
5754
ModuleNotFoundError: If albumentations or tables is not installed.
5855
ModuleNotFoundError: If tables is not installed.
59-
ValueError: If invalid keys are provided.
60-
ValueError: If 'image' key is not present in input_keys.
61-
ValueError: If 'inst' key is not present in input_keys.
6256
"""
6357
if not has_albu:
6458
raise ModuleNotFoundError(
@@ -72,32 +66,18 @@ def __init__(
7266
"Install with `pip install tables`"
7367
)
7468

75-
if not all(k in ALLOWED_KEYS for k in input_keys):
76-
raise ValueError(
77-
f"Invalid keys. Allowed keys are {ALLOWED_KEYS}, got {input_keys}"
78-
)
79-
80-
if "image" not in input_keys:
81-
raise ValueError("'image' key must be present in keys")
82-
83-
if "inst" not in input_keys:
84-
raise ValueError("'inst' key must be present in keys")
85-
8669
self.path = path
87-
self.keys = input_keys
88-
self.mask_keys = [k for k in input_keys if k != "image"]
89-
self.inst_in_keys = [k for k in input_keys if "inst" in k]
90-
self.inst_out_keys = [
91-
f"{key}-{name}"
92-
for name in inst_transforms.names
93-
for key in self.inst_in_keys
94-
]
70+
self.img_key = img_key
71+
self.inst_keys = inst_keys
72+
self.mask_keys = mask_keys
73+
self.keys = [img_key] + list(mask_keys) + list(inst_keys)
9574
self.transforms = transforms
9675
self.inst_transforms = inst_transforms
97-
self.map_out_keys = map_out_keys
9876

9977
with tb.open_file(path, "r") as h5:
100-
self.n_items = len(h5.root["fname"][:])
78+
for array in h5.walk_nodes("/", classname="Array"):
79+
self.n_items = len(array)
80+
break
10181

10282
def __len__(self) -> int:
10383
"""Return the number of items in the db."""
@@ -107,49 +87,34 @@ def __getitem__(self, ix: int) -> Dict[str, np.ndarray]:
10787
data = FileHandler.read_h5(self.path, ix, keys=self.keys)
10888

10989
# get instance transform kwargs
110-
inst_kws = {
111-
k: data[k] for k in self.inst_in_keys if data.get(k, None) is not None
112-
}
90+
inst_kws = {k: data[k] for k in self.inst_keys}
11391

11492
# apply instance transforms
11593
aux = self.inst_transforms(**inst_kws)
11694

11795
# append integer masks and instance transformed masks
118-
masks = [d[..., np.newaxis] for k, d in data.items() if k != "image"] + aux
96+
masks = [data[k][..., np.newaxis] for k in self.mask_keys] + aux
11997

12098
# number of channels per non image data
12199
mask_chls = [m.shape[2] for m in masks]
122100

123101
# concatenate all masks + inst transforms
124102
masks = np.concatenate(masks, axis=-1)
125-
126-
tr = self.transforms(image=data["image"], masks=[masks])
103+
tr = self.transforms(image=data[self.img_key], masks=[masks])
127104

128105
image = to_tensor(tr["image"])
129106
masks = to_tensor(tr["masks"][0])
130107
masks = torch.split(masks, mask_chls, dim=0)
131108

132109
integer_masks = {
133-
n: masks[i].squeeze().long()
134-
for i, n in enumerate(self.mask_keys)
135-
# n: masks[i].squeeze()
136-
# for i, n in enumerate(self.mask_keys)
110+
n: masks[i].squeeze().long() for i, n in enumerate(self.mask_keys)
137111
}
138112
inst_transformed_masks = {
139-
# n: masks[len(integer_masks) + i]
140-
# for i, n in enumerate(self.inst_out_keys)
141-
n: masks[len(integer_masks) + i].float()
142-
for i, n in enumerate(self.inst_out_keys)
113+
f"{n}_{tr_n}": masks[len(integer_masks) + i].float()
114+
for n in self.inst_keys
115+
for i, tr_n in enumerate(self.inst_transforms.names)
143116
}
144117

145-
# out = {"image": image.float(), **integer_masks, **inst_transformed_masks}
146-
out = {"image": image.float(), **integer_masks, **inst_transformed_masks}
147-
148-
if self.map_out_keys is not None:
149-
new_out = {}
150-
for in_key, out_key in self.map_out_keys.items():
151-
if in_key in out:
152-
new_out[out_key] = out.pop(in_key)
153-
out = new_out
118+
out = {self.img_key: image.float(), **inst_transformed_masks, **integer_masks}
154119

155120
return out

cellseg_models_pytorch/torch_datasets/wsi_dataset_infer.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,20 @@
2424
SOFTWARE.
2525
"""
2626

27-
from typing import Any, Callable, Dict, Iterator, Optional
27+
from typing import Dict, Iterator
2828

2929
import numpy as np
3030
from torch.utils.data import Dataset
3131

3232
from cellseg_models_pytorch.wsi.reader import SlideReader
3333

34+
try:
35+
import albumentations as A
36+
37+
has_albu = True
38+
except ModuleNotFoundError:
39+
has_albu = False
40+
3441
__all__ = ["WSIDatasetInfer"]
3542

3643

@@ -40,7 +47,7 @@ def __init__(
4047
reader: SlideReader,
4148
coordinates: Iterator[tuple[int, int, int, int]],
4249
level: int = 0,
43-
transform: Optional[Callable[[np.ndarray], Any]] = None,
50+
transforms: A.Compose = None,
4451
) -> None:
4552
"""Initialize WSIReaderDataset.
4653
@@ -50,27 +57,32 @@ def __init__(
5057
coordinates (Iterator[tuple[int, int, int, int]]):
5158
Iterator of xywh-coordinates.
5259
level (int):
53-
Slide level for reading tile image. Defaults to 0.
54-
transform (Optional[Callable[[np.ndarray], Any]]):
55-
Transform function for tile images. Defaults to None.
60+
Slide level for reading tile images.
61+
transforms (A.Compose, default=None):
62+
Albumentations Compose object ocntaining transformations for tile images.
5663
5764
Raises:
5865
ImportError: Could not import `PyTorch`.
5966
"""
6067
super().__init__()
68+
if not has_albu:
69+
raise ModuleNotFoundError(
70+
"The albumentations lib is needed for TrainDatasetH5. "
71+
"Install with `pip install albumentations`"
72+
)
73+
6174
self.reader = reader
6275
self.coordinates = coordinates
6376
self.level = level
64-
self.transform = transform
77+
self.transforms = transforms
6578

6679
def __len__(self) -> int:
6780
return len(self.coordinates)
6881

6982
def __getitem__(self, index: int) -> Dict[str, np.ndarray]:
7083
xywh = self.coordinates[index]
7184
tile = self.reader.read_region(xywh, level=self.level)
72-
73-
if self.transform is not None:
74-
tile = self.transform(tile)
85+
if self.transforms is not None:
86+
tile = self.transforms(image=tile)["image"]
7587

7688
return {"image": tile, "name": self.reader.name, "coords": np.array(xywh)}

0 commit comments

Comments
 (0)