Skip to content

Commit 764ed95

Browse files
authored
Raise warning when rms_scaling = True (#21352)
* Fix rms_scaling in LayerNormalization * Fix numerics UT * Fix numerics UT (1) * Fix numerics UT (2) * Add warning for rms_scaling = True * Add warning to LayerNormalization layer * Remove unnecessary comments
1 parent 6fd736c commit 764ed95

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

keras/src/layers/normalization/layer_normalization.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from keras.src import constraints
24
from keras.src import initializers
35
from keras.src import ops
@@ -82,12 +84,6 @@ class LayerNormalization(Layer):
8284
When the next layer is linear (also e.g. `nn.relu`), this can be
8385
disabled since the scaling will be done by the next layer.
8486
Defaults to `True`.
85-
rms_scaling: If True, `center` and `scale` are ignored, and the
86-
inputs are scaled by `gamma` and the inverse square root
87-
of the square of all inputs. This is an approximate and faster
88-
approach that avoids ever computing the mean of the input. Note that
89-
this *isn't* equivalent to the computation that the
90-
`keras.layers.RMSNormalization` layer performs.
9187
beta_initializer: Initializer for the beta weight. Defaults to zeros.
9288
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
9389
beta_regularizer: Optional regularizer for the beta weight.
@@ -112,7 +108,6 @@ def __init__(
112108
epsilon=1e-3,
113109
center=True,
114110
scale=True,
115-
rms_scaling=False,
116111
beta_initializer="zeros",
117112
gamma_initializer="ones",
118113
beta_regularizer=None,
@@ -121,6 +116,15 @@ def __init__(
121116
gamma_constraint=None,
122117
**kwargs,
123118
):
119+
rms_scaling = kwargs.pop("rms_scaling", False)
120+
if rms_scaling:
121+
warnings.warn(
122+
"You passed `rms_scaling=True`, which is deprecated. This "
123+
"argument incorrectly scales the input by the variance, not "
124+
"the root mean square. To correctly use RMS Normalization, "
125+
"please use `keras.layers.RMSNormalization` instead."
126+
)
127+
124128
super().__init__(**kwargs)
125129
if isinstance(axis, (list, tuple)):
126130
self.axis = list(axis)
@@ -185,7 +189,7 @@ def call(self, inputs):
185189
self.beta,
186190
self.axis,
187191
self.epsilon,
188-
self.rms_scaling,
192+
rms_scaling=self.rms_scaling,
189193
)
190194
return ops.cast(outputs, self.compute_dtype)
191195

keras/src/ops/nn.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2875,7 +2875,7 @@ def call(self, x):
28752875
]
28762876
)
28772877
def layer_normalization(
2878-
x, gamma=None, beta=None, axis=-1, epsilon=None, rms_scaling=False
2878+
x, gamma=None, beta=None, axis=-1, epsilon=None, **kwargs
28792879
):
28802880
"""Layer normalization layer (Ba et al., 2016).
28812881
@@ -2889,9 +2889,6 @@ def layer_normalization(
28892889
Default to -1.
28902890
gamma: Optional scaling factor for the normalization.
28912891
beta: Optional add offset for the normalized tensor.
2892-
rms_scaling:This is an approximate and faster
2893-
approach that avoids ever computing the mean of the input. Note that
2894-
this *isn't* equivalent to the computation that rms_normalization
28952892
epsilon: A lower bound value for the norm.
28962893
Defaults to `backend.epsilon()`.
28972894
@@ -2902,6 +2899,16 @@ def layer_normalization(
29022899
>>> print(x_norm)
29032900
array([-1.4142135 , -0.70710677, 0., 0.7071067 , 1.4142135 ])
29042901
"""
2902+
rms_scaling = kwargs.pop("rms_scaling", False)
2903+
if rms_scaling:
2904+
warnings.warn(
2905+
"You passed `rms_scaling=True`, which is deprecated. This argument "
2906+
"incorrectly scales the input by the variance, not the root mean "
2907+
"square. To correctly use RMS Normalization, please use "
2908+
"`keras.ops.rms_normalization` / `keras.ops.nn.rms_normalization` "
2909+
"instead."
2910+
)
2911+
29052912
if any_symbolic_tensors((x,)):
29062913
return LayerNorm(
29072914
gamma=gamma,
@@ -2953,7 +2960,6 @@ def _broadcast(v):
29532960
# Calculate the variance along self.axis (layer activations).
29542961
variance = backend.numpy.var(x, axis=axis, keepdims=True)
29552962
inv = backend.math.rsqrt(variance + epsilon)
2956-
29572963
outputs = x * inv * backend.cast(_broadcast(gamma), x.dtype)
29582964
elif backend.config.backend() == "torch" and is_continuous_axis(axis):
29592965
# when using torch backend,use kernel to improve performance

0 commit comments

Comments
 (0)