Skip to content

Commit 0b67f88

Browse files
Update phishing_email_detection_gpt2.py
Re-corrected the metrics BinaryAccuracy to correct AI introduced error.
1 parent 014b3c3 commit 0b67f88

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

phishing_email_detection_gpt2.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070

7171
# Training data for baseline model
7272
baseline_train_x = tf.constant(X_train)
73-
baseline_train_y = tf.constant(y_train)
73+
baseline_train_y = tf.constant(y_train, dtype=tf.int8)
7474

7575
# Packaged for Cerebros (multimodal, takes inputs as a list)
7676
training_x = [baseline_train_x]
@@ -142,7 +142,10 @@ def from_config(cls, config):
142142
gpt_baseline_model.compile(
143143
optimizer=Adam(learning_rate=1e-4), # Small LR since we're fine-tuning GPT
144144
loss='binary_crossentropy',
145-
metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
145+
# metrics=['accuracy', tf.keras.metrics.AUC(name='auc')]
146+
metrics=[tf.keras.metrics.BinaryAccuracy(),
147+
tf.keras.metrics.Precision(),
148+
tf.keras.metrics.Recall()]
146149
)
147150
148151
gpt_t0 = time.time()
@@ -303,9 +306,9 @@ def from_config(cls, config):
303306
num_lateral_connection_tries_per_unit=num_lateral_connection_tries_per_unit,
304307
learning_rate=learning_rate,
305308
loss=tf.keras.losses.CategoricalHinge(),
306-
metrics=[tf.keras.metrics.Accuracy(),
307-
tf.keras.metrics.Precision(),
308-
tf.keras.metrics.Recall()],
309+
metrics=[tf.keras.metrics.BinaryAccuracy(),
310+
tf.keras.metrics.Precision(),
311+
tf.keras.metrics.Recall()],
309312
epochs=epochs,
310313
project_name=f"{PROJECT_NAME}_meta_{meta_trial_number}",
311314
model_graphs='model_graphs',

0 commit comments

Comments
 (0)