Skip to content

Commit 01d0d1d

Browse files
committed
improve rmsln preformance when torch backend
1 parent 6b74cb0 commit 01d0d1d

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

keras/src/ops/nn.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from keras.src.ops import operation_utils
1515
from keras.src.ops.operation import Operation
1616
from keras.src.ops.operation_utils import reduce_shape
17+
from keras.src.utils.python_utils import is_continuous_axis
1718

1819

1920
class Relu(Operation):
@@ -2725,6 +2726,10 @@ def compute_output_spec(self, x):
27252726
return KerasTensor(shape=x.shape)
27262727

27272728
def call(self, x):
2729+
if backend.backend() == "torch" and is_continuous_axis(self.axis):
2730+
import torch.nn.functional as F
2731+
2732+
return F.rms_norm(x, self.axis, self.scale, self.epsilon)
27282733
return _rms_normalization(
27292734
x, scale=self.scale, axis=self.axis, epsilon=self.epsilon
27302735
)

keras/src/utils/python_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,24 @@
55
import types as python_types
66

77

8+
def is_continuous_axis(axis):
9+
# Used to determine whether the dimensions in an axis are continuous
10+
if len(axis) == 1:
11+
return True
12+
positive_order_flag = True
13+
for i in range(len(axis) - 1):
14+
if axis[i + 1] - axis[i] != 1:
15+
positive_order_flag = False
16+
break
17+
18+
negative_order_flag = True
19+
for i in range(len(axis) - 1):
20+
if axis[i + 1] - axis[i] != 1:
21+
negative_order_flag = False
22+
break
23+
return positive_order_flag or negative_order_flag
24+
25+
826
def default(method):
927
"""Decorates a method to detect overrides in subclasses."""
1028
method._is_default = True

0 commit comments

Comments
 (0)