Skip to content

Commit d435e05

Browse files
authored
Add hyperparameters.Choice support for zoom_factor of ImageAugmentation block (#1716)
1 parent 8d6a2e4 commit d435e05

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

autokeras/blocks/preprocessing.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,10 @@ class ImageAugmentation(block_module.Block):
157157
represented as a single float, lower = upper.
158158
If left unspecified, it will be tuned automatically.
159159
zoom_factor: A positive float represented as fraction value, or a tuple of 2
160-
representing fraction for zooming vertically and horizontally. For
161-
instance, `zoom_factor=0.2` result in a random zoom factor from 80% to
162-
120%. If left unspecified, it will be tuned automatically.
160+
representing fraction for zooming vertically and horizontally,
161+
or a kerastuner.engine.hyperparameters.Choice range of positive floats.
162+
For instance, `zoom_factor=0.2` result in a random zoom factor from 80%
163+
to 120%. If left unspecified, it will be tuned automatically.
163164
contrast_factor: A positive float represented as fraction of value, or a
164165
tuple of size 2 representing lower and upper bound, or a
165166
kerastuner.engine.hyperparameters.Choice range of floats to find the
@@ -177,7 +178,9 @@ def __init__(
177178
vertical_flip: Optional[bool] = None,
178179
horizontal_flip: Optional[bool] = None,
179180
rotation_factor: Optional[Union[float, hyperparameters.Choice]] = None,
180-
zoom_factor: Optional[Union[float, Tuple[float, float]]] = None,
181+
zoom_factor: Optional[
182+
Union[float, Tuple[float, float], hyperparameters.Choice]
183+
] = None,
181184
contrast_factor: Optional[
182185
Union[float, Tuple[float, float], hyperparameters.Choice]
183186
] = None,
@@ -196,7 +199,11 @@ def __init__(
196199
hyperparameters.Choice("rotation_factor", [0.0, 0.1]),
197200
float,
198201
)
199-
self.zoom_factor = zoom_factor
202+
self.zoom_factor = utils.get_hyperparameter(
203+
zoom_factor,
204+
hyperparameters.Choice("zoom_factor", [0.0, 0.1]),
205+
Union[float, Tuple[float, float]],
206+
)
200207
self.contrast_factor = utils.get_hyperparameter(
201208
contrast_factor,
202209
hyperparameters.Choice("contrast_factor", [0.0, 0.1]),
@@ -247,9 +254,7 @@ def build(self, hp, inputs=None):
247254
output_node = layers.RandomRotation(rotation_factor)(output_node)
248255

249256
# Zoom
250-
zoom_factor = self.zoom_factor
251-
if zoom_factor is None:
252-
zoom_factor = hp.Choice("zoom_factor", [0.0, 0.1])
257+
zoom_factor = utils.add_to_hp(self.zoom_factor, hp)
253258
if zoom_factor not in [0, (0, 0)]:
254259
height_factor, width_factor = self._get_fraction_value(zoom_factor)
255260
# TODO: Add back RandomZoom when it is ready.
@@ -273,7 +278,7 @@ def get_config(self):
273278
"horizontal_flip": self.horizontal_flip,
274279
"vertical_flip": self.vertical_flip,
275280
"rotation_factor": hyperparameters.serialize(self.rotation_factor),
276-
"zoom_factor": self.zoom_factor,
281+
"zoom_factor": hyperparameters.serialize(self.zoom_factor),
277282
"contrast_factor": hyperparameters.serialize(self.contrast_factor),
278283
}
279284
)

0 commit comments

Comments
 (0)