Skip to content

Commit 3d6022a

Browse files
authored
Disable loading functions within deserialization. (#21412)
Loading files while loading a model is not allowed.
1 parent be9b002 commit 3d6022a

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

keras/src/saving/serialization_lib.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,27 @@
1818
PLAIN_TYPES = (str, int, float, bool)
1919

2020
# List of Keras modules with built-in string representations for Keras defaults
21-
BUILTIN_MODULES = (
22-
"activations",
23-
"constraints",
24-
"initializers",
25-
"losses",
26-
"metrics",
27-
"optimizers",
28-
"regularizers",
21+
BUILTIN_MODULES = frozenset(
22+
{
23+
"activations",
24+
"constraints",
25+
"initializers",
26+
"losses",
27+
"metrics",
28+
"optimizers",
29+
"regularizers",
30+
}
31+
)
32+
33+
LOADING_APIS = frozenset(
34+
{
35+
"keras.models.load_model",
36+
"keras.preprocessing.image.load_img",
37+
"keras.saving.load_model",
38+
"keras.saving.load_weights",
39+
"keras.utils.get_file",
40+
"keras.utils.load_img",
41+
}
2942
)
3043

3144

@@ -765,6 +778,12 @@ def _retrieve_class_or_fn(
765778
if module == "keras" or module.startswith("keras."):
766779
api_name = module + "." + name
767780

781+
if api_name in LOADING_APIS:
782+
raise ValueError(
783+
f"Cannot deserialize `{api_name}`, loading functions are "
784+
"not allowed during deserialization"
785+
)
786+
768787
obj = api_export.get_symbol_from_name(api_name)
769788
if obj is not None:
770789
return obj

0 commit comments

Comments
 (0)