File tree 2 files changed +23
-0
lines changed
2 files changed +23
-0
lines changed Original file line number Diff line number Diff line change 14
14
from keras .src .ops import operation_utils
15
15
from keras .src .ops .operation import Operation
16
16
from keras .src .ops .operation_utils import reduce_shape
17
+ from keras .src .utils .python_utils import is_continuous_axis
17
18
18
19
19
20
class Relu (Operation ):
@@ -2725,6 +2726,10 @@ def compute_output_spec(self, x):
2725
2726
return KerasTensor (shape = x .shape )
2726
2727
2727
2728
def call (self , x ):
2729
+ if backend .backend () == "torch" and is_continuous_axis (self .axis ):
2730
+ import torch .nn .functional as F
2731
+
2732
+ return F .rms_norm (x , self .axis , self .scale , self .epsilon )
2728
2733
return _rms_normalization (
2729
2734
x , scale = self .scale , axis = self .axis , epsilon = self .epsilon
2730
2735
)
Original file line number Diff line number Diff line change 5
5
import types as python_types
6
6
7
7
8
+ def is_continuous_axis (axis ):
9
+ # Used to determine whether the dimensions in an axis are continuous
10
+ if len (axis ) == 1 :
11
+ return True
12
+ positive_order_flag = True
13
+ for i in range (len (axis ) - 1 ):
14
+ if axis [i + 1 ] - axis [i ] != 1 :
15
+ positive_order_flag = False
16
+ break
17
+
18
+ negative_order_flag = True
19
+ for i in range (len (axis ) - 1 ):
20
+ if axis [i + 1 ] - axis [i ] != 1 :
21
+ negative_order_flag = False
22
+ break
23
+ return positive_order_flag or negative_order_flag
24
+
25
+
8
26
def default (method ):
9
27
"""Decorates a method to detect overrides in subclasses."""
10
28
method ._is_default = True
You can’t perform that action at this time.
0 commit comments