Skip to content

Commit b8204d6

Browse files
Update phishing_email_detection_gpt2.py
...
1 parent 116b888 commit b8204d6

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

phishing_email_detection_gpt2.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
## GPT2 + Cerebros for Phishing email detection
1010
1111
Initialization
12-
"""
12+
def rotate_half(x):
13+
x = split_alternate(x)
14+
rotated_x = tf.concat([-x[..., x.shape[-1]//2:], x[..., :x.shape[-1]//2]], axis=-1)
15+
return tf.reshape(rotated_x, tf.shape(x))"""
1316

1417
import tensorflow as tf
1518
import tensorflow_text
@@ -255,11 +258,8 @@ def split_alternate(x):
255258

256259
def rotate_half(x):
257260
x = split_alternate(x)
258-
d = tf.shape(x)[-1]
259-
x1 = x[..., :d//2]
260-
x2 = x[..., d//2:]
261-
rotated_x = tf.concat([-x2, x1], axis=-1)
262-
return tf.reshape(rotated_x, tf.shape(x)[:-2] + [-1])
261+
rotated_x = tf.concat([-x[..., x.shape[-1]//2:], x[..., :x.shape[-1]//2]], axis=-1)
262+
return tf.reshape(rotated_x, tf.shape(x))
263263

264264

265265
def apply_rotary_pos_emb(x, sin, cos):
@@ -268,6 +268,7 @@ def apply_rotary_pos_emb(x, sin, cos):
268268
x_rotated = x * cos + rotate_half(x) * sin
269269
return x_rotated
270270

271+
271272
class InterleavedRoPE(tf.keras.layers.Layer):
272273
def __init__(self, dim, max_seq_len=1024, **kwargs):
273274
super().__init__(**kwargs)

0 commit comments

Comments
 (0)