@@ -211,7 +211,7 @@ def from_config(cls, config):
211
211
212
212
213
213
214
- class RotaryEmbedding (tf .keras .layers .Layer ):
214
+ def RotaryEmbedding (tf .keras .layers .Layer ):
215
215
def __init__ (self , dim , max_seq_len = 1024 , temperature = 10000.0 , ** kwargs ):
216
216
super ().__init__ (** kwargs )
217
217
self .dim = dim
@@ -225,15 +225,18 @@ def build(self, input_shape):
225
225
sinusoid = tf .einsum ("i,j->ij" , position , inv_freq )
226
226
sin = tf .sin (sinusoid )
227
227
cos = tf .cos (sinusoid )
228
- self .sin_cache = tf . concat ([ sin , sin ], axis = - 1 )
229
- self .cos_cache = tf . concat ([ cos , cos ], axis = - 1 )
228
+ self .sin_cache = sin
229
+ self .cos_cache = cos
230
230
231
231
def call (self , x , seq_len = None ):
232
232
batch_size = tf .shape (x )[0 ]
233
233
seq_len = tf .shape (x )[1 ] if seq_len is None else seq_len
234
234
sin = self .sin_cache [:seq_len ]
235
235
cos = self .cos_cache [:seq_len ]
236
- return tf .cast (sin , x .dtype ), tf .cast (cos , x .dtype )
236
+ sin = tf .cast (tf .repeat (sin [..., tf .newaxis ], self .dim // 2 , axis = - 1 ), x .dtype )
237
+ cos = tf .cast (tf .repeat (cos [..., tf .newaxis ], self .dim // 2 , axis = - 1 ), x .dtype )
238
+ return sin , cos
239
+
237
240
238
241
def split_alternate (x ):
239
242
shape = tf .shape (x )
0 commit comments