Skip to content

Commit 6aef977

Browse files
committed
update
1 parent 01d0d1d commit 6aef977

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

keras/src/ops/nn.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2726,10 +2726,6 @@ def compute_output_spec(self, x):
27262726
return KerasTensor(shape=x.shape)
27272727

27282728
def call(self, x):
2729-
if backend.backend() == "torch" and is_continuous_axis(self.axis):
2730-
import torch.nn.functional as F
2731-
2732-
return F.rms_norm(x, self.axis, self.scale, self.epsilon)
27332729
return _rms_normalization(
27342730
x, scale=self.scale, axis=self.axis, epsilon=self.epsilon
27352731
)
@@ -2777,6 +2773,14 @@ def rms_normalization(x, scale=1, axis=-1, epsilon=None):
27772773

27782774

27792775
def _rms_normalization(x, scale=1, axis=-1, epsilon=None):
2776+
if backend.backend() == "torch" and is_continuous_axis(axis):
2777+
import torch.nn.functional as F
2778+
2779+
if isinstance(axis, (tuple, list)):
2780+
normalized_shape = tuple([x.shape[dim] for dim in axis])
2781+
else:
2782+
normalized_shape = x.shape[axis]
2783+
return F.rms_norm(x, normalized_shape, scale, epsilon)
27802784
x = backend.convert_to_tensor(x)
27812785
if len(x.shape) == 0:
27822786
x = backend.numpy.expand_dims(x, axis=0)

keras/src/utils/python_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
def is_continuous_axis(axis):
99
# Used to determine whether the dimensions in an axis are continuous
10-
if len(axis) == 1:
10+
if isinstance(axis, int) or len(axis) == 1:
1111
return True
1212
positive_order_flag = True
1313
for i in range(len(axis) - 1):

0 commit comments

Comments
 (0)