18
18
19
19
import keras_tuner
20
20
import tensorflow as tf
21
- from keras_tuner .engine import hypermodel as hm_module
22
21
from tensorflow .keras import callbacks as tf_callbacks
23
22
from tensorflow .keras .layers .experimental import preprocessing
24
23
from tensorflow .python .util import nest
@@ -43,7 +42,7 @@ class AutoTuner(keras_tuner.engine.tuner.Tuner):
43
42
44
43
# Arguments
45
44
oracle: keras_tuner Oracle.
46
- hypermodel: keras_tuner KerasHyperModel .
45
+ hypermodel: keras_tuner HyperModel .
47
46
**kwargs: The args supported by KerasTuner.
48
47
"""
49
48
@@ -52,15 +51,15 @@ def __init__(self, oracle, hypermodel, **kwargs):
52
51
self ._finished = False
53
52
super ().__init__ (oracle , hypermodel , ** kwargs )
54
53
# Save or load the HyperModel.
55
- self .hypermodel .hypermodel . save (os .path .join (self .project_dir , "graph" ))
54
+ self .hypermodel .save (os .path .join (self .project_dir , "graph" ))
56
55
self .hyper_pipeline = None
57
56
58
57
def _populate_initial_space (self ):
59
58
# Override the function to prevent building the model during initialization.
60
59
return
61
60
62
61
def get_best_model (self ):
63
- with hm_module .maybe_distribute (self .distribution_strategy ):
62
+ with keras_tuner . engine . tuner .maybe_distribute (self .distribution_strategy ):
64
63
model = tf .keras .models .load_model (self .best_model_path )
65
64
return model
66
65
@@ -80,27 +79,27 @@ def _prepare_model_build(self, hp, **kwargs):
80
79
pipeline = self .hyper_pipeline .build (hp , dataset )
81
80
pipeline .fit (dataset )
82
81
dataset = pipeline .transform (dataset )
83
- self .hypermodel .hypermodel . set_io_shapes (data_utils .dataset_shape (dataset ))
82
+ self .hypermodel .set_io_shapes (data_utils .dataset_shape (dataset ))
84
83
85
84
if "validation_data" in kwargs :
86
85
validation_data = pipeline .transform (kwargs ["validation_data" ])
87
86
else :
88
87
validation_data = None
89
88
return pipeline , dataset , validation_data
90
89
91
- def _build_and_fit_model (self , trial , fit_args , fit_kwargs ):
90
+ def _build_and_fit_model (self , trial , * args , ** kwargs ):
91
+ model = self .hypermodel .build (trial .hyperparameters )
92
92
(
93
93
pipeline ,
94
- fit_kwargs ["x" ],
95
- fit_kwargs ["validation_data" ],
96
- ) = self ._prepare_model_build (trial .hyperparameters , ** fit_kwargs )
94
+ kwargs ["x" ],
95
+ kwargs ["validation_data" ],
96
+ ) = self ._prepare_model_build (trial .hyperparameters , ** kwargs )
97
97
pipeline .save (self ._pipeline_path (trial .trial_id ))
98
98
99
- model = self .hypermodel .build (trial .hyperparameters )
100
- self .adapt (model , fit_kwargs ["x" ])
99
+ self .adapt (model , kwargs ["x" ])
101
100
102
101
_ , history = utils .fit_with_adaptive_batch_size (
103
- model , self .hypermodel .hypermodel . batch_size , ** fit_kwargs
102
+ model , self .hypermodel .batch_size , ** kwargs
104
103
)
105
104
return history
106
105
@@ -165,7 +164,7 @@ def search(
165
164
if callbacks is None :
166
165
callbacks = []
167
166
168
- self .hypermodel .hypermodel . set_fit_args (validation_split , epochs = epochs )
167
+ self .hypermodel .set_fit_args (validation_split , epochs = epochs )
169
168
170
169
# Insert early-stopping for adaptive number of epochs.
171
170
epochs_provided = True
@@ -216,9 +215,7 @@ def search(
216
215
)
217
216
copied_fit_kwargs .pop ("validation_data" )
218
217
219
- self .hypermodel .hypermodel .set_fit_args (
220
- 0 , epochs = copied_fit_kwargs ["epochs" ]
221
- )
218
+ self .hypermodel .set_fit_args (0 , epochs = copied_fit_kwargs ["epochs" ])
222
219
pipeline , model , history = self .final_fit (** copied_fit_kwargs )
223
220
else :
224
221
# TODO: Add return history functionality in Keras Tuner
@@ -270,7 +267,7 @@ def final_fit(self, **kwargs):
270
267
model = self ._build_best_model ()
271
268
self .adapt (model , kwargs ["x" ])
272
269
model , history = utils .fit_with_adaptive_batch_size (
273
- model , self .hypermodel .hypermodel . batch_size , ** kwargs
270
+ model , self .hypermodel .batch_size , ** kwargs
274
271
)
275
272
return pipeline , model , history
276
273
0 commit comments