Skip to content

Commit 08764a2

Browse files
authored
fix bug serializing block arguments (#1766)
* use custom serialize function * add unit test
1 parent 490436d commit 08764a2

File tree

6 files changed

+85
-47
lines changed

6 files changed

+85
-47
lines changed

.vscode/settings.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,8 @@
2020
"editor.codeActionsOnSave": {
2121
"source.organizeImports": true
2222
}
23+
},
24+
"python.analysis.diagnosticSeverityOverrides": {
25+
"reportMissingImports": "none"
2326
}
24-
}
27+
}

autokeras/blocks/basic.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from autokeras import keras_layers
2626
from autokeras.blocks import reduction
2727
from autokeras.engine import block as block_module
28+
from autokeras.utils import io_utils
2829
from autokeras.utils import layer_utils
2930
from autokeras.utils import utils
3031

@@ -103,19 +104,19 @@ def get_config(self):
103104
config = super().get_config()
104105
config.update(
105106
{
106-
"num_layers": hyperparameters.serialize(self.num_layers),
107-
"num_units": hyperparameters.serialize(self.num_units),
107+
"num_layers": io_utils.serialize_block_arg(self.num_layers),
108+
"num_units": io_utils.serialize_block_arg(self.num_units),
108109
"use_batchnorm": self.use_batchnorm,
109-
"dropout": hyperparameters.serialize(self.dropout),
110+
"dropout": io_utils.serialize_block_arg(self.dropout),
110111
}
111112
)
112113
return config
113114

114115
@classmethod
115116
def from_config(cls, config):
116-
config["num_layers"] = hyperparameters.deserialize(config["num_layers"])
117-
config["num_units"] = hyperparameters.deserialize(config["num_units"])
118-
config["dropout"] = hyperparameters.deserialize(config["dropout"])
117+
config["num_layers"] = io_utils.deserialize_block_arg(config["num_layers"])
118+
config["num_units"] = io_utils.deserialize_block_arg(config["num_units"])
119+
config["dropout"] = io_utils.deserialize_block_arg(config["dropout"])
119120
return cls(**config)
120121

121122
def build(self, hp, inputs=None):
@@ -190,20 +191,20 @@ def get_config(self):
190191
config.update(
191192
{
192193
"return_sequences": self.return_sequences,
193-
"bidirectional": hyperparameters.serialize(self.bidirectional),
194-
"num_layers": hyperparameters.serialize(self.num_layers),
195-
"layer_type": hyperparameters.serialize(self.layer_type),
194+
"bidirectional": io_utils.serialize_block_arg(self.bidirectional),
195+
"num_layers": io_utils.serialize_block_arg(self.num_layers),
196+
"layer_type": io_utils.serialize_block_arg(self.layer_type),
196197
}
197198
)
198199
return config
199200

200201
@classmethod
201202
def from_config(cls, config):
202-
config["bidirectional"] = hyperparameters.deserialize(
203+
config["bidirectional"] = io_utils.deserialize_block_arg(
203204
config["bidirectional"]
204205
)
205-
config["num_layers"] = hyperparameters.deserialize(config["num_layers"])
206-
config["layer_type"] = hyperparameters.deserialize(config["layer_type"])
206+
config["num_layers"] = io_utils.deserialize_block_arg(config["num_layers"])
207+
config["layer_type"] = io_utils.deserialize_block_arg(config["layer_type"])
207208
return cls(**config)
208209

209210
def build(self, hp, inputs=None):
@@ -314,24 +315,24 @@ def get_config(self):
314315
config = super().get_config()
315316
config.update(
316317
{
317-
"kernel_size": hyperparameters.serialize(self.kernel_size),
318-
"num_blocks": hyperparameters.serialize(self.num_blocks),
319-
"num_layers": hyperparameters.serialize(self.num_layers),
320-
"filters": hyperparameters.serialize(self.filters),
318+
"kernel_size": io_utils.serialize_block_arg(self.kernel_size),
319+
"num_blocks": io_utils.serialize_block_arg(self.num_blocks),
320+
"num_layers": io_utils.serialize_block_arg(self.num_layers),
321+
"filters": io_utils.serialize_block_arg(self.filters),
321322
"max_pooling": self.max_pooling,
322323
"separable": self.separable,
323-
"dropout": hyperparameters.serialize(self.dropout),
324+
"dropout": io_utils.serialize_block_arg(self.dropout),
324325
}
325326
)
326327
return config
327328

328329
@classmethod
329330
def from_config(cls, config):
330-
config["kernel_size"] = hyperparameters.deserialize(config["kernel_size"])
331-
config["num_blocks"] = hyperparameters.deserialize(config["num_blocks"])
332-
config["num_layers"] = hyperparameters.deserialize(config["num_layers"])
333-
config["filters"] = hyperparameters.deserialize(config["filters"])
334-
config["dropout"] = hyperparameters.deserialize(config["dropout"])
331+
config["kernel_size"] = io_utils.deserialize_block_arg(config["kernel_size"])
332+
config["num_blocks"] = io_utils.deserialize_block_arg(config["num_blocks"])
333+
config["num_layers"] = io_utils.deserialize_block_arg(config["num_layers"])
334+
config["filters"] = io_utils.deserialize_block_arg(config["filters"])
335+
config["dropout"] = io_utils.deserialize_block_arg(config["dropout"])
335336
return cls(**config)
336337

337338
def build(self, hp, inputs=None):
@@ -560,24 +561,24 @@ def get_config(self):
560561
config.update(
561562
{
562563
"max_features": self.max_features,
563-
"pretraining": hyperparameters.serialize(self.pretraining),
564-
"embedding_dim": hyperparameters.serialize(self.embedding_dim),
565-
"num_heads": hyperparameters.serialize(self.num_heads),
566-
"dense_dim": hyperparameters.serialize(self.dense_dim),
567-
"dropout": hyperparameters.serialize(self.dropout),
564+
"pretraining": io_utils.serialize_block_arg(self.pretraining),
565+
"embedding_dim": io_utils.serialize_block_arg(self.embedding_dim),
566+
"num_heads": io_utils.serialize_block_arg(self.num_heads),
567+
"dense_dim": io_utils.serialize_block_arg(self.dense_dim),
568+
"dropout": io_utils.serialize_block_arg(self.dropout),
568569
}
569570
)
570571
return config
571572

572573
@classmethod
573574
def from_config(cls, config):
574-
config["pretraining"] = hyperparameters.deserialize(config["pretraining"])
575-
config["embedding_dim"] = hyperparameters.deserialize(
575+
config["pretraining"] = io_utils.deserialize_block_arg(config["pretraining"])
576+
config["embedding_dim"] = io_utils.deserialize_block_arg(
576577
config["embedding_dim"]
577578
)
578-
config["num_heads"] = hyperparameters.deserialize(config["num_heads"])
579-
config["dense_dim"] = hyperparameters.deserialize(config["dense_dim"])
580-
config["dropout"] = hyperparameters.deserialize(config["dropout"])
579+
config["num_heads"] = io_utils.deserialize_block_arg(config["num_heads"])
580+
config["dense_dim"] = io_utils.deserialize_block_arg(config["dense_dim"])
581+
config["dropout"] = io_utils.deserialize_block_arg(config["dropout"])
581582
return cls(**config)
582583

583584
def build(self, hp, inputs=None):
@@ -872,18 +873,18 @@ def get_config(self):
872873
config.update(
873874
{
874875
"max_features": self.max_features,
875-
"pretraining": hyperparameters.serialize(self.pretraining),
876-
"embedding_dim": hyperparameters.serialize(self.embedding_dim),
877-
"dropout": hyperparameters.serialize(self.dropout),
876+
"pretraining": io_utils.serialize_block_arg(self.pretraining),
877+
"embedding_dim": io_utils.serialize_block_arg(self.embedding_dim),
878+
"dropout": io_utils.serialize_block_arg(self.dropout),
878879
}
879880
)
880881
return config
881882

882883
@classmethod
883884
def from_config(cls, config):
884-
config["pretraining"] = hyperparameters.deserialize(config["pretraining"])
885-
config["dropout"] = hyperparameters.deserialize(config["dropout"])
886-
config["embedding_dim"] = hyperparameters.deserialize(
885+
config["pretraining"] = io_utils.deserialize_block_arg(config["pretraining"])
886+
config["dropout"] = io_utils.deserialize_block_arg(config["dropout"])
887+
config["embedding_dim"] = io_utils.deserialize_block_arg(
887888
config["embedding_dim"]
888889
)
889890
return cls(**config)
@@ -956,7 +957,7 @@ def get_config(self):
956957
config = super().get_config()
957958
config.update(
958959
{
959-
"max_sequence_length": hyperparameters.serialize(
960+
"max_sequence_length": io_utils.serialize_block_arg(
960961
self.max_sequence_length
961962
)
962963
}
@@ -965,7 +966,7 @@ def get_config(self):
965966

966967
@classmethod
967968
def from_config(cls, config):
968-
config["max_sequence_length"] = hyperparameters.deserialize(
969+
config["max_sequence_length"] = io_utils.deserialize_block_arg(
969970
config["max_sequence_length"]
970971
)
971972
return cls(**config)

autokeras/blocks/preprocessing.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from autokeras import analysers
2424
from autokeras import keras_layers
2525
from autokeras.engine import block as block_module
26+
from autokeras.utils import io_utils
2627
from autokeras.utils import utils
2728

2829

@@ -272,23 +273,34 @@ def get_config(self):
272273
config = super().get_config()
273274
config.update(
274275
{
275-
"translation_factor": hyperparameters.serialize(
276+
"translation_factor": io_utils.serialize_block_arg(
276277
self.translation_factor
277278
),
278279
"horizontal_flip": self.horizontal_flip,
279280
"vertical_flip": self.vertical_flip,
280-
"rotation_factor": hyperparameters.serialize(self.rotation_factor),
281-
"zoom_factor": hyperparameters.serialize(self.zoom_factor),
282-
"contrast_factor": hyperparameters.serialize(self.contrast_factor),
281+
"rotation_factor": io_utils.serialize_block_arg(
282+
self.rotation_factor
283+
),
284+
"zoom_factor": io_utils.serialize_block_arg(self.zoom_factor),
285+
"contrast_factor": io_utils.serialize_block_arg(
286+
self.contrast_factor
287+
),
283288
}
284289
)
285290
return config
286291

287292
@classmethod
288293
def from_config(cls, config):
289-
config["rotation_factor"] = hyperparameters.deserialize(
294+
config["translation_factor"] = io_utils.deserialize_block_arg(
295+
config["translation_factor"]
296+
)
297+
config["rotation_factor"] = io_utils.deserialize_block_arg(
290298
config["rotation_factor"]
291299
)
300+
config["zoom_factor"] = io_utils.deserialize_block_arg(config["zoom_factor"])
301+
config["contrast_factor"] = io_utils.deserialize_block_arg(
302+
config["contrast_factor"]
303+
)
292304
return cls(**config)
293305

294306

autokeras/blocks/preprocessing_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import keras_tuner
1616
import tensorflow as tf
17+
from keras_tuner.engine import hyperparameters
1718
from tensorflow import keras
1819
from tensorflow import nest
1920

@@ -88,11 +89,18 @@ def test_augment_build_with_contrast_factor_return_tensor():
8889

8990

9091
def test_augment_deserialize_to_augment():
91-
serialized_block = blocks.serialize(blocks.ImageAugmentation())
92+
serialized_block = blocks.serialize(
93+
blocks.ImageAugmentation(
94+
zoom_factor=0.1,
95+
contrast_factor=hyperparameters.Float("contrast_factor", 0.1, 0.5),
96+
)
97+
)
9298

9399
block = blocks.deserialize(serialized_block)
94100

95101
assert isinstance(block, blocks.ImageAugmentation)
102+
assert block.zoom_factor == 0.1
103+
assert isinstance(block.contrast_factor, hyperparameters.Float)
96104

97105

98106
def test_augment_get_config_has_all_attributes():

autokeras/utils/io_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222
import tensorflow as tf
23+
from keras_tuner.engine import hyperparameters
2324

2425
WHITELIST_FORMATS = (".bmp", ".gif", ".jpeg", ".jpg", ".png")
2526

@@ -380,3 +381,15 @@ def path_to_image(image, num_channels, image_size, interpolation):
380381
image = tf.image.resize(image, image_size, method=interpolation)
381382
image.set_shape((image_size[0], image_size[1], num_channels))
382383
return image
384+
385+
386+
def deserialize_block_arg(arg):
387+
if isinstance(arg, dict):
388+
return hyperparameters.deserialize(arg)
389+
return arg
390+
391+
392+
def serialize_block_arg(arg):
393+
if isinstance(arg, hyperparameters.HyperParameter):
394+
return hyperparameters.serialize(arg)
395+
return arg

shell/coverage.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pytest --cov-report xml:cov.xml --cov autokeras $1

0 commit comments

Comments
 (0)