Skip to content

Commit 014b3c3

Browse files
Update phishing_email_detection_gpt2.py
Fix errors from trying to work too fast ...
1 parent d8db0f1 commit 014b3c3

File tree

1 file changed

+4
-26
lines changed

1 file changed

+4
-26
lines changed

phishing_email_detection_gpt2.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -186,44 +186,22 @@ def from_config(cls, config):
186186
class TokenizerLayer(tf.keras.layers.Layer):
187187

188188
def __init__(self, max_seq_length, **kwargs):
189-
#
190-
super(GPT2Layer, self).__init__(**kwargs)
191-
#
192-
# Load the GPT2 tokenizer, preprocessor and model
193-
self.tokenizer = GPT2Tokenizer.from_preset("gpt2_extra_large_en") # "gpt2_base_en"
194-
self.preprocessor = GPT2Preprocessor(self.tokenizer,
195-
sequence_length=max_seq_length)
196-
# self.encoder = GPT2Backbone.from_preset("gpt2_base_en")
197-
#
198-
# Set whether the GPT2 model's layers are trainable
199-
# self.encoder.trainable = False
200-
# for layer in self.encoder.layers:
201-
# layer.trainable = False
202-
#
203-
# self.encoder.layers[-2].trainable = True
204-
#
205-
# Set the maximum sequence length for tokenization
189+
super(TokenizerLayer, self).__init__(**kwargs) # Update this line
190+
self.tokenizer = GPT2Tokenizer.from_preset("gpt2_extra_large_en")
191+
self.preprocessor = GPT2Preprocessor(self.tokenizer, sequence_length=max_seq_length)
206192
self.max_seq_length = max_seq_length
207193

208194
def call(self, inputs):
209-
#
210-
# Output the GPT2 embedding
211195
prep = self.preprocessor([inputs])
212-
# embedding = self.encoder(prep)
213-
# avg_pool = tf.reduce_mean(embedding, axis=1)
214-
#
215196
return prep['token_ids']
216197

217198
def get_config(self):
218-
#
219-
config = super(GPT2Layer, self).get_config()
199+
config = super(TokenizerLayer, self).get_config()
220200
config.update({'max_seq_length': self.max_seq_length})
221-
#
222201
return config
223202

224203
@classmethod
225204
def from_config(cls, config):
226-
#
227205
return cls(max_seq_length=config['max_seq_length'])
228206

229207
# GPT2 configurables

0 commit comments

Comments
 (0)