Skip to content

Commit 5398ce7

Browse files
Update phishing_email_detection_gpt2.py
More dimensionality debugging...
1 parent 4368259 commit 5398ce7

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

phishing_email_detection_gpt2.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def from_config(cls, config):
211211

212212

213213

214-
class RotaryEmbedding(tf.keras.layers.Layer):
214+
def RotaryEmbedding(tf.keras.layers.Layer):
215215
def __init__(self, dim, max_seq_len=1024, temperature=10000.0, **kwargs):
216216
super().__init__(**kwargs)
217217
self.dim = dim
@@ -225,15 +225,18 @@ def build(self, input_shape):
225225
sinusoid = tf.einsum("i,j->ij", position, inv_freq)
226226
sin = tf.sin(sinusoid)
227227
cos = tf.cos(sinusoid)
228-
self.sin_cache = tf.concat([sin, sin], axis=-1)
229-
self.cos_cache = tf.concat([cos, cos], axis=-1)
228+
self.sin_cache = sin
229+
self.cos_cache = cos
230230

231231
def call(self, x, seq_len=None):
232232
batch_size = tf.shape(x)[0]
233233
seq_len = tf.shape(x)[1] if seq_len is None else seq_len
234234
sin = self.sin_cache[:seq_len]
235235
cos = self.cos_cache[:seq_len]
236-
return tf.cast(sin, x.dtype), tf.cast(cos, x.dtype)
236+
sin = tf.cast(tf.repeat(sin[..., tf.newaxis], self.dim // 2, axis=-1), x.dtype)
237+
cos = tf.cast(tf.repeat(cos[..., tf.newaxis], self.dim // 2, axis=-1), x.dtype)
238+
return sin, cos
239+
237240

238241
def split_alternate(x):
239242
shape = tf.shape(x)

0 commit comments

Comments
 (0)