Skip to content

Commit 5ea4513

Browse files
committed
feat(transforms): add strongaugment aug pipeline
1 parent 928b35a commit 5ea4513

File tree

5 files changed

+508
-4
lines changed

5 files changed

+508
-4
lines changed

cellseg_models_pytorch/datasets/_base_dataset.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
IMG_TRANSFORMS,
99
INST_TRANSFORMS,
1010
NORM_TRANSFORMS,
11+
StrongAugment,
1112
apply_each,
1213
compose,
1314
to_tensorv3,
@@ -41,7 +42,8 @@ def __init__(
4142
img_transforms : List[str]
4243
A list containing all the transformations that are applied to the input
4344
images and corresponding masks. Allowed ones: "blur", "non_spatial",
44-
"non_rigid", "rigid", "hue_sat", "random_crop", "center_crop", "resize"
45+
"non_rigid", "rigid", "hue_sat", "random_crop", "center_crop", "resize",
46+
"strong_augment"
4547
inst_transforms : List[str]
4648
A list containg all the transformations that are applied to only the
4749
instance labelled masks. Allowed ones: "cellpose", "contour", "dist",
@@ -87,8 +89,17 @@ def __init__(
8789
self.return_type = return_type
8890
self.return_sem = return_sem
8991

90-
# Set transformations
91-
img_transforms = [IMG_TRANSFORMS[tr](**kwargs) for tr in img_transforms]
92+
# add strong augment pipeline
93+
self.strong_aug = None
94+
if "strong_augment" in img_transforms:
95+
self.strong_aug = StrongAugment()
96+
97+
# Set rest of the transforms, (these are applied before straonaug in _getitem)
98+
img_transforms = [
99+
IMG_TRANSFORMS[tr](**kwargs)
100+
for tr in img_transforms
101+
if tr != "strong_augment"
102+
]
92103
if normalization is not None:
93104
img_transforms.append(NORM_TRANSFORMS[normalization]())
94105

@@ -123,7 +134,6 @@ def _getitem(
123134
A dictionary containing all the augmented data patches.
124135
Keys are: "im", "inst", "type", "sem". Image shape: (B, 3, H, W).
125136
Mask shapes: (B, C_mask, H, W).
126-
127137
"""
128138
inputs = read_input_func(ix, self.return_type, self.return_sem)
129139

@@ -133,7 +143,11 @@ def _getitem(
133143
data = dict(image=inputs["image"], masks=masks)
134144

135145
# transform + convert to tensor
146+
if self.strong_aug is not None:
147+
data = self.strong_aug(**data)
148+
136149
aug = self.img_transforms(**data)
150+
137151
aux = self.inst_transforms(image=aug["image"], inst_map=aug["masks"][0])
138152
data = self.to_tensor(image=aug["image"], masks=aug["masks"], aux=aux)
139153

cellseg_models_pytorch/transforms/albu_transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
minmaxnorm_transform,
2828
percentilenorm_transform,
2929
)
30+
from .strong_augment import StrongAugment
3031

3132
IMG_TRANSFORMS = {
3233
"blur": blur_transforms,
@@ -37,6 +38,7 @@
3738
"non_spatial": non_spatial_transforms,
3839
"resize": resize,
3940
"random_crop": random_crop,
41+
"strong_augment": StrongAugment,
4042
}
4143

4244
INST_TRANSFORMS = {
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
2+
3+
import numpy as np
4+
from albumentations.core.composition import BaseCompose
5+
from albumentations.core.transforms_interface import BasicTransform, ImageOnlyTransform
6+
7+
from ..functional.generic_transforms import (
8+
AUGMENT_SPACE,
9+
_apply_operation,
10+
_check_augment_space,
11+
_magnitude_kwargs,
12+
)
13+
14+
TransformType = Union[BasicTransform, "BaseCompose"]
15+
TransformsSeqType = Sequence[TransformType]
16+
17+
__all__ = ["StrongAugment", "StrongAugTransform"]
18+
19+
20+
class StrongAugTransform(ImageOnlyTransform):
21+
def __init__(self, operation_name: str, **kwargs) -> None:
22+
"""Create StronAugment transformation.
23+
24+
This is a albumentations wrapper for the StrongAugment transformations.
25+
26+
Parameters
27+
----------
28+
operation_name : str
29+
Name of the transformation to apply.
30+
"""
31+
super().__init__(always_apply=True, p=1.0)
32+
self.op_name = operation_name
33+
34+
def apply(self, image: np.ndarray, **kwargs) -> np.ndarray:
35+
"""Apply a transformation from the StrognAugment augmentation space.
36+
37+
Parameters
38+
----------
39+
image : np.ndarray:
40+
Input image to be normalized. Shape (H, W, C)|(H, W).
41+
42+
Returns
43+
-------
44+
np.ndarray:
45+
Transformed image. Same shape as input. dtype: float32.
46+
"""
47+
return _apply_operation(image, self.op_name, **kwargs)
48+
49+
def get_transform_init_args_names(self):
50+
"""Get the names of the transformation arguments."""
51+
return ("op_name",)
52+
53+
def update_params(self, params: Dict[str, Any], **kwargs) -> Dict[str, Any]:
54+
"""Update the transformation parameters."""
55+
params.update({kw: it for kw, it in kwargs.items() if kw != "image"})
56+
return params
57+
58+
59+
class StrongAugment(BaseCompose):
60+
def __init__(
61+
self,
62+
augment_space: Dict[str, tuple] = AUGMENT_SPACE,
63+
operations: Tuple[int] = (3, 4, 5),
64+
probabilites: Tuple[float] = (0.2, 0.3, 0.5),
65+
seed: Optional[int] = None,
66+
p=1.0,
67+
) -> None:
68+
"""Strong augment augmentation policy.
69+
70+
Augment like there's no tomorrow: Consistently performing neural networks for
71+
medical imaging: https://arxiv.org/abs/2206.15274
72+
73+
Parameters
74+
----------
75+
augment_space : Dict[str, tuple], default: AUGMENT_SPACE
76+
Augmentation space to sample operations from.
77+
operations : Tuple[int], default: [3, 4, 5].
78+
Number of operations to apply. If None, sample from
79+
[1, len(augment_space)].
80+
probabilites : Tuple[float], default: [0.2, 0.3, 0.5]
81+
Probabilities of sampling operations. If None, sample from
82+
the uniform distribution.
83+
seed : Optional[int], default: None
84+
Random seed.
85+
p : float, default: 1.0
86+
Probability of applying the transform.
87+
"""
88+
_check_augment_space(augment_space)
89+
if len(operations) != len(probabilites):
90+
raise ValueError("Operation length does not match probabilities length.")
91+
92+
transforms = [StrongAugTransform(op) for op in augment_space.keys()]
93+
self.rng = np.random.RandomState(seed=seed)
94+
self.augment_space = augment_space
95+
self.operations = operations
96+
self.probabilites = probabilites
97+
self.last_operations = dict()
98+
super().__init__(transforms, p=p)
99+
100+
def __call__(self, *args, force_apply: bool = False, **data) -> Dict[str, Any]:
101+
"""Apply the StrongAugment transformation pipeline."""
102+
image = data["image"].copy()
103+
masks = data["masks"].copy()
104+
105+
num_ops = np.random.choice(self.operations, p=self.probabilites)
106+
idx = self.rng.choice(len(self.transforms), size=num_ops, replace=False)
107+
108+
rs = np.random.random()
109+
if force_apply or rs < self.p:
110+
for i in idx:
111+
t = self.transforms[i]
112+
name = t.op_name
113+
kwargs = dict(
114+
name=name,
115+
**_magnitude_kwargs(
116+
name, bounds=self.augment_space[name], rng=self.rng
117+
),
118+
)
119+
120+
data = t(image=image, masks=masks, force_apply=True, **kwargs)
121+
self.last_operations[name] = kwargs
122+
123+
return {k: d for k, d in data.items() if k in ("image", "masks")}
124+
125+
def __repr__(self) -> str:
126+
"""Return the string representation of the StrongAugment object."""
127+
return (
128+
f"{self.__class__.__name__}("
129+
f"operations={self.operations}, "
130+
f"probabilites={self.probabilites}, "
131+
f"augment_space={self.augment_space})"
132+
)

0 commit comments

Comments
 (0)