|
7 | 7 | from keras.src.layers.layer import Layer
|
8 | 8 |
|
9 | 9 |
|
| 10 | +def is_continue_axis(axis): |
| 11 | + # Used to determine whether the dimensions in an axis are continuous |
| 12 | + if len(axis) == 1: |
| 13 | + return True |
| 14 | + positive_order_flag = True |
| 15 | + for i in range(len(axis) - 1): |
| 16 | + if axis[i + 1] - axis[i] != 1: |
| 17 | + positive_order_flag = False |
| 18 | + break |
| 19 | + |
| 20 | + negative_order_flag = True |
| 21 | + for i in range(len(axis) - 1): |
| 22 | + if axis[i + 1] - axis[i] != 1: |
| 23 | + negative_order_flag = False |
| 24 | + break |
| 25 | + return positive_order_flag or negative_order_flag |
| 26 | + |
| 27 | + |
10 | 28 | @keras_export("keras.layers.LayerNormalization")
|
11 | 29 | class LayerNormalization(Layer):
|
12 | 30 | """Layer normalization layer (Ba et al., 2016).
|
@@ -214,6 +232,16 @@ def _broadcast(v):
|
214 | 232 | outputs = (
|
215 | 233 | inputs * inv * ops.cast(_broadcast(self.gamma), inputs.dtype)
|
216 | 234 | )
|
| 235 | + elif backend.config.backend() == "torch" and is_continue_axis( |
| 236 | + self.axis |
| 237 | + ): |
| 238 | + # when using torch backend,use kernel to improve performance |
| 239 | + import torch.nn.functional as F |
| 240 | + |
| 241 | + normalized_shape = tuple([input_shape[dim] for dim in self.axis]) |
| 242 | + outputs = F.layer_norm( |
| 243 | + inputs, normalized_shape, self.gamma, self.beta, self.epsilon |
| 244 | + ) |
217 | 245 | else:
|
218 | 246 | # Calculate the mean & variance along self.axis (layer activations).
|
219 | 247 | mean, variance = ops.moments(inputs, axes=self.axis, keepdims=True)
|
|
0 commit comments