Skip to content

Commit 6a54c36

Browse files
Update rms_normalization and layer_normalization. (#21438)
- Add tests. - Fix the numeric stability issue in half precision training. - Update the signature of the Operation.
1 parent ff17868 commit 6a54c36

File tree

2 files changed

+165
-94
lines changed

2 files changed

+165
-94
lines changed

keras/src/ops/nn.py

Lines changed: 79 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from keras.src.backend.common.backend_utils import (
1111
compute_conv_transpose_output_shape,
1212
)
13-
from keras.src.backend.common.keras_tensor import is_keras_tensor
1413
from keras.src.ops import operation_utils
1514
from keras.src.ops.operation import Operation
1615
from keras.src.ops.operation_utils import reduce_shape
@@ -2753,18 +2752,17 @@ def dot_product_attention(
27532752

27542753

27552754
class RMSNorm(Operation):
2756-
def __init__(self, scale=1, axis=-1, epsilon=None, *, name=None):
2755+
def __init__(self, axis=-1, epsilon=None, *, name=None):
27572756
super().__init__(name=name)
27582757
self.axis = axis
2759-
self.scale = scale
27602758
self.epsilon = epsilon
27612759

2762-
def compute_output_spec(self, x):
2763-
return KerasTensor(shape=x.shape)
2760+
def compute_output_spec(self, x, scale):
2761+
return KerasTensor(shape=x.shape, dtype=x.dtype)
27642762

2765-
def call(self, x):
2763+
def call(self, x, scale=None):
27662764
return _rms_normalization(
2767-
x, scale=self.scale, axis=self.axis, epsilon=self.epsilon
2765+
x, scale=scale, axis=self.axis, epsilon=self.epsilon
27682766
)
27692767

27702768

@@ -2774,7 +2772,7 @@ def call(self, x):
27742772
"keras.ops.nn.rms_normalization",
27752773
]
27762774
)
2777-
def rms_normalization(x, scale=1, axis=-1, epsilon=None):
2775+
def rms_normalization(x, scale=None, axis=-1, epsilon=None):
27782776
"""Performs Root Mean Square (RMS) normalization on `x`.
27792777
27802778
The Keras operation implements the operation as described in
@@ -2787,81 +2785,77 @@ def rms_normalization(x, scale=1, axis=-1, epsilon=None):
27872785
27882786
Args:
27892787
x: Input tensor.
2790-
axis: The axis or axes along which to perform normalization.
2791-
Default to -1.
27922788
scale: Optional scaling factor for the normalization.
2793-
epsilon: A lower bound value for the norm.
2794-
Defaults to `backend.epsilon()`.
2789+
axis: The axis or axes along which to perform normalization. Defaults
2790+
to `-1`.
2791+
epsilon: A lower bound value for the norm. Defaults to
2792+
`backend.epsilon()`.
27952793
27962794
Returns:
27972795
The normalized array.
27982796
27992797
Example:
28002798
2801-
>>> x = np.random.rand(1, 10)
2802-
>>> x_norm = keras.ops.rms_normalization(x, (10,))
2803-
>>> print(x_norm)
2799+
>>> x = keras.random.normal((1, 10))
2800+
>>> keras.ops.rms_normalization(x)
28042801
array([[0.69384296, 0.94444374, 0.16551171, 0.05749961, 1.11008865,
2805-
0.52475186, 1.57686807, 1.69893307, 1.27292764, 0.30819128]])
2802+
0.52475186, 1.57686807, 1.69893307, 1.27292764, 0.30819128]])
28062803
"""
2807-
if any_symbolic_tensors((x,)):
2808-
return RMSNorm(scale=scale, axis=axis, epsilon=epsilon).symbolic_call(x)
2804+
if any_symbolic_tensors((x, scale)):
2805+
return RMSNorm(axis=axis, epsilon=epsilon).symbolic_call(x, scale=scale)
28092806
return _rms_normalization(x, scale=scale, axis=axis, epsilon=epsilon)
28102807

28112808

2812-
def _rms_normalization(x, scale=1, axis=-1, epsilon=None):
2809+
def _rms_normalization(x, scale=None, axis=-1, epsilon=None):
2810+
if epsilon is None:
2811+
epsilon = backend.epsilon()
2812+
original_dtype = backend.standardize_dtype(x.dtype)
2813+
# Computes in at least float32 precision for stability in half precision
2814+
# training.
2815+
compute_dtype = backend.result_type(x.dtype, "float32")
2816+
2817+
x = backend.convert_to_tensor(x, dtype=compute_dtype)
2818+
if scale is not None:
2819+
scale = backend.convert_to_tensor(scale, x.dtype)
2820+
28132821
if backend.backend() == "torch" and is_continuous_axis(axis):
28142822
import torch.nn.functional as F
28152823

28162824
if isinstance(axis, (tuple, list)):
28172825
normalized_shape = tuple([x.shape[dim] for dim in axis])
28182826
else:
2819-
normalized_shape = x.shape[axis]
2820-
return F.rms_norm(x, normalized_shape, scale, epsilon)
2821-
x = backend.convert_to_tensor(x)
2822-
if len(x.shape) == 0:
2823-
x = backend.numpy.expand_dims(x, axis=0)
2824-
if epsilon is None:
2825-
epsilon = backend.epsilon()
2826-
2827-
if not is_keras_tensor(scale):
2828-
scale = backend.convert_to_tensor(scale, dtype=x.dtype)
2829-
if not is_keras_tensor(epsilon):
2830-
epsilon = backend.convert_to_tensor(epsilon, dtype=x.dtype)
2831-
2832-
rrms = backend.math.rsqrt(
2833-
backend.numpy.mean(backend.numpy.square(x), axis=axis, keepdims=True)
2834-
+ epsilon
2835-
)
2836-
return (x * rrms) * scale
2827+
normalized_shape = (x.shape[axis],)
2828+
outputs = F.rms_norm(x, normalized_shape, scale, epsilon)
2829+
else:
2830+
if len(x.shape) == 0:
2831+
x = backend.numpy.expand_dims(x, axis=0)
2832+
rrms = backend.math.rsqrt(
2833+
backend.numpy.mean(
2834+
backend.numpy.square(x), axis=axis, keepdims=True
2835+
)
2836+
+ epsilon
2837+
)
2838+
outputs = backend.numpy.multiply(x, rrms)
2839+
if scale is not None:
2840+
outputs = backend.numpy.multiply(outputs, scale)
2841+
return backend.cast(outputs, original_dtype)
28372842

28382843

28392844
class LayerNorm(Operation):
2840-
def __init__(
2841-
self,
2842-
gamma=None,
2843-
beta=None,
2844-
axis=-1,
2845-
epsilon=None,
2846-
rms_scaling=False,
2847-
*,
2848-
name=None,
2849-
):
2845+
def __init__(self, axis=-1, epsilon=None, rms_scaling=False, *, name=None):
28502846
super().__init__(name=name)
28512847
self.axis = axis
2852-
self.gamma = gamma
2853-
self.beta = beta
28542848
self.epsilon = epsilon
28552849
self.rms_scaling = rms_scaling
28562850

2857-
def compute_output_spec(self, x):
2858-
return KerasTensor(shape=x.shape)
2851+
def compute_output_spec(self, x, gamma, beta):
2852+
return KerasTensor(shape=x.shape, dtype=x.dtype)
28592853

2860-
def call(self, x):
2861-
return _rms_normalization(
2854+
def call(self, x, gamma=None, beta=None):
2855+
return _layer_normalization(
28622856
x,
2863-
gamma=self.gamma,
2864-
beta=self.beta,
2857+
gamma=gamma,
2858+
beta=beta,
28652859
axis=self.axis,
28662860
epsilon=self.epsilon,
28672861
rms_scaling=self.rms_scaling,
@@ -2883,21 +2877,24 @@ def layer_normalization(
28832877
batch independently, rather than across a batch like Batch Normalization.
28842878
i.e. applies a transformation that maintains the mean activation within each
28852879
example close to 0 and the activation standard deviation close to 1.
2880+
28862881
Args:
28872882
x: Input tensor.
2888-
axis: The axis or axes along which to perform normalization.
2889-
Default to -1.
28902883
gamma: Optional scaling factor for the normalization.
28912884
beta: Optional add offset for the normalized tensor.
2885+
axis: The axis or axes along which to perform normalization. Default to
2886+
`-1`.
28922887
epsilon: A lower bound value for the norm.
28932888
Defaults to `backend.epsilon()`.
28942889
28952890
Returns:
28962891
The normalized array.
2897-
>>> x = ops.arange(5,dtype = "float32")
2898-
>>> x_norm = ops.layer_normalization(x)
2899-
>>> print(x_norm)
2900-
array([-1.4142135 , -0.70710677, 0., 0.7071067 , 1.4142135 ])
2892+
2893+
Example:
2894+
2895+
>>> x = keras.ops.arange(5, dtype="float32")
2896+
>>> keras.ops.layer_normalization(x)
2897+
array([-1.4142135, -0.70710677, 0.0, 0.7071067, 1.4142135])
29012898
"""
29022899
rms_scaling = kwargs.pop("rms_scaling", False)
29032900
if rms_scaling:
@@ -2909,14 +2906,10 @@ def layer_normalization(
29092906
"instead."
29102907
)
29112908

2912-
if any_symbolic_tensors((x,)):
2909+
if any_symbolic_tensors((x, gamma, beta)):
29132910
return LayerNorm(
2914-
gamma=gamma,
2915-
beta=beta,
2916-
axis=axis,
2917-
epsilon=epsilon,
2918-
rms_scaling=rms_scaling,
2919-
).symbolic_call(x)
2911+
axis=axis, epsilon=epsilon, rms_scaling=rms_scaling
2912+
).symbolic_call(x, gamma, beta)
29202913
return _layer_normalization(
29212914
x,
29222915
gamma=gamma,
@@ -2928,12 +2921,21 @@ def layer_normalization(
29282921

29292922

29302923
def _layer_normalization(
2931-
inputs, gamma=None, beta=None, axis=-1, epsilon=None, rms_scaling=False
2924+
x, gamma=None, beta=None, axis=-1, epsilon=None, rms_scaling=False
29322925
):
2933-
compute_dtype = backend.result_type(inputs.dtype, "float32")
2934-
# LN is prone to overflow with float16/bfloat16 inputs, so we upcast to
2935-
# float32 for the subsequent computations.
2936-
x = backend.cast(inputs, compute_dtype)
2926+
if epsilon is None:
2927+
epsilon = backend.epsilon()
2928+
original_dtype = backend.standardize_dtype(x.dtype)
2929+
# Computes in at least float32 precision for stability in half precision
2930+
# training.
2931+
compute_dtype = backend.result_type(x.dtype, "float32")
2932+
2933+
x = backend.convert_to_tensor(x, dtype=compute_dtype)
2934+
if gamma is not None:
2935+
gamma = backend.convert_to_tensor(gamma, x.dtype)
2936+
if beta is not None:
2937+
beta = backend.convert_to_tensor(beta, x.dtype)
2938+
29372939
# Compute the axes along which to reduce the mean / variance
29382940
input_shape = x.shape
29392941
ndims = len(input_shape)
@@ -2951,16 +2953,12 @@ def _broadcast(v):
29512953
return backend.numpy.reshape(v, broadcast_shape)
29522954
return v
29532955

2954-
if epsilon is None:
2955-
epsilon = backend.epsilon()
2956-
29572956
if rms_scaling:
2958-
# Calculate outputs with only variance and gamma if rms scaling
2959-
# is enabled
2960-
# Calculate the variance along self.axis (layer activations).
29612957
variance = backend.numpy.var(x, axis=axis, keepdims=True)
29622958
inv = backend.math.rsqrt(variance + epsilon)
2963-
outputs = x * inv * backend.cast(_broadcast(gamma), x.dtype)
2959+
outputs = outputs = x * inv
2960+
if gamma is not None:
2961+
outputs = outputs * backend.cast(_broadcast(gamma), x.dtype)
29642962
elif backend.config.backend() == "torch" and is_continuous_axis(axis):
29652963
# when using torch backend,use kernel to improve performance
29662964
import torch.nn.functional as F
@@ -2973,16 +2971,14 @@ def _broadcast(v):
29732971
gamma, beta = _broadcast(gamma), _broadcast(beta)
29742972
inv = backend.math.rsqrt(variance + epsilon)
29752973
if gamma is not None:
2976-
gamma = backend.cast(gamma, x.dtype)
29772974
inv = inv * gamma
29782975

29792976
res = -mean * inv
29802977
if beta is not None:
2981-
beta = backend.cast(beta, x.dtype)
29822978
res = res + beta
29832979

29842980
outputs = x * inv + res
2985-
return backend.cast(outputs, inputs.dtype)
2981+
return backend.cast(outputs, original_dtype)
29862982

29872983

29882984
class Polar(Operation):

0 commit comments

Comments
 (0)