Skip to content

Commit e5c45f0

Browse files
Improve load_weights and the test coverage. (#21454)
1 parent 18ab462 commit e5c45f0

File tree

3 files changed

+105
-48
lines changed

3 files changed

+105
-48
lines changed

keras/src/legacy/saving/legacy_h5_format.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,12 +318,14 @@ def save_attributes_to_hdf5_group(group, name, data):
318318
group.attrs[name] = data
319319

320320

321-
def load_weights_from_hdf5_group(f, model):
321+
def load_weights_from_hdf5_group(f, model, skip_mismatch=False):
322322
"""Implements topological (order-based) weight loading.
323323
324324
Args:
325325
f: A pointer to a HDF5 group.
326326
model: Model instance.
327+
skip_mismatch: Boolean, whether to skip loading of weights
328+
where there is a mismatch in the shape of the weights,
327329
328330
Raises:
329331
ValueError: in case of mismatch between provided layers
@@ -379,6 +381,7 @@ def load_weights_from_hdf5_group(f, model):
379381
layer,
380382
symbolic_weights,
381383
weight_values,
384+
skip_mismatch=skip_mismatch,
382385
name=f"layer #{k} (named {layer.name})",
383386
)
384387

@@ -403,6 +406,7 @@ def load_weights_from_hdf5_group(f, model):
403406
model,
404407
symbolic_weights,
405408
weight_values,
409+
skip_mismatch=skip_mismatch,
406410
name="top-level model",
407411
)
408412

keras/src/saving/saving_api.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -249,32 +249,51 @@ def save_weights(
249249
@keras_export("keras.saving.load_weights")
250250
def load_weights(model, filepath, skip_mismatch=False, **kwargs):
251251
filepath_str = str(filepath)
252+
253+
# Get the legacy kwargs.
254+
objects_to_skip = kwargs.pop("objects_to_skip", None)
255+
by_name = kwargs.pop("by_name", None)
256+
if kwargs:
257+
raise ValueError(f"Invalid keyword arguments: {kwargs}")
258+
252259
if filepath_str.endswith(".keras"):
253-
if kwargs:
254-
raise ValueError(f"Invalid keyword arguments: {kwargs}")
260+
if objects_to_skip is not None:
261+
raise ValueError(
262+
"`objects_to_skip` only supports loading '.weights.h5' files."
263+
f"Received: {filepath}"
264+
)
265+
if by_name is not None:
266+
raise ValueError(
267+
"`by_name` only supports loading legacy '.h5' or '.hdf5' "
268+
f"files. Received: {filepath}"
269+
)
255270
saving_lib.load_weights_only(
256271
model, filepath, skip_mismatch=skip_mismatch
257272
)
258273
elif filepath_str.endswith(".weights.h5") or filepath_str.endswith(
259274
".weights.json"
260275
):
261-
objects_to_skip = kwargs.pop("objects_to_skip", None)
262-
if kwargs:
263-
raise ValueError(f"Invalid keyword arguments: {kwargs}")
276+
if by_name is not None:
277+
raise ValueError(
278+
"`by_name` only supports loading legacy '.h5' or '.hdf5' "
279+
f"files. Received: {filepath}"
280+
)
264281
saving_lib.load_weights_only(
265282
model,
266283
filepath,
267284
skip_mismatch=skip_mismatch,
268285
objects_to_skip=objects_to_skip,
269286
)
270287
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
271-
by_name = kwargs.pop("by_name", False)
272-
if kwargs:
273-
raise ValueError(f"Invalid keyword arguments: {kwargs}")
274288
if not h5py:
275289
raise ImportError(
276290
"Loading a H5 file requires `h5py` to be installed."
277291
)
292+
if objects_to_skip is not None:
293+
raise ValueError(
294+
"`objects_to_skip` only supports loading '.weights.h5' files."
295+
f"Received: {filepath}"
296+
)
278297
with h5py.File(filepath, "r") as f:
279298
if "layer_names" not in f.attrs and "model_weights" in f:
280299
f = f["model_weights"]
@@ -283,7 +302,9 @@ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
283302
f, model, skip_mismatch
284303
)
285304
else:
286-
legacy_h5_format.load_weights_from_hdf5_group(f, model)
305+
legacy_h5_format.load_weights_from_hdf5_group(
306+
f, model, skip_mismatch
307+
)
287308
else:
288309
raise ValueError(
289310
f"File format not supported: filepath={filepath}. "

keras/src/saving/saving_api_test.py

Lines changed: 70 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from absl.testing import parameterized
88

99
from keras.src import layers
10+
from keras.src.legacy.saving.legacy_h5_format import save_model_to_hdf5
1011
from keras.src.models import Sequential
1112
from keras.src.saving import saving_api
1213
from keras.src.testing import test_case
@@ -53,7 +54,18 @@ def test_save_h5_format(self):
5354
"""Test saving model in h5 format."""
5455
model = self.get_model()
5556
filepath_h5 = os.path.join(self.get_temp_dir(), "test_model.h5")
56-
saving_api.save_model(model, filepath_h5)
57+
58+
# Verify the warning.
59+
with mock.patch.object(logging, "warning") as mock_warn:
60+
saving_api.save_model(model, filepath_h5)
61+
mock_warn.assert_called_once_with(
62+
"You are saving your model as an HDF5 file via "
63+
"`model.save()` or `keras.saving.save_model(model)`. "
64+
"This file format is considered legacy. "
65+
"We recommend using instead the native Keras format, "
66+
"e.g. `model.save('my_model.keras')` or "
67+
"`keras.saving.save_model(model, 'my_model.keras')`. "
68+
)
5769
self.assertTrue(os.path.exists(filepath_h5))
5870
os.remove(filepath_h5)
5971

@@ -203,18 +215,36 @@ def get_model(self, dtype=None):
203215

204216
@parameterized.named_parameters(
205217
named_product(
218+
save_format=["keras", "weights.h5", "h5"],
206219
source_dtype=["float64", "float32", "float16", "bfloat16"],
207220
dest_dtype=["float64", "float32", "float16", "bfloat16"],
208221
)
209222
)
210-
def test_load_keras_weights(self, source_dtype, dest_dtype):
223+
def test_load_weights(self, save_format, source_dtype, dest_dtype):
211224
"""Test loading keras weights."""
212225
src_model = self.get_model(dtype=source_dtype)
213-
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
214-
src_model.save_weights(filepath)
215-
src_weights = src_model.get_weights()
226+
if save_format == "keras":
227+
filepath = os.path.join(self.get_temp_dir(), "test_weights.keras")
228+
src_model.save(filepath)
229+
elif save_format == "weights.h5":
230+
filepath = os.path.join(
231+
self.get_temp_dir(), "test_weights.weights.h5"
232+
)
233+
src_model.save_weights(filepath)
234+
elif save_format == "h5":
235+
if "bfloat16" in (source_dtype, dest_dtype):
236+
raise self.skipTest(
237+
"bfloat16 dtype is not supported in legacy h5 format."
238+
)
239+
filepath = os.path.join(self.get_temp_dir(), "test_weights.h5")
240+
save_model_to_hdf5(src_model, filepath)
241+
else:
242+
raise ValueError(f"Unsupported save format: {save_format}")
243+
216244
dest_model = self.get_model(dtype=dest_dtype)
217245
dest_model.load_weights(filepath)
246+
247+
src_weights = src_model.get_weights()
218248
dest_weights = dest_model.get_weights()
219249
for orig, loaded in zip(src_weights, dest_weights):
220250
self.assertAllClose(
@@ -224,13 +254,41 @@ def test_load_keras_weights(self, source_dtype, dest_dtype):
224254
rtol=0.01,
225255
)
226256

227-
def test_load_h5_weights_by_name(self):
228-
"""Test loading h5 weights by name."""
229-
model = self.get_model()
230-
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
231-
model.save_weights(filepath)
232-
with self.assertRaisesRegex(ValueError, "Invalid keyword arguments"):
233-
model.load_weights(filepath, by_name=True)
257+
def test_load_weights_invalid_kwargs(self):
258+
src_model = self.get_model()
259+
keras_filepath = os.path.join(self.get_temp_dir(), "test_weights.keras")
260+
weight_h5_filepath = os.path.join(
261+
self.get_temp_dir(), "test_weights.weights.h5"
262+
)
263+
legacy_h5_filepath = os.path.join(
264+
self.get_temp_dir(), "test_weights.h5"
265+
)
266+
src_model.save(keras_filepath)
267+
src_model.save_weights(weight_h5_filepath)
268+
save_model_to_hdf5(src_model, legacy_h5_filepath)
269+
270+
dest_model = self.get_model()
271+
# Test keras file.
272+
with self.assertRaisesRegex(
273+
ValueError, r"only supports loading '.weights.h5' files."
274+
):
275+
dest_model.load_weights(keras_filepath, objects_to_skip=[])
276+
with self.assertRaisesRegex(
277+
ValueError, r"only supports loading legacy '.h5' or '.hdf5' files."
278+
):
279+
dest_model.load_weights(keras_filepath, by_name=True)
280+
with self.assertRaisesRegex(ValueError, r"Invalid keyword arguments"):
281+
dest_model.load_weights(keras_filepath, bad_kwarg=None)
282+
# Test weights.h5 file.
283+
with self.assertRaisesRegex(
284+
ValueError, r"only supports loading legacy '.h5' or '.hdf5' files."
285+
):
286+
dest_model.load_weights(weight_h5_filepath, by_name=True)
287+
# Test h5 file.
288+
with self.assertRaisesRegex(
289+
ValueError, r"only supports loading '.weights.h5' files."
290+
):
291+
dest_model.load_weights(legacy_h5_filepath, objects_to_skip=[])
234292

235293
def test_load_weights_invalid_extension(self):
236294
"""Test loading weights with unsupported extension."""
@@ -251,29 +309,3 @@ def test_load_sharded_weights(self):
251309
dest_weights = dest_model.get_weights()
252310
for orig, loaded in zip(src_weights, dest_weights):
253311
self.assertAllClose(orig, loaded)
254-
255-
256-
class SaveModelTestsWarning(test_case.TestCase):
257-
def get_model(self):
258-
return Sequential(
259-
[
260-
layers.Dense(5, input_shape=(3,)),
261-
layers.Softmax(),
262-
]
263-
)
264-
265-
def test_h5_deprecation_warning(self):
266-
"""Test deprecation warning for h5 format."""
267-
model = self.get_model()
268-
filepath = os.path.join(self.get_temp_dir(), "test_model.h5")
269-
270-
with mock.patch.object(logging, "warning") as mock_warn:
271-
saving_api.save_model(model, filepath)
272-
mock_warn.assert_called_once_with(
273-
"You are saving your model as an HDF5 file via "
274-
"`model.save()` or `keras.saving.save_model(model)`. "
275-
"This file format is considered legacy. "
276-
"We recommend using instead the native Keras format, "
277-
"e.g. `model.save('my_model.keras')` or "
278-
"`keras.saving.save_model(model, 'my_model.keras')`. "
279-
)

0 commit comments

Comments
 (0)