Skip to content

No mixed precision support with GradientAccumulateOptimizer? #113

@dPys

Description

@dPys

Love this package!

Curious -- is the lack of mixed-precision support for GradientAccumulateOptimizer intentional (e.g. perhaps observed issues with stability?) or is this already something on the to-do list that just has yet to be implemented? if just the latter, I've put together a rough draft for how this might look (assuming the input optimizer is either SGD or ADAM and has already wrapped in the LossScaleOptimizer):

@tf.keras.utils.register_keras_serializable("gradient-accumulator")
class GradientAccumulateOptimizer(opt):
    """Optimizer wrapper for gradient accumulation.
    """

    def __init__(
        self,
        optimizer="SGD",
        accum_steps=1,
        reduction: str = "MEAN",
        mixed_precision: bool = False,
        name: str = "GradientAccumulateOptimizer",
        **kwargs
    ):
        """Construct a new GradientAccumulateOptimizer optimizer.

        Adding support for sparse tensors was tricky, but this resource was
        helpful. Note that you need to implement both _resource_apply_sparse()
        and _resource_apply_sparse_duplicate_indices() for it to work as
        intended.

        See here for more information regarding implementation:
        * https://github.com/tensorflow/addons/blob/master/tensorflow_addons/optimizers/average_wrapper.py#L93  # noqa

        Args:
            optimizer: str or `tf.keras.optimizers.Optimizer` that will be
                used to compute and apply gradients.
            accum_steps: int > 0. Update gradient in every accumulation steps.
            reduction: str. Which gradient reduction method to use. Defaults
                to 'SUM'.
            mixed_precision: bool. Whether to use mixed precision training.
            name: Optional name for the operations created when applying
                gradients. Defaults to "GradientAccumulateOptimizer".
            **kwargs: keyword arguments. Allowed to be {`clipnorm`,
                `clipvalue`, `lr`, `decay`}. `clipnorm` is clip gradients by
                norm; `clipvalue` is clip gradients by value, `decay` is
                included for backward compatibility to allow time inverse
                decay of learning rate. `lr` is included for backward
                compatibility, recommended to use `learning_rate` instead.
        """
        self._optimizer = tf.keras.optimizers.get(optimizer)
        self._accum_steps = accum_steps
        self._reduction = reduction
        self._step = None
        self.mixed_precision = mixed_precision
        super().__init__(name, **kwargs)

    def get_slot(self, *args, **kwargs):
        return self._optimizer.get_slot(*args, **kwargs)

    def add_slot(self, *args, **kwargs):
        return self._optimizer.add_slot(*args, **kwargs)

    def _create_slots(self, var_list):
        """Creates slots for optimizer gradients.

        Args:
            List of trainable variables.
        """
        self._optimizer._create_slots(var_list=var_list)
        for var in var_list:
            self.add_slot(var, "ga")

        self._gradients = [self.get_slot(var, "ga") for var in var_list]

        if self.mixed_precision:
            # Update this line
            self._optimizer.inner_optimizer._create_slots(var_list=var_list)

            for var in var_list:
                self.add_slot(var, "ga", initializer=tf.zeros(var.shape, var.dtype))

            # Initialize Adam optimizer slots if necessary
            if isinstance(
                self._optimizer.inner_optimizer, tf.keras.optimizers.Adam
            ):  # change this line as well
                for var in var_list:
                    self.add_slot(var, "m")
                    self.add_slot(var, "v")

            self._gradients = [self.get_slot(var, "ga") for var in var_list]

    @property
    def step(self):
        """The number of training steps this Optimizer has run.
        Initializes step variable if None.

        Returns:
            Current number of optimizer steps.
        """
        if self._step is None:
            with self._distribution_strategy_scope():
                self._step = self.add_weight(
                    "iter",
                    shape=[],
                    initializer="ones",
                    dtype=tf.int64,
                    trainable=False,
                    aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
                )
            self._weights.append(self._step)
        return self._step

    @step.setter
    def step(self, variable):  # pragma: no cover
        """Sets the step value."""
        if self._step is not None:
            raise RuntimeError(
                "Cannot set `step` to a new Variable after "
                "the Optimizer weights have been created"
            )
        self._step = variable
        self._weights.append(self._step)

    @property
    def gradients(self):  # pragma: no cover
        """The accumulated gradients on the current replica.

        Returns:
            Current gradients in optimizer.
        """
        if not self._gradients:
            raise ValueError(
                "The accumulator should be called first to initialize the" "gradients"
            )
        return list(
            gradient.read_value() if gradient is not None else gradient
            for gradient in self._gradients
        )

    def apply_gradients(self, grads_and_vars, name=None, **kwargs):
        """Updates weights using gradients.

        Args:
            grads_and_vars: dict containing variables and corresponding
                gradients.
            name: name to set when applying gradients.
            **kwargs: keyword arguments.
        Return:
            Updated weights.
        """
        train_op = super().apply_gradients(grads_and_vars, name, **kwargs)
        with tf.control_dependencies([train_op]):
            with tf.control_dependencies(
                [
                    self._optimizer.iterations.assign_add(
                        tf.cast(
                            tf.where(self.step % self._accum_steps == 0, 1, 0),
                            tf.int64,
                        ),
                        read_value=False,
                    )
                ]
            ):
                return self.step.assign_add(1, read_value=False)

    def _resource_apply_dense(self, grad, var, apply_state=None):  # pragma: no cover
        """Performs gradient update on dense tensor.

        Args:
            grad: current gradient.
            var: current variable.
            apply_state: whether to apply X.
        Returns:
            apply_op.
        """

        if self.mixed_precision:
            opt_to_use = self._optimizer.inner_optimizer
        else:
            opt_to_use = self._optimizer

        accum_gradient = self.get_slot(var, "ga")
        if accum_gradient is not None and grad is not None:
            accum_gradient.assign_add(
                tf.math.divide(grad, self._accum_steps),
                use_locking=self._use_locking,
                read_value=False,
            )

        def _apply(accum_gradient, var, apply_state):
            grad = tf.where(
                self.step % self._accum_steps == 0,
                accum_gradient,
                tf.zeros_like(var),
            )

            if self.mixed_precision:
                grad = self.optimizer.get_unscaled_gradients([grad])[0]

            train_op = opt_to_use._resource_apply_dense(
                grad,
                var,
                apply_state if "apply_state" in opt_to_use._sparse_apply_args else None,
            )

            reset_val = tf.where(
                grad == accum_gradient,
                tf.zeros_like(accum_gradient),
                accum_gradient,
            )
            reset_op = accum_gradient.assign(
                reset_val,
                use_locking=self._use_locking,
                read_value=False,
            )

            return tf.group(train_op, reset_op)

        return _apply(accum_gradient, var, apply_state)

    def _resource_apply_sparse(
        self, grad, var, indices, apply_state=None
    ):  # pragma: no cover
        """Performs gradient update on sparse tensor.

        Args:
            grad: current gradient.
            var: current variable.
            indices: relevant indices to be used for masking the sparse tensor
                during update.
        Returns:
            apply_op.
        """

        if self.mixed_precision:
            opt_to_use = self._optimizer.inner_optimizer
        else:
            opt_to_use = self._optimizer

        accum_gradient = self.get_slot(var, "ga")

        if accum_gradient is not None and grad is not None:
            self._resource_scatter_add(
                accum_gradient, indices, tf.math.divide(grad, self._accum_steps)
            )

        def _apply(accum_gradient, var, apply_state):
            grad = tf.where(
                self.step % self._accum_steps == 0,
                accum_gradient,
                tf.zeros_like(var),
            )

            if self.mixed_precision:
                grad = self.optimizer.get_unscaled_gradients([grad])[0]

            train_op = opt_to_use._resource_apply_sparse(
                accum_gradient.sparse_read(indices),
                var,
                apply_state if "apply_state" in opt_to_use._sparse_apply_args else None,
            )

            reset_val = tf.where(
                grad == accum_gradient,
                tf.zeros_like(accum_gradient),
                accum_gradient,
            )
            reset_op = accum_gradient.assign(
                reset_val,
                use_locking=self._use_locking,
                read_value=False,
            )

            return tf.group(train_op, reset_op)

        return _apply(accum_gradient, var, apply_state)

    # TODO: needs to be updated and tested
    def _resource_apply_sparse_duplicate_indices(
        self, grad, var, indices, apply_state=None
    ):  # pragma: no cover
        """Performs gradient update on sparse tensor.

        Args:
            grad: current gradient.
            var: current variable.
            indices: relevant indices to be used for masking the sparse tensor
                during update.
        Returns:
            apply_op.
        """
        if self.mixed_precision:
            opt_to_use = self._optimizer.inner_optimizer
        else:
            opt_to_use = self._optimizer

        accum_gradient = self.get_slot(var, "ga")

        if accum_gradient is not None and grad is not None:
            self._resource_scatter_add(
                accum_gradient, indices, tf.math.divide(grad, self._accum_steps)
            )

        def _apply(accum_gradient, var, apply_state):
            grad = tf.where(
                self.step % self._accum_steps == 0,
                accum_gradient,
                tf.zeros_like(var),
            )

            if self.mixed_precision:
                grad = self.optimizer.get_unscaled_gradients([grad])[0]

            train_op = opt_to_use._resource_apply_sparse_duplicate_indices(
                accum_gradient.sparse_read(indices),
                var,
                apply_state if "apply_state" in opt_to_use._sparse_apply_args else None,
            )

            reset_val = tf.where(
                grad == accum_gradient,
                tf.zeros_like(accum_gradient),
                accum_gradient,
            )
            reset_op = accum_gradient.assign(
                reset_val,
                use_locking=self._use_locking,
                read_value=False,
            )

            return tf.group(train_op, reset_op)

        return _apply(accum_gradient, var, apply_state)

    def reset(self):  # pragma: no cover
        """Resets the accumulated gradients on the current replica."""
        assign_ops = []
        if not self._gradients:
            return assign_ops

        for gradient in self._gradients:
            if gradient is not None:
                assign_ops.append(
                    gradient.assign(
                        tf.zeros_like(gradient),
                        use_locking=self._use_locking,
                        read_value=False,
                    )
                )

        return tf.group(assign_ops)

    @property
    def optimizer(self):
        """The optimizer that this AccumOptimizer is wrapping."""
        return self._optimizer

    @property
    def iterations(self):
        """Returns current iteration value of optimizer.

        Returns:
            iterations of optimizer."""
        return self._optimizer.iterations

    @iterations.setter
    def iterations(self, variable):
        """Sets the iterations value of optimizer."""
        self._optimizer.iterations = variable

    @property
    def learning_rate(self):  # pragma: no cover
        """Returns the learning rate of the optimizer.

        Returns:
            learning rate of optimizer.
        """
        return self._optimizer._get_hyper("learning_rate")

    @learning_rate.setter
    def learning_rate(self, learning_rate):  # pragma: no cover
        """Sets the learning rate of the optimizer.

        Args:
            learning_rate: which learning rate to set in the optimizer.
        """
        self._optimizer._set_hyper("learning_rate", learning_rate)

    def get_config(self):
        """Returns the configuration as dict."""
        config = {
            "optimizer": tf.keras.optimizers.serialize(self._optimizer),
            "accum_steps": self._accum_steps,
            "reduction": self._reduction,
            "mixed_precision": self.mixed_precision,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

    @classmethod
    def from_config(cls, config, custom_objects=None):
        """Gets config of original optimizer and deserializes it."""
        optimizer = tf.keras.optimizers.deserialize(
            config.pop("optimizer"), custom_objects=custom_objects
        )
        return cls(optimizer, **config)

curious to hear your thoughts!

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingenhancementNew feature or request

Projects

Status

To do

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions