10
10
from keras .src .backend .common .backend_utils import (
11
11
compute_conv_transpose_output_shape ,
12
12
)
13
- from keras .src .backend .common .keras_tensor import is_keras_tensor
14
13
from keras .src .ops import operation_utils
15
14
from keras .src .ops .operation import Operation
16
15
from keras .src .ops .operation_utils import reduce_shape
@@ -2753,18 +2752,17 @@ def dot_product_attention(
2753
2752
2754
2753
2755
2754
class RMSNorm (Operation ):
2756
- def __init__ (self , scale = 1 , axis = - 1 , epsilon = None , * , name = None ):
2755
+ def __init__ (self , axis = - 1 , epsilon = None , * , name = None ):
2757
2756
super ().__init__ (name = name )
2758
2757
self .axis = axis
2759
- self .scale = scale
2760
2758
self .epsilon = epsilon
2761
2759
2762
- def compute_output_spec (self , x ):
2763
- return KerasTensor (shape = x .shape )
2760
+ def compute_output_spec (self , x , scale ):
2761
+ return KerasTensor (shape = x .shape , dtype = x . dtype )
2764
2762
2765
- def call (self , x ):
2763
+ def call (self , x , scale = None ):
2766
2764
return _rms_normalization (
2767
- x , scale = self . scale , axis = self .axis , epsilon = self .epsilon
2765
+ x , scale = scale , axis = self .axis , epsilon = self .epsilon
2768
2766
)
2769
2767
2770
2768
@@ -2774,7 +2772,7 @@ def call(self, x):
2774
2772
"keras.ops.nn.rms_normalization" ,
2775
2773
]
2776
2774
)
2777
- def rms_normalization (x , scale = 1 , axis = - 1 , epsilon = None ):
2775
+ def rms_normalization (x , scale = None , axis = - 1 , epsilon = None ):
2778
2776
"""Performs Root Mean Square (RMS) normalization on `x`.
2779
2777
2780
2778
The Keras operation implements the operation as described in
@@ -2787,81 +2785,77 @@ def rms_normalization(x, scale=1, axis=-1, epsilon=None):
2787
2785
2788
2786
Args:
2789
2787
x: Input tensor.
2790
- axis: The axis or axes along which to perform normalization.
2791
- Default to -1.
2792
2788
scale: Optional scaling factor for the normalization.
2793
- epsilon: A lower bound value for the norm.
2794
- Defaults to `backend.epsilon()`.
2789
+ axis: The axis or axes along which to perform normalization. Defaults
2790
+ to `-1`.
2791
+ epsilon: A lower bound value for the norm. Defaults to
2792
+ `backend.epsilon()`.
2795
2793
2796
2794
Returns:
2797
2795
The normalized array.
2798
2796
2799
2797
Example:
2800
2798
2801
- >>> x = np.random.rand(1, 10)
2802
- >>> x_norm = keras.ops.rms_normalization(x, (10,))
2803
- >>> print(x_norm)
2799
+ >>> x = keras.random.normal((1, 10))
2800
+ >>> keras.ops.rms_normalization(x)
2804
2801
array([[0.69384296, 0.94444374, 0.16551171, 0.05749961, 1.11008865,
2805
- 0.52475186, 1.57686807, 1.69893307, 1.27292764, 0.30819128]])
2802
+ 0.52475186, 1.57686807, 1.69893307, 1.27292764, 0.30819128]])
2806
2803
"""
2807
- if any_symbolic_tensors ((x ,)):
2808
- return RMSNorm (scale = scale , axis = axis , epsilon = epsilon ).symbolic_call (x )
2804
+ if any_symbolic_tensors ((x , scale )):
2805
+ return RMSNorm (axis = axis , epsilon = epsilon ).symbolic_call (x , scale = scale )
2809
2806
return _rms_normalization (x , scale = scale , axis = axis , epsilon = epsilon )
2810
2807
2811
2808
2812
- def _rms_normalization (x , scale = 1 , axis = - 1 , epsilon = None ):
2809
+ def _rms_normalization (x , scale = None , axis = - 1 , epsilon = None ):
2810
+ if epsilon is None :
2811
+ epsilon = backend .epsilon ()
2812
+ original_dtype = backend .standardize_dtype (x .dtype )
2813
+ # Computes in at least float32 precision for stability in half precision
2814
+ # training.
2815
+ compute_dtype = backend .result_type (x .dtype , "float32" )
2816
+
2817
+ x = backend .convert_to_tensor (x , dtype = compute_dtype )
2818
+ if scale is not None :
2819
+ scale = backend .convert_to_tensor (scale , x .dtype )
2820
+
2813
2821
if backend .backend () == "torch" and is_continuous_axis (axis ):
2814
2822
import torch .nn .functional as F
2815
2823
2816
2824
if isinstance (axis , (tuple , list )):
2817
2825
normalized_shape = tuple ([x .shape [dim ] for dim in axis ])
2818
2826
else :
2819
- normalized_shape = x .shape [axis ]
2820
- return F .rms_norm (x , normalized_shape , scale , epsilon )
2821
- x = backend .convert_to_tensor (x )
2822
- if len (x .shape ) == 0 :
2823
- x = backend .numpy .expand_dims (x , axis = 0 )
2824
- if epsilon is None :
2825
- epsilon = backend .epsilon ()
2826
-
2827
- if not is_keras_tensor (scale ):
2828
- scale = backend .convert_to_tensor (scale , dtype = x .dtype )
2829
- if not is_keras_tensor (epsilon ):
2830
- epsilon = backend .convert_to_tensor (epsilon , dtype = x .dtype )
2831
-
2832
- rrms = backend .math .rsqrt (
2833
- backend .numpy .mean (backend .numpy .square (x ), axis = axis , keepdims = True )
2834
- + epsilon
2835
- )
2836
- return (x * rrms ) * scale
2827
+ normalized_shape = (x .shape [axis ],)
2828
+ outputs = F .rms_norm (x , normalized_shape , scale , epsilon )
2829
+ else :
2830
+ if len (x .shape ) == 0 :
2831
+ x = backend .numpy .expand_dims (x , axis = 0 )
2832
+ rrms = backend .math .rsqrt (
2833
+ backend .numpy .mean (
2834
+ backend .numpy .square (x ), axis = axis , keepdims = True
2835
+ )
2836
+ + epsilon
2837
+ )
2838
+ outputs = backend .numpy .multiply (x , rrms )
2839
+ if scale is not None :
2840
+ outputs = backend .numpy .multiply (outputs , scale )
2841
+ return backend .cast (outputs , original_dtype )
2837
2842
2838
2843
2839
2844
class LayerNorm (Operation ):
2840
- def __init__ (
2841
- self ,
2842
- gamma = None ,
2843
- beta = None ,
2844
- axis = - 1 ,
2845
- epsilon = None ,
2846
- rms_scaling = False ,
2847
- * ,
2848
- name = None ,
2849
- ):
2845
+ def __init__ (self , axis = - 1 , epsilon = None , rms_scaling = False , * , name = None ):
2850
2846
super ().__init__ (name = name )
2851
2847
self .axis = axis
2852
- self .gamma = gamma
2853
- self .beta = beta
2854
2848
self .epsilon = epsilon
2855
2849
self .rms_scaling = rms_scaling
2856
2850
2857
- def compute_output_spec (self , x ):
2858
- return KerasTensor (shape = x .shape )
2851
+ def compute_output_spec (self , x , gamma , beta ):
2852
+ return KerasTensor (shape = x .shape , dtype = x . dtype )
2859
2853
2860
- def call (self , x ):
2861
- return _rms_normalization (
2854
+ def call (self , x , gamma = None , beta = None ):
2855
+ return _layer_normalization (
2862
2856
x ,
2863
- gamma = self . gamma ,
2864
- beta = self . beta ,
2857
+ gamma = gamma ,
2858
+ beta = beta ,
2865
2859
axis = self .axis ,
2866
2860
epsilon = self .epsilon ,
2867
2861
rms_scaling = self .rms_scaling ,
@@ -2883,21 +2877,24 @@ def layer_normalization(
2883
2877
batch independently, rather than across a batch like Batch Normalization.
2884
2878
i.e. applies a transformation that maintains the mean activation within each
2885
2879
example close to 0 and the activation standard deviation close to 1.
2880
+
2886
2881
Args:
2887
2882
x: Input tensor.
2888
- axis: The axis or axes along which to perform normalization.
2889
- Default to -1.
2890
2883
gamma: Optional scaling factor for the normalization.
2891
2884
beta: Optional add offset for the normalized tensor.
2885
+ axis: The axis or axes along which to perform normalization. Default to
2886
+ `-1`.
2892
2887
epsilon: A lower bound value for the norm.
2893
2888
Defaults to `backend.epsilon()`.
2894
2889
2895
2890
Returns:
2896
2891
The normalized array.
2897
- >>> x = ops.arange(5,dtype = "float32")
2898
- >>> x_norm = ops.layer_normalization(x)
2899
- >>> print(x_norm)
2900
- array([-1.4142135 , -0.70710677, 0., 0.7071067 , 1.4142135 ])
2892
+
2893
+ Example:
2894
+
2895
+ >>> x = keras.ops.arange(5, dtype="float32")
2896
+ >>> keras.ops.layer_normalization(x)
2897
+ array([-1.4142135, -0.70710677, 0.0, 0.7071067, 1.4142135])
2901
2898
"""
2902
2899
rms_scaling = kwargs .pop ("rms_scaling" , False )
2903
2900
if rms_scaling :
@@ -2909,14 +2906,10 @@ def layer_normalization(
2909
2906
"instead."
2910
2907
)
2911
2908
2912
- if any_symbolic_tensors ((x ,)):
2909
+ if any_symbolic_tensors ((x , gamma , beta )):
2913
2910
return LayerNorm (
2914
- gamma = gamma ,
2915
- beta = beta ,
2916
- axis = axis ,
2917
- epsilon = epsilon ,
2918
- rms_scaling = rms_scaling ,
2919
- ).symbolic_call (x )
2911
+ axis = axis , epsilon = epsilon , rms_scaling = rms_scaling
2912
+ ).symbolic_call (x , gamma , beta )
2920
2913
return _layer_normalization (
2921
2914
x ,
2922
2915
gamma = gamma ,
@@ -2928,12 +2921,21 @@ def layer_normalization(
2928
2921
2929
2922
2930
2923
def _layer_normalization (
2931
- inputs , gamma = None , beta = None , axis = - 1 , epsilon = None , rms_scaling = False
2924
+ x , gamma = None , beta = None , axis = - 1 , epsilon = None , rms_scaling = False
2932
2925
):
2933
- compute_dtype = backend .result_type (inputs .dtype , "float32" )
2934
- # LN is prone to overflow with float16/bfloat16 inputs, so we upcast to
2935
- # float32 for the subsequent computations.
2936
- x = backend .cast (inputs , compute_dtype )
2926
+ if epsilon is None :
2927
+ epsilon = backend .epsilon ()
2928
+ original_dtype = backend .standardize_dtype (x .dtype )
2929
+ # Computes in at least float32 precision for stability in half precision
2930
+ # training.
2931
+ compute_dtype = backend .result_type (x .dtype , "float32" )
2932
+
2933
+ x = backend .convert_to_tensor (x , dtype = compute_dtype )
2934
+ if gamma is not None :
2935
+ gamma = backend .convert_to_tensor (gamma , x .dtype )
2936
+ if beta is not None :
2937
+ beta = backend .convert_to_tensor (beta , x .dtype )
2938
+
2937
2939
# Compute the axes along which to reduce the mean / variance
2938
2940
input_shape = x .shape
2939
2941
ndims = len (input_shape )
@@ -2951,16 +2953,12 @@ def _broadcast(v):
2951
2953
return backend .numpy .reshape (v , broadcast_shape )
2952
2954
return v
2953
2955
2954
- if epsilon is None :
2955
- epsilon = backend .epsilon ()
2956
-
2957
2956
if rms_scaling :
2958
- # Calculate outputs with only variance and gamma if rms scaling
2959
- # is enabled
2960
- # Calculate the variance along self.axis (layer activations).
2961
2957
variance = backend .numpy .var (x , axis = axis , keepdims = True )
2962
2958
inv = backend .math .rsqrt (variance + epsilon )
2963
- outputs = x * inv * backend .cast (_broadcast (gamma ), x .dtype )
2959
+ outputs = outputs = x * inv
2960
+ if gamma is not None :
2961
+ outputs = outputs * backend .cast (_broadcast (gamma ), x .dtype )
2964
2962
elif backend .config .backend () == "torch" and is_continuous_axis (axis ):
2965
2963
# when using torch backend,use kernel to improve performance
2966
2964
import torch .nn .functional as F
@@ -2973,16 +2971,14 @@ def _broadcast(v):
2973
2971
gamma , beta = _broadcast (gamma ), _broadcast (beta )
2974
2972
inv = backend .math .rsqrt (variance + epsilon )
2975
2973
if gamma is not None :
2976
- gamma = backend .cast (gamma , x .dtype )
2977
2974
inv = inv * gamma
2978
2975
2979
2976
res = - mean * inv
2980
2977
if beta is not None :
2981
- beta = backend .cast (beta , x .dtype )
2982
2978
res = res + beta
2983
2979
2984
2980
outputs = x * inv + res
2985
- return backend .cast (outputs , inputs . dtype )
2981
+ return backend .cast (outputs , original_dtype )
2986
2982
2987
2983
2988
2984
class Polar (Operation ):
0 commit comments