Skip to content

Commit 18f927b

Browse files
committed
improve ln preformance when torch backend
1 parent 6810267 commit 18f927b

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

keras/src/layers/normalization/layer_normalization.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,24 @@
77
from keras.src.layers.layer import Layer
88

99

10+
def is_continue_axis(axis):
11+
# Used to determine whether the dimensions in an axis are continuous
12+
if len(axis) == 1:
13+
return True
14+
positive_order_flag = True
15+
for i in range(len(axis) - 1):
16+
if axis[i + 1] - axis[i] != 1:
17+
positive_order_flag = False
18+
break
19+
20+
negative_order_flag = True
21+
for i in range(len(axis) - 1):
22+
if axis[i + 1] - axis[i] != 1:
23+
negative_order_flag = False
24+
break
25+
return positive_order_flag or negative_order_flag
26+
27+
1028
@keras_export("keras.layers.LayerNormalization")
1129
class LayerNormalization(Layer):
1230
"""Layer normalization layer (Ba et al., 2016).
@@ -214,6 +232,16 @@ def _broadcast(v):
214232
outputs = (
215233
inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)
216234
)
235+
elif backend.config.backend() == "torch" and is_continue_axis(
236+
self.axis
237+
):
238+
# when using torch backend,use kernel to improve performance
239+
import torch.nn.functional as F
240+
241+
normalized_shape = tuple([input_shape[dim] for dim in self.axis])
242+
outputs = F.layer_norm(
243+
inputs, normalized_shape, self.gamma, self.beta, self.epsilon
244+
)
217245
else:
218246
# Calculate the mean & variance along self.axis (layer activations).
219247
mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True)

0 commit comments

Comments
 (0)