Skip to content

Commit 1566908

Browse files
Update phishing_email_detection_gpt2.py
1 parent 0340045 commit 1566908

File tree

1 file changed

+4
-44
lines changed

1 file changed

+4
-44
lines changed

phishing_email_detection_gpt2.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import zero_7_exp_decay, zero_95_exp_decay, simple_sigmoid
3131
from ast import literal_eval
3232

33+
from custom.custom import GPT2Layer
34+
3335
#
3436
# Load the email data
3537
#
@@ -73,50 +75,8 @@
7375
INPUT_SHAPES = [()]
7476
OUTPUT_SHAPES = [1]
7577

76-
"""### A custom GPT2 encoder layer for text embedding"""
77-
78-
class GPT2Layer(tf.keras.layers.Layer):
79-
80-
def __init__(self, max_seq_length, **kwargs):
81-
#
82-
super(GPT2Layer, self).__init__(**kwargs)
83-
#
84-
# Load the GPT2 tokenizer, preprocessor and model
85-
self.tokenizer = GPT2Tokenizer.from_preset("gpt2_base_en")
86-
self.preprocessor = GPT2Preprocessor(self.tokenizer,
87-
sequence_length=max_seq_length)
88-
self.encoder = GPT2Backbone.from_preset("gpt2_base_en")
89-
#
90-
# Set whether the GPT2 model's layers are trainable
91-
#self.encoder.trainable = False
92-
for layer in self.encoder.layers:
93-
layer.trainable = False
94-
#
95-
self.encoder.layers[-2].trainable = True
96-
#
97-
# Set the maximum sequence length for tokenization
98-
self.max_seq_length = max_seq_length
99-
100-
def call(self, inputs):
101-
#
102-
# Output the GPT2 embedding
103-
prep = self.preprocessor([inputs])
104-
embedding = self.encoder(prep)
105-
avg_pool = tf.reduce_mean(embedding, axis=1)
106-
#
107-
return avg_pool
108-
109-
def get_config(self):
110-
#
111-
config = super(GPT2Layer, self).get_config()
112-
config.update({'max_seq_length': self.max_seq_length})
113-
#
114-
return config
115-
116-
@classmethod
117-
def from_config(cls, config):
118-
#
119-
return cls(max_seq_length=config['max_seq_length'])
78+
79+
12080

12181
# GPT2 configurables
12282
max_seq_length = 96

0 commit comments

Comments
 (0)