7
7
from absl .testing import parameterized
8
8
9
9
from keras .src import layers
10
+ from keras .src .legacy .saving .legacy_h5_format import save_model_to_hdf5
10
11
from keras .src .models import Sequential
11
12
from keras .src .saving import saving_api
12
13
from keras .src .testing import test_case
@@ -53,7 +54,18 @@ def test_save_h5_format(self):
53
54
"""Test saving model in h5 format."""
54
55
model = self .get_model ()
55
56
filepath_h5 = os .path .join (self .get_temp_dir (), "test_model.h5" )
56
- saving_api .save_model (model , filepath_h5 )
57
+
58
+ # Verify the warning.
59
+ with mock .patch .object (logging , "warning" ) as mock_warn :
60
+ saving_api .save_model (model , filepath_h5 )
61
+ mock_warn .assert_called_once_with (
62
+ "You are saving your model as an HDF5 file via "
63
+ "`model.save()` or `keras.saving.save_model(model)`. "
64
+ "This file format is considered legacy. "
65
+ "We recommend using instead the native Keras format, "
66
+ "e.g. `model.save('my_model.keras')` or "
67
+ "`keras.saving.save_model(model, 'my_model.keras')`. "
68
+ )
57
69
self .assertTrue (os .path .exists (filepath_h5 ))
58
70
os .remove (filepath_h5 )
59
71
@@ -203,18 +215,36 @@ def get_model(self, dtype=None):
203
215
204
216
@parameterized .named_parameters (
205
217
named_product (
218
+ save_format = ["keras" , "weights.h5" , "h5" ],
206
219
source_dtype = ["float64" , "float32" , "float16" , "bfloat16" ],
207
220
dest_dtype = ["float64" , "float32" , "float16" , "bfloat16" ],
208
221
)
209
222
)
210
- def test_load_keras_weights (self , source_dtype , dest_dtype ):
223
+ def test_load_weights (self , save_format , source_dtype , dest_dtype ):
211
224
"""Test loading keras weights."""
212
225
src_model = self .get_model (dtype = source_dtype )
213
- filepath = os .path .join (self .get_temp_dir (), "test_weights.weights.h5" )
214
- src_model .save_weights (filepath )
215
- src_weights = src_model .get_weights ()
226
+ if save_format == "keras" :
227
+ filepath = os .path .join (self .get_temp_dir (), "test_weights.keras" )
228
+ src_model .save (filepath )
229
+ elif save_format == "weights.h5" :
230
+ filepath = os .path .join (
231
+ self .get_temp_dir (), "test_weights.weights.h5"
232
+ )
233
+ src_model .save_weights (filepath )
234
+ elif save_format == "h5" :
235
+ if "bfloat16" in (source_dtype , dest_dtype ):
236
+ raise self .skipTest (
237
+ "bfloat16 dtype is not supported in legacy h5 format."
238
+ )
239
+ filepath = os .path .join (self .get_temp_dir (), "test_weights.h5" )
240
+ save_model_to_hdf5 (src_model , filepath )
241
+ else :
242
+ raise ValueError (f"Unsupported save format: { save_format } " )
243
+
216
244
dest_model = self .get_model (dtype = dest_dtype )
217
245
dest_model .load_weights (filepath )
246
+
247
+ src_weights = src_model .get_weights ()
218
248
dest_weights = dest_model .get_weights ()
219
249
for orig , loaded in zip (src_weights , dest_weights ):
220
250
self .assertAllClose (
@@ -224,13 +254,41 @@ def test_load_keras_weights(self, source_dtype, dest_dtype):
224
254
rtol = 0.01 ,
225
255
)
226
256
227
- def test_load_h5_weights_by_name (self ):
228
- """Test loading h5 weights by name."""
229
- model = self .get_model ()
230
- filepath = os .path .join (self .get_temp_dir (), "test_weights.weights.h5" )
231
- model .save_weights (filepath )
232
- with self .assertRaisesRegex (ValueError , "Invalid keyword arguments" ):
233
- model .load_weights (filepath , by_name = True )
257
+ def test_load_weights_invalid_kwargs (self ):
258
+ src_model = self .get_model ()
259
+ keras_filepath = os .path .join (self .get_temp_dir (), "test_weights.keras" )
260
+ weight_h5_filepath = os .path .join (
261
+ self .get_temp_dir (), "test_weights.weights.h5"
262
+ )
263
+ legacy_h5_filepath = os .path .join (
264
+ self .get_temp_dir (), "test_weights.h5"
265
+ )
266
+ src_model .save (keras_filepath )
267
+ src_model .save_weights (weight_h5_filepath )
268
+ save_model_to_hdf5 (src_model , legacy_h5_filepath )
269
+
270
+ dest_model = self .get_model ()
271
+ # Test keras file.
272
+ with self .assertRaisesRegex (
273
+ ValueError , r"only supports loading '.weights.h5' files."
274
+ ):
275
+ dest_model .load_weights (keras_filepath , objects_to_skip = [])
276
+ with self .assertRaisesRegex (
277
+ ValueError , r"only supports loading legacy '.h5' or '.hdf5' files."
278
+ ):
279
+ dest_model .load_weights (keras_filepath , by_name = True )
280
+ with self .assertRaisesRegex (ValueError , r"Invalid keyword arguments" ):
281
+ dest_model .load_weights (keras_filepath , bad_kwarg = None )
282
+ # Test weights.h5 file.
283
+ with self .assertRaisesRegex (
284
+ ValueError , r"only supports loading legacy '.h5' or '.hdf5' files."
285
+ ):
286
+ dest_model .load_weights (weight_h5_filepath , by_name = True )
287
+ # Test h5 file.
288
+ with self .assertRaisesRegex (
289
+ ValueError , r"only supports loading '.weights.h5' files."
290
+ ):
291
+ dest_model .load_weights (legacy_h5_filepath , objects_to_skip = [])
234
292
235
293
def test_load_weights_invalid_extension (self ):
236
294
"""Test loading weights with unsupported extension."""
@@ -251,29 +309,3 @@ def test_load_sharded_weights(self):
251
309
dest_weights = dest_model .get_weights ()
252
310
for orig , loaded in zip (src_weights , dest_weights ):
253
311
self .assertAllClose (orig , loaded )
254
-
255
-
256
- class SaveModelTestsWarning (test_case .TestCase ):
257
- def get_model (self ):
258
- return Sequential (
259
- [
260
- layers .Dense (5 , input_shape = (3 ,)),
261
- layers .Softmax (),
262
- ]
263
- )
264
-
265
- def test_h5_deprecation_warning (self ):
266
- """Test deprecation warning for h5 format."""
267
- model = self .get_model ()
268
- filepath = os .path .join (self .get_temp_dir (), "test_model.h5" )
269
-
270
- with mock .patch .object (logging , "warning" ) as mock_warn :
271
- saving_api .save_model (model , filepath )
272
- mock_warn .assert_called_once_with (
273
- "You are saving your model as an HDF5 file via "
274
- "`model.save()` or `keras.saving.save_model(model)`. "
275
- "This file format is considered legacy. "
276
- "We recommend using instead the native Keras format, "
277
- "e.g. `model.save('my_model.keras')` or "
278
- "`keras.saving.save_model(model, 'my_model.keras')`. "
279
- )
0 commit comments