Skip to content

Commit c289f80

Browse files
committed
Allow per-variable optimizer, add DispatchOptimizer.
- Adds a property `variable.optimizer` that defaults to `None` - Adds a `DispatchOptimizer` that scans the list of trainable variables during build, collects all unique per-variable optimizers, then dispatches the apply/stateless_apply function to the correct optimizer if applicable. - Modifies `trainer` so that during the optimizer build stage, checks if any variables have a custom optimizer attached, and if so inserts a `DispatchOptimizer` to properly handle them. This allows usage to be hidden from the user. Context: for large embedding tables, we need special optimizers to be used so that the tables can be updated in-place, rather than returning large gradients. The layer will handle setting of the custom optimizers, but we need the trainer to be aware of them and dispatch the embedding tables to different optimizers appropriately.
1 parent 8a6e83b commit c289f80

File tree

9 files changed

+547
-7
lines changed

9 files changed

+547
-7
lines changed

keras/src/backend/common/variables.py

+22
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ def __init__(
154154
# whether this variable should be overwritten by the computed gradient.
155155
# Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py
156156
self._overwrite_with_gradient = False
157+
# Per-variable optimizer.
158+
self._optimizer = None
157159
if isinstance(initializer, str):
158160
from keras.src import initializers
159161

@@ -372,6 +374,26 @@ def regularizer(self, value):
372374
)
373375
self._regularizer = value
374376

377+
@property
378+
def optimizer(self):
379+
"""Per-variable custom optimizer."""
380+
return self._optimizer
381+
382+
@optimizer.setter
383+
def optimizer(self, value):
384+
from keras.src import optimizers
385+
386+
if isinstance(value, str):
387+
value = optimizers.get(value)
388+
389+
if value is not None and not isinstance(value, optimizers.Optimizer):
390+
raise ValueError(
391+
"Invalid value for attribute `optimizer`. Expected an "
392+
"instance of `keras.optimizers.Optimizer`, or `None`. "
393+
f"Received: regularizer={value}"
394+
)
395+
self._optimizer = value
396+
375397
@property
376398
def constraint(self):
377399
return self._constraint

keras/src/backend/common/variables_test.py

+13
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from keras.src import backend
99
from keras.src import initializers
1010
from keras.src import ops
11+
from keras.src import optimizers
1112
from keras.src.backend.common import dtypes
1213
from keras.src.backend.common.variables import AutocastScope
1314
from keras.src.backend.common.variables import shape_equal
@@ -419,6 +420,18 @@ def test_deferred_initialize_within_stateless_scope(self):
419420
):
420421
v._deferred_initialize()
421422

423+
def test_optimizer_setter(self):
424+
v = backend.Variable(
425+
initializer=initializers.RandomNormal(),
426+
shape=(2, 2),
427+
)
428+
self.assertIsNone(v.optimizer)
429+
v.optimizer = "sgd"
430+
self.assertTrue(isinstance(v.optimizer, optimizers.Optimizer))
431+
432+
with self.assertRaisesRegex(ValueError, "Invalid value"):
433+
v.optimizer = True
434+
422435

423436
class VariableDtypeShapeNdimRepr(test_case.TestCase):
424437
"""tests for dtype, shape, ndim, __repr__"""

keras/src/backend/tensorflow/optimizer.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,25 @@ def assign_sub(self, variable, value):
7171
else:
7272
variable.assign_sub(value)
7373

74-
def _var_key(self, variable):
74+
def _convert_to_tf_variable(self, variable):
7575
if isinstance(variable, backend.Variable):
76-
variable = variable.value # Convert to tf.Variable
76+
tf_variable = variable.value
77+
# Copy additional properties.
78+
if getattr(variable, "optimizer", None) is not None:
79+
tf_variable.optimizer = variable.optimizer
80+
if getattr(variable, "overwrite_with_gradient", False):
81+
tf_variable.overwrite_with_gradient = True
82+
return tf_variable
83+
elif isinstance(variable, tf.Variable):
84+
return variable
85+
else:
86+
raise ValueError(
87+
f"Variable {variable} must be of type keras.Variable or "
88+
f"tf.Variable, received {value.__class__.__name__}."
89+
)
90+
91+
def _var_key(self, variable):
92+
variable = self._convert_to_tf_variable(variable)
7793
if hasattr(variable, "_distributed_container"):
7894
variable = variable._distributed_container()
7995
elif (
@@ -98,8 +114,7 @@ def weight_decay_fn(variable):
98114
variable.assign_sub(variable * wd * lr)
99115

100116
for variable in variables:
101-
if isinstance(variable, backend.Variable):
102-
variable = variable.value # Convert to tf.Variable
117+
variable = self._convert_to_tf_variable(variable)
103118
distribution.extended.update(
104119
variable, weight_decay_fn, group=False
105120
)

keras/src/optimizers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from keras.src.optimizers.adam import Adam
66
from keras.src.optimizers.adamax import Adamax
77
from keras.src.optimizers.adamw import AdamW
8+
from keras.src.optimizers.dispatch_optimizer import DispatchOptimizer
89
from keras.src.optimizers.ftrl import Ftrl
910
from keras.src.optimizers.lion import Lion
1011
from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer

keras/src/optimizers/base_optimizer.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,27 @@ def iterations(self):
204204
def _track_variable(self, variable):
205205
self._tracker.add_to_store("variables", variable)
206206

207+
def _has_custom_optimizer(self, variable):
208+
return (
209+
hasattr(variable, "optimizer")
210+
and variable.optimizer is not None
211+
and variable.optimizer != self
212+
)
213+
207214
@tracking.no_automatic_dependency_tracking
208215
def build(self, variables):
209216
if self.use_ema:
210217
self._model_variables_moving_average = []
211218
if self.gradient_accumulation_steps:
212219
self._accumulated_gradients = []
213220
for i, variable in enumerate(variables):
221+
if self._has_custom_optimizer(variable):
222+
warnings.warn(
223+
f"Variable {variable} has a custom optimizer "
224+
f"{variable.optimizer} that is being ignored. "
225+
"See `keras.optimizers.DispatchOptimizer` to allow "
226+
"dispatching to the correct per-variable optimizer."
227+
)
214228
self._trainable_variables_indices[self._var_key(variable)] = i
215229
if self.use_ema:
216230
self._model_variables_moving_average.append(
@@ -568,7 +582,7 @@ def stateless_apply(self, optimizer_variables, grads, trainable_variables):
568582
)
569583
if len(trainable_variables) != len(self._trainable_variables):
570584
raise ValueError(
571-
"Argument `optimizer_variables` must be a list of tensors "
585+
"Argument `trainable_variables` must be a list of tensors "
572586
"corresponding 1:1 to the trainable variables list that "
573587
"the optimizer was built with. Received "
574588
f"len(trainable_variables) == {len(trainable_variables)} "

0 commit comments

Comments
 (0)