Validation data during training for early stopping #2934
-
Hi, I'm using InceptionTimeClassifier() from aeon and I'm tuning hyperparameters with Optuna. I'd like to implement early stopping based on a validation set. However, I noticed that .fit() only accepts X_train and y_train, there's no validation_data parameter like in Keras. Is there a recommended way to use early stopping based on validation loss with aeon's classifiers? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hello, thank you for your comment. For now in aeon we do not support fit taking validation data for deep learners as we want to keep consistency with other classifiers where fit just takes X and y. However a future possible addition could be to add to BaseDeepClassifier a parameter to do a cross validation split in case someone uses a callback, however it might not be a very pretty solution. One thing the user could do, given you can add any callback using the A possible code snippet could be this: import numpy as np
import tensorflow as tf
class CustomEarlyStopping(tf.keras.callbacks.Callback):
def __init__(self, x, y, val_split=0.2, patience=3, monitor='val_loss'):
super().__init__()
# Manual split
split_at = int(len(x) * (1 - val_split))
self.x_val = x[split_at:]
self.y_val = y[split_at:]
self.patience = patience
self.monitor = monitor
self.wait = 0
self.best = np.Inf if monitor == 'val_loss' else -np.Inf
self.stopped_epoch = 0
def on_epoch_end(self, epoch, logs=None):
# Evaluate on manual validation set
results = self.model.evaluate(self.x_val, self.y_val, verbose=0)
metrics = dict(zip(self.model.metrics_names, results))
current = metrics[self.monitor]
if (self.monitor == 'val_loss' and current < self.best) or \
(self.monitor != 'val_loss' and current > self.best):
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print(f"\nEarly stopping at epoch {epoch + 1}")
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print(f"Training stopped early at epoch {self.stopped_epoch + 1}") |
Beta Was this translation helpful? Give feedback.
Hello, thank you for your comment. For now in aeon we do not support fit taking validation data for deep learners as we want to keep consistency with other classifiers where fit just takes X and y. However a future possible addition could be to add to BaseDeepClassifier a parameter to do a cross validation split in case someone uses a callback, however it might not be a very pretty solution.
One thing the user could do, given you can add any callback using the
callbacks
parameter, is to implement a custom keras callback that does the validation split and then whatever type of checking you would need, such as early stopping. However you would have to give the training data X and y (that yo…