Skip to content

Commit 713172a

Browse files
authored
Only allow deserialization of KerasSaveables by module and name. (#21429)
Arbitrary functions and classes are not allowed. - Made `Operation` extend `KerasSaveable`, this required moving imports to avoid circular imports - `Layer` no longer need to extend `KerasSaveable` directly - Made feature space `Cross` and `Feature` extend `KerasSaveable` - Also dissallow public function `enable_unsafe_deserialization`
1 parent 744b8be commit 713172a

File tree

7 files changed

+61
-16
lines changed

7 files changed

+61
-16
lines changed

keras/src/layers/layer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from keras.src.metrics.metric import Metric
4545
from keras.src.ops.node import Node
4646
from keras.src.ops.operation import Operation
47-
from keras.src.saving.keras_saveable import KerasSaveable
4847
from keras.src.utils import python_utils
4948
from keras.src.utils import summary_utils
5049
from keras.src.utils import traceback_utils
@@ -67,7 +66,7 @@
6766

6867

6968
@keras_export(["keras.Layer", "keras.layers.Layer"])
70-
class Layer(BackendLayer, Operation, KerasSaveable):
69+
class Layer(BackendLayer, Operation):
7170
"""This is the class from which all layers inherit.
7271
7372
A layer is a callable object that takes as input one or more tensors and

keras/src/layers/preprocessing/feature_space.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
77
from keras.src.saving import saving_lib
88
from keras.src.saving import serialization_lib
9+
from keras.src.saving.keras_saveable import KerasSaveable
910
from keras.src.utils import backend_utils
1011
from keras.src.utils.module_utils import tensorflow as tf
1112
from keras.src.utils.naming import auto_name
1213

1314

14-
class Cross:
15+
class Cross(KerasSaveable):
1516
def __init__(self, feature_names, crossing_dim, output_mode="one_hot"):
1617
if output_mode not in {"int", "one_hot"}:
1718
raise ValueError(
@@ -23,6 +24,9 @@ def __init__(self, feature_names, crossing_dim, output_mode="one_hot"):
2324
self.crossing_dim = crossing_dim
2425
self.output_mode = output_mode
2526

27+
def _obj_type(self):
28+
return "Cross"
29+
2630
@property
2731
def name(self):
2832
return "_X_".join(self.feature_names)
@@ -39,7 +43,7 @@ def from_config(cls, config):
3943
return cls(**config)
4044

4145

42-
class Feature:
46+
class Feature(KerasSaveable):
4347
def __init__(self, dtype, preprocessor, output_mode):
4448
if output_mode not in {"int", "one_hot", "float"}:
4549
raise ValueError(
@@ -55,6 +59,9 @@ def __init__(self, dtype, preprocessor, output_mode):
5559
self.preprocessor = preprocessor
5660
self.output_mode = output_mode
5761

62+
def _obj_type(self):
63+
return "Feature"
64+
5865
def get_config(self):
5966
return {
6067
"dtype": self.dtype,

keras/src/legacy/saving/legacy_h5_format.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from absl import logging
77

88
from keras.src import backend
9-
from keras.src import optimizers
109
from keras.src.backend.common import global_state
1110
from keras.src.legacy.saving import json_utils
1211
from keras.src.legacy.saving import saving_options
@@ -161,6 +160,8 @@ def load_model_from_hdf5(filepath, custom_objects=None, compile=True):
161160
# Set optimizer weights.
162161
if "optimizer_weights" in f:
163162
try:
163+
from keras.src import optimizers
164+
164165
if isinstance(model.optimizer, optimizers.Optimizer):
165166
model.optimizer.build(model._trainable_variables)
166167
else:
@@ -249,6 +250,8 @@ def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
249250
hdf5_group: HDF5 group.
250251
optimizer: optimizer instance.
251252
"""
253+
from keras.src import optimizers
254+
252255
if isinstance(optimizer, optimizers.Optimizer):
253256
symbolic_weights = optimizer.variables
254257
else:

keras/src/legacy/saving/saving_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
from absl import logging
55

66
from keras.src import backend
7-
from keras.src import layers
87
from keras.src import losses
98
from keras.src import metrics as metrics_module
10-
from keras.src import models
11-
from keras.src import optimizers
129
from keras.src import tree
1310
from keras.src.legacy.saving import serialization
1411
from keras.src.saving import object_registration
@@ -49,6 +46,9 @@ def model_from_config(config, custom_objects=None):
4946
global MODULE_OBJECTS
5047

5148
if not hasattr(MODULE_OBJECTS, "ALL_OBJECTS"):
49+
from keras.src import layers
50+
from keras.src import models
51+
5252
MODULE_OBJECTS.ALL_OBJECTS = layers.__dict__
5353
MODULE_OBJECTS.ALL_OBJECTS["InputLayer"] = layers.InputLayer
5454
MODULE_OBJECTS.ALL_OBJECTS["Functional"] = models.Functional
@@ -132,6 +132,8 @@ def compile_args_from_training_config(training_config, custom_objects=None):
132132
custom_objects = {}
133133

134134
with object_registration.CustomObjectScope(custom_objects):
135+
from keras.src import optimizers
136+
135137
optimizer_config = training_config["optimizer_config"]
136138
optimizer = optimizers.deserialize(optimizer_config)
137139
# Ensure backwards compatibility for optimizers in legacy H5 files

keras/src/ops/operation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from keras.src.api_export import keras_export
88
from keras.src.backend.common.keras_tensor import any_symbolic_tensors
99
from keras.src.ops.node import Node
10+
from keras.src.saving.keras_saveable import KerasSaveable
1011
from keras.src.utils import python_utils
1112
from keras.src.utils import traceback_utils
1213
from keras.src.utils.naming import auto_name
1314

1415

1516
@keras_export("keras.Operation")
16-
class Operation:
17+
class Operation(KerasSaveable):
1718
def __init__(self, name=None):
1819
if name is None:
1920
name = auto_name(self.__class__.__name__)
@@ -311,6 +312,9 @@ def _get_node_attribute_at_index(self, node_index, attr, attr_name):
311312
else:
312313
return values
313314

315+
def _obj_type(self):
316+
return "Operation"
317+
314318
# Hooks for backend layer classes
315319
def _post_build(self):
316320
"""Can be overridden for per backend post build actions."""

keras/src/saving/saving_lib.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,9 @@
1616

1717
from keras.src import backend
1818
from keras.src.backend.common import global_state
19-
from keras.src.layers.layer import Layer
20-
from keras.src.losses.loss import Loss
21-
from keras.src.metrics.metric import Metric
22-
from keras.src.optimizers.optimizer import Optimizer
2319
from keras.src.saving.serialization_lib import ObjectSharingScope
2420
from keras.src.saving.serialization_lib import deserialize_keras_object
2521
from keras.src.saving.serialization_lib import serialize_keras_object
26-
from keras.src.trainers.compile_utils import CompileMetrics
2722
from keras.src.utils import dtype_utils
2823
from keras.src.utils import file_utils
2924
from keras.src.utils import io_utils
@@ -1584,32 +1579,60 @@ def get_attr_skipset(obj_type):
15841579
"_self_unconditional_dependency_names",
15851580
]
15861581
)
1582+
if obj_type == "Operation":
1583+
from keras.src.ops.operation import Operation
1584+
1585+
ref_obj = Operation()
1586+
skipset.update(dir(ref_obj))
15871587
if obj_type == "Layer":
1588+
from keras.src.layers.layer import Layer
1589+
15881590
ref_obj = Layer()
15891591
skipset.update(dir(ref_obj))
15901592
elif obj_type == "Functional":
1593+
from keras.src.layers.layer import Layer
1594+
15911595
ref_obj = Layer()
15921596
skipset.update(dir(ref_obj) + ["operations", "_operations"])
15931597
elif obj_type == "Sequential":
1598+
from keras.src.layers.layer import Layer
1599+
15941600
ref_obj = Layer()
15951601
skipset.update(dir(ref_obj) + ["_functional"])
15961602
elif obj_type == "Metric":
1603+
from keras.src.metrics.metric import Metric
1604+
from keras.src.trainers.compile_utils import CompileMetrics
1605+
15971606
ref_obj_a = Metric()
15981607
ref_obj_b = CompileMetrics([], [])
15991608
skipset.update(dir(ref_obj_a) + dir(ref_obj_b))
16001609
elif obj_type == "Optimizer":
1610+
from keras.src.optimizers.optimizer import Optimizer
1611+
16011612
ref_obj = Optimizer(1.0)
16021613
skipset.update(dir(ref_obj))
16031614
skipset.remove("variables")
16041615
elif obj_type == "Loss":
1616+
from keras.src.losses.loss import Loss
1617+
16051618
ref_obj = Loss()
16061619
skipset.update(dir(ref_obj))
1620+
elif obj_type == "Cross":
1621+
from keras.src.layers.preprocessing.feature_space import Cross
1622+
1623+
ref_obj = Cross((), 1)
1624+
skipset.update(dir(ref_obj))
1625+
elif obj_type == "Feature":
1626+
from keras.src.layers.preprocessing.feature_space import Feature
1627+
1628+
ref_obj = Feature("int32", lambda x: x, "int")
1629+
skipset.update(dir(ref_obj))
16071630
else:
16081631
raise ValueError(
16091632
f"get_attr_skipset got invalid {obj_type=}. "
16101633
"Accepted values for `obj_type` are "
16111634
"['Layer', 'Functional', 'Sequential', 'Metric', "
1612-
"'Optimizer', 'Loss']"
1635+
"'Optimizer', 'Loss', 'Cross', 'Feature']"
16131636
)
16141637

16151638
global_state.set_global_attribute(

keras/src/saving/serialization_lib.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from keras.src.api_export import keras_export
1313
from keras.src.backend.common import global_state
1414
from keras.src.saving import object_registration
15+
from keras.src.saving.keras_saveable import KerasSaveable
1516
from keras.src.utils import python_utils
1617
from keras.src.utils.module_utils import tensorflow as tf
1718

@@ -32,6 +33,7 @@
3233

3334
LOADING_APIS = frozenset(
3435
{
36+
"keras.config.enable_unsafe_deserialization",
3537
"keras.models.load_model",
3638
"keras.preprocessing.image.load_img",
3739
"keras.saving.load_model",
@@ -817,8 +819,13 @@ def _retrieve_class_or_fn(
817819
try:
818820
mod = importlib.import_module(module)
819821
obj = vars(mod).get(name, None)
820-
if obj is not None:
822+
if isinstance(obj, type) and issubclass(obj, KerasSaveable):
821823
return obj
824+
else:
825+
raise ValueError(
826+
f"Could not deserialize '{module}.{name}' because "
827+
"it is not a KerasSaveable subclass"
828+
)
822829
except ModuleNotFoundError:
823830
raise TypeError(
824831
f"Could not deserialize {obj_type} '{name}' because "

0 commit comments

Comments
 (0)