From d43921310d6fed31212d1db605c390c319cdf0d7 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Tue, 1 Oct 2024 12:25:14 +0530 Subject: [PATCH 01/13] LSTM-AD --- aeon/anomaly_detection/__init__.py | 2 + aeon/anomaly_detection/_lstm_ad.py | 332 +++++++++++++++++++++++++++++ 2 files changed, 334 insertions(+) create mode 100644 aeon/anomaly_detection/_lstm_ad.py diff --git a/aeon/anomaly_detection/__init__.py b/aeon/anomaly_detection/__init__.py index 5acb9e3921..3e278c5bea 100644 --- a/aeon/anomaly_detection/__init__.py +++ b/aeon/anomaly_detection/__init__.py @@ -3,6 +3,7 @@ __all__ = [ "DWT_MLEAD", "KMeansAD", + "LSTM_AD", "MERLIN", "STRAY", "PyODAdapter", @@ -11,6 +12,7 @@ from aeon.anomaly_detection._dwt_mlead import DWT_MLEAD from aeon.anomaly_detection._kmeans import KMeansAD +from aeon.anomaly_detection._lstm_ad import LSTM_AD from aeon.anomaly_detection._merlin import MERLIN from aeon.anomaly_detection._pyodadapter import PyODAdapter from aeon.anomaly_detection._stomp import STOMP diff --git a/aeon/anomaly_detection/_lstm_ad.py b/aeon/anomaly_detection/_lstm_ad.py new file mode 100644 index 0000000000..720fd0317e --- /dev/null +++ b/aeon/anomaly_detection/_lstm_ad.py @@ -0,0 +1,332 @@ +"""LSTM-AD Anomaly Detector.""" + +__all__ = ["LSTM_AD"] + +import numpy as np +from scipy.stats import multivariate_normal +from sklearn.covariance import EmpiricalCovariance +from sklearn.metrics import fbeta_score +from sklearn.model_selection import train_test_split +from tensorflow.keras import layers, models +from tensorflow.keras.callbacks import EarlyStopping + +from aeon.anomaly_detection.base import BaseAnomalyDetector + +# from aeon.utils.windowing import sliding_windows + + +class LSTM_AD(BaseAnomalyDetector): + """LSTM-AD anomaly detector. + + The LSTM-AD uses stacked LSTM network for anomaly detection in time series. A + network is trained over non-anomalous data and used as a predictor over a + number of time steps. The resulting prediction errors are modeled as a + multivariate Gaussian distribution, which is used to assess the likelihood of + anomalous behavior. + + ``LSTMAD`` supports univariate and multivariate time series. It can also be + fitted on a clean reference time series and used to detect anomalies in a different + target time series with the same number of dimensions. + + .. list-table:: Capabilities + :stub-columns: 1 + + * - Input data format + - univariate and multivariate + * - Output data format + - binary classification + * - Learning Type + - supervised + + Parameters + ---------- + n_layers : int, default=2 + The number of LSTM layers to be stacked. + + n_nodes : int, default=64 + The number of LSTM units in each layer. + + window_size : int, default=20 + The size of the sliding window used to split the time series into windows. The + bigger the window size, the bigger the anomaly context is. If it is too big, + however, the detector marks points anomalous that are not. If it is too small, + the detector might not detect larger anomalies or contextual anomalies at all. + If ``window_size`` is smaller than the anomaly, the detector might detect only + the transitions between normal data and the anomalous subsequence. + + prediction_horizon : int, default=1 + The prediction horizon is the number of time steps in the future predicted by + the LSTM. default value is ``1``, which means the the LSTM will take + ``window_size`` time steps as input and predict ``1`` time step in the future. + + batch_size : int, default=32 + The number of time steps per gradient update. + + n_epochs: int, default = 1500 + The number of epochs to train the model. + + patience: int, default = 5 + The number of epochs to watch before early stopping. + + verbose : boolean, default = False + whether to output extra information + + Notes + ----- + This implementation is inspired by [1]_. + + References + ---------- + .. [1] Malhotra Pankaj, Lovekesh Vig, Gautam Shroff, and Puneet Agarwal. + Long Short Term Memory Networks for Anomaly Detection in Time Series. In Proceedings + of the European Symposium on Artificial Neural Networks, Computational Intelligence + and Machine Learning (ESANN), Vol. 23, 2015. + https://www.esann.org/sites/default/files/proceedings/legacy/es2015-56.pdf + + Examples + -------- + >>> import numpy as np + >>> from aeon.anomaly_detection import LSTMAD + >>> X = np.array([1, 2, 3, 4, 1, 2, 3, 3, 2, 8, 9, 8, 1, 2, 3, 4], dtype=np.float_) + >>> y = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0]) + >>> detector = LSTM_AD(window_size=4, n_epochs=10, batch_size = 4) + >>> detector.fit_predict(X) + array([1.97827709, 2.45374147, 2.51929879, 2.36979677, 2.34826601, + 2.05075554, 2.57611912, 2.87642119, 3.18400743, 3.65060425, + 3.36402514, 3.94053744, 3.65448197, 3.6707922 , 3.70341266, + 1.97827709]) + + """ + + _tags = { + "capability:univariate": True, + "capability:multivariate": True, + "capability:missing_values": False, + "fit_is_empty": False, + "python_dependencies": "tensorflow", + } + + def __init__( + self, + n_layers: int = 2, + n_nodes: int = 64, + window_size: int = 20, + prediction_horizon: int = 1, + batch_size: int = 32, + n_epochs: int = 1500, + patience: int = 5, + verbose: bool = False, + ): + self.n_layers = n_layers + self.n_nodes = n_nodes + self.window_size = window_size + self.prediction_horizon = prediction_horizon + self.batch_size = batch_size + self.n_epochs = n_epochs + self.patience = patience + self.verbose = verbose + + super().__init__(axis=0) + + def _fit(self, X: np.array, y: np.array): + """Fit the model on the data. + + Parameters + ---------- + X: np.ndarray of shape (n_timepoints, n_channels) + The training time series, maybe with anomalies. + y: np.ndarray of shape (n_timepoints,) or (n_timepoints, 1) + Anomaly annotations for the training time series with values 0 or 1. + """ + self._check_params(X) + + if X.ndim == 1: + self.n_channels = 1 + else: + self.n_channels = X.shape[1] + + # Create normal time series if not present + if len(np.unique(y)) == 2: + X_normal = X[y == 0] + y_normal = y[y == 0] + X_anomaly = X[y == 1] + else: + raise ValueError( + "The training time series must have anomaly annotations with values" + "0 for normal and 1 for anomaly." + ) + + # Divide the normal time series into train set and two validation sets for lstm + X_train, X_val, y_train, y_val = train_test_split( + X_normal, y_normal, test_size=0.2, shuffle=False + ) + X_val1, X_val2, y_val1, y_val2 = train_test_split( + X_val, y_val, test_size=0.5, shuffle=False + ) + X_train_n, y_train_n = self._create_sequences( + X_train, self.window_size, self.prediction_horizon + ) + X_val_1, y_val_1 = self._create_sequences( + X_val1, self.window_size, self.prediction_horizon + ) + X_val_2, y_val_2 = self._create_sequences( + X_val2, self.window_size, self.prediction_horizon + ) + + # Create a stacked LSTM model and fit on the training data + self.model = self.build_model( + self.n_layers, + self.n_nodes, + self.n_channels, + self.prediction_horizon, + self.window_size, + ) + self.model_summary_ = self.model.summary() + early_stopping = EarlyStopping( + monitor="val_loss", patience=self.patience, restore_best_weights=True + ) + self.history = self.model.fit( + X_train_n, + y_train_n, + validation_data=(X_val_1, y_val_1), + epochs=self.n_epochs, + batch_size=self.batch_size, + callbacks=[early_stopping], + verbose=self.verbose, + ) + + # Prediction errors on validation set 1 to calculate error vector + predicted_vN1 = self.model.predict(X_val_1) + errors_vN1 = y_val_1 - predicted_vN1 + + # Fit the error vectors to a Gaussian distribution + cov_estimator = EmpiricalCovariance() + cov_estimator.fit(errors_vN1) + + # Mean and covariance matrix of the error distribution + mu = cov_estimator.location_ + cov_matrix = cov_estimator.covariance_ + + # Create a Gaussian Normal Distribution + self.distribution = multivariate_normal(mean=mu, cov=cov_matrix) + + X_anomalies, y_anomalies = self._create_sequences( + X_anomaly, self.window_size, self.prediction_horizon + ) + + predicted_vN2 = self.model.predict(X_val_2) + predicted_vA = self.model.predict(X_anomalies) + + errors_vN2 = y_val_2 - predicted_vN2 + errors_vA = y_anomalies - predicted_vA + + # Estimate the likelihood of the errors: + p_vN2 = self.distribution.pdf(errors_vN2) + p_vA = self.distribution.pdf(errors_vA) + + # Combine likelihoods and labels + likelihoods = np.concatenate([p_vN2, p_vA]) + true_labels = np.concatenate( + [np.zeros_like(p_vN2), np.ones_like(p_vA)] + ) # 0 for normal, 1 for anomalous + + # Experiment with different thresholds and calculate Fβ-score + self.best_tau = None + self.best_fbeta = -1 + + # Loop over different thresholds + for tau in np.linspace(min(likelihoods), max(likelihoods), 100): + # Classify as anomalous if likelihood < tau + predictions = (likelihoods < tau).astype(int) + + # Calculate Fβ-score (arbitrarily use beta=1.0 for F1-score) + fbeta = fbeta_score(true_labels, predictions, beta=1.0) + + # Track the best threshold and Fβ-score + if fbeta > self.best_fbeta: + self.best_tau = tau + self.best_fbeta = fbeta + + def _predict(self, X): + X_, y_ = self._create_sequences(X, self.window_size, self.prediction_horizon) + predict_test = self.model.predict(X_) + errors = y_ - predict_test + likelihoods = multivariate_normal.pdf(errors) + anomalies = (likelihoods < self.best_tau).astype(int) + return np.concatenate([np.zeros(self.window_size + 1), anomalies]) + + def build_model( + self, n_layers, n_nodes, n_channels, window_size, prediction_horizon + ): + """Construct a compiled, un-trained, keras model that is ready for training. + + In aeon, time series are stored in numpy arrays of shape (d,m), where d + is the number of dimensions, m is the series length. Keras/tensorflow assume + data is in shape (m,d). This method also assumes (m,d). Transpose should + happen in fit. + + Parameters + ---------- + n_layers : int + The number of layers in the LSTM model. + n_nodes : int + The number of LSTM units in each layer. + n_channels : int + It is basically d, the number of dimesions. + window_size : int + Tie number of time steps fed to the model. + prediction_horizon : int + The number of time steps to be predicted by the model. + + Returns + ------- + output : a compiled Keras Model + """ + model = models.Sequential() + model.add(layers.Input(shape=(window_size, n_channels))) + model.add(layers.LSTM(n_nodes, return_sequences=True)) # First LSTM layer + for _ in range(n_layers - 1): + model.add(layers.LSTM(n_nodes)) # Stacked LSTM layers + model.add(layers.Dense(n_channels * prediction_horizon)) + model.compile(optimizer="adam", loss="mse") + return model + + def _check_params(self, X: np.ndarray) -> None: + if self.window_size < 1 or self.window_size > X.shape[0]: + raise ValueError( + "The window size must be at least 1 and at most the length of the " + "time series." + ) + if self.batch_size < 1 or self.batch_size > X.shape[0]: + raise ValueError( + "The batch size must be at least 1 and at most the length of the " + "time series." + ) + + # Create input and output sequences for lstm using sliding window + def _create_sequences(self, data, window_size, prediction_horizon): + """Create input and output sequences using sliding window to train LSTM. + + Parameters + ---------- + data: np.dnarray + The time series of shape (n_timepoints, n_channels). + window_size: int + The length of the sliding window. + prediction_horizon: int + The number of time steps in future that would be predicted by the model. + + Returns + ------- + X: np.ndarray + The array of input sequences of shape + (n_timepoints - window_size - 1, n_channels). + y: np.ndarray + The array of output sequences of shape + (n_timepoints - window_size - 1, window_size). + """ + X, y = [], [] + for i in range(len(data) - window_size - prediction_horizon + 1): + X.append(data[i : (i + window_size)]) + y.append(data[(i + window_size) : (i + window_size + prediction_horizon)]) + return np.array(X), np.array(y) From 4a48f234be3b2d193a80394ba294d70569c2e8ef Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Wed, 2 Oct 2024 22:07:47 +0530 Subject: [PATCH 02/13] Fixing minor issues --- aeon/anomaly_detection/_lstm_ad.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/aeon/anomaly_detection/_lstm_ad.py b/aeon/anomaly_detection/_lstm_ad.py index 720fd0317e..1bf732d367 100644 --- a/aeon/anomaly_detection/_lstm_ad.py +++ b/aeon/anomaly_detection/_lstm_ad.py @@ -86,7 +86,7 @@ class LSTM_AD(BaseAnomalyDetector): Examples -------- >>> import numpy as np - >>> from aeon.anomaly_detection import LSTMAD + >>> from aeon.anomaly_detection import LSTM_AD >>> X = np.array([1, 2, 3, 4, 1, 2, 3, 3, 2, 8, 9, 8, 1, 2, 3, 4], dtype=np.float_) >>> y = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0]) >>> detector = LSTM_AD(window_size=4, n_epochs=10, batch_size = 4) @@ -140,11 +140,6 @@ def _fit(self, X: np.array, y: np.array): """ self._check_params(X) - if X.ndim == 1: - self.n_channels = 1 - else: - self.n_channels = X.shape[1] - # Create normal time series if not present if len(np.unique(y)) == 2: X_normal = X[y == 0] @@ -166,12 +161,15 @@ def _fit(self, X: np.array, y: np.array): X_train_n, y_train_n = self._create_sequences( X_train, self.window_size, self.prediction_horizon ) + y_train_n = y_train_n.reshape(-1, self.prediction_horizon * self.n_channels) X_val_1, y_val_1 = self._create_sequences( X_val1, self.window_size, self.prediction_horizon ) + y_val_1 = y_val_1.reshape(-1, self.prediction_horizon * self.n_channels) X_val_2, y_val_2 = self._create_sequences( X_val2, self.window_size, self.prediction_horizon ) + y_val_2 = y_val_2.reshape(-1, self.prediction_horizon * self.n_channels) # Create a stacked LSTM model and fit on the training data self.model = self.build_model( @@ -213,6 +211,7 @@ def _fit(self, X: np.array, y: np.array): X_anomalies, y_anomalies = self._create_sequences( X_anomaly, self.window_size, self.prediction_horizon ) + y_anomalies = y_anomalies.reshape(-1, self.prediction_horizon * self.n_channels) predicted_vN2 = self.model.predict(X_val_2) predicted_vA = self.model.predict(X_anomalies) @@ -249,6 +248,7 @@ def _fit(self, X: np.array, y: np.array): def _predict(self, X): X_, y_ = self._create_sequences(X, self.window_size, self.prediction_horizon) + y_ = y_.reshape(-1, self.prediction_horizon * self.n_channels) predict_test = self.model.predict(X_) errors = y_ - predict_test likelihoods = multivariate_normal.pdf(errors) @@ -292,6 +292,16 @@ def build_model( return model def _check_params(self, X: np.ndarray) -> None: + if X.ndim == 1: + self.n_channels = 1 + elif X.ndim == 2: + self.n_channels = X.shape[1] + else: + raise ValueError( + "The training time series must be of shape (n_timepoints,) or " + "(n_timepoints, n_channels)." + ) + if self.window_size < 1 or self.window_size > X.shape[0]: raise ValueError( "The window size must be at least 1 and at most the length of the " From 81b7ebf808942464e60dac6cbae46392ca99abf3 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Mon, 7 Oct 2024 17:36:31 +0530 Subject: [PATCH 03/13] Fixing bugs --- aeon/anomaly_detection/_lstm_ad.py | 102 ++++++++++++++--------------- 1 file changed, 51 insertions(+), 51 deletions(-) diff --git a/aeon/anomaly_detection/_lstm_ad.py b/aeon/anomaly_detection/_lstm_ad.py index 1bf732d367..7893db9581 100644 --- a/aeon/anomaly_detection/_lstm_ad.py +++ b/aeon/anomaly_detection/_lstm_ad.py @@ -12,8 +12,6 @@ from aeon.anomaly_detection.base import BaseAnomalyDetector -# from aeon.utils.windowing import sliding_windows - class LSTM_AD(BaseAnomalyDetector): """LSTM-AD anomaly detector. @@ -86,15 +84,12 @@ class LSTM_AD(BaseAnomalyDetector): Examples -------- >>> import numpy as np - >>> from aeon.anomaly_detection import LSTM_AD - >>> X = np.array([1, 2, 3, 4, 1, 2, 3, 3, 2, 8, 9, 8, 1, 2, 3, 4], dtype=np.float_) - >>> y = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0]) + >>> from aeon.datasets import load_anomaly_detection + >>> X, y = load_anomaly_detection( + name=("KDD-TSAD", "001_UCR_Anomaly_DISTORTED1sddb40") + ) >>> detector = LSTM_AD(window_size=4, n_epochs=10, batch_size = 4) - >>> detector.fit_predict(X) - array([1.97827709, 2.45374147, 2.51929879, 2.36979677, 2.34826601, - 2.05075554, 2.57611912, 2.87642119, 3.18400743, 3.65060425, - 3.36402514, 3.94053744, 3.65448197, 3.6707922 , 3.70341266, - 1.97827709]) + >>> detector.fit(X) """ @@ -158,26 +153,31 @@ def _fit(self, X: np.array, y: np.array): X_val1, X_val2, y_val1, y_val2 = train_test_split( X_val, y_val, test_size=0.5, shuffle=False ) - X_train_n, y_train_n = self._create_sequences( + X_train_n, y_train_n = _create_sequences( X_train, self.window_size, self.prediction_horizon ) y_train_n = y_train_n.reshape(-1, self.prediction_horizon * self.n_channels) - X_val_1, y_val_1 = self._create_sequences( + X_val_1, y_val_1 = _create_sequences( X_val1, self.window_size, self.prediction_horizon ) y_val_1 = y_val_1.reshape(-1, self.prediction_horizon * self.n_channels) - X_val_2, y_val_2 = self._create_sequences( + X_val_2, y_val_2 = _create_sequences( X_val2, self.window_size, self.prediction_horizon ) y_val_2 = y_val_2.reshape(-1, self.prediction_horizon * self.n_channels) + X_anomalies, y_anomalies = _create_sequences( + X_anomaly, self.window_size, self.prediction_horizon + ) + y_anomalies = y_anomalies.reshape(-1, self.prediction_horizon * self.n_channels) + # Create a stacked LSTM model and fit on the training data - self.model = self.build_model( + self.model = self._build_model( self.n_layers, self.n_nodes, self.n_channels, - self.prediction_horizon, self.window_size, + self.prediction_horizon, ) self.model_summary_ = self.model.summary() early_stopping = EarlyStopping( @@ -208,11 +208,6 @@ def _fit(self, X: np.array, y: np.array): # Create a Gaussian Normal Distribution self.distribution = multivariate_normal(mean=mu, cov=cov_matrix) - X_anomalies, y_anomalies = self._create_sequences( - X_anomaly, self.window_size, self.prediction_horizon - ) - y_anomalies = y_anomalies.reshape(-1, self.prediction_horizon * self.n_channels) - predicted_vN2 = self.model.predict(X_val_2) predicted_vA = self.model.predict(X_anomalies) @@ -247,15 +242,16 @@ def _fit(self, X: np.array, y: np.array): self.best_fbeta = fbeta def _predict(self, X): - X_, y_ = self._create_sequences(X, self.window_size, self.prediction_horizon) + X_, y_ = _create_sequences(X, self.window_size, self.prediction_horizon) y_ = y_.reshape(-1, self.prediction_horizon * self.n_channels) predict_test = self.model.predict(X_) errors = y_ - predict_test - likelihoods = multivariate_normal.pdf(errors) + likelihoods = self.distribution.pdf(errors) anomalies = (likelihoods < self.best_tau).astype(int) - return np.concatenate([np.zeros(self.window_size + 1), anomalies]) + prediction = np.concatenate([np.zeros(self.window_size + 1), anomalies]) + return np.array(prediction, dtype=int) - def build_model( + def _build_model( self, n_layers, n_nodes, n_channels, window_size, prediction_horizon ): """Construct a compiled, un-trained, keras model that is ready for training. @@ -285,8 +281,11 @@ def build_model( model = models.Sequential() model.add(layers.Input(shape=(window_size, n_channels))) model.add(layers.LSTM(n_nodes, return_sequences=True)) # First LSTM layer - for _ in range(n_layers - 1): - model.add(layers.LSTM(n_nodes)) # Stacked LSTM layers + if n_layers > 2: + for _ in range(1, n_layers - 1): + model.add(layers.LSTM(n_nodes, return_sequences=True)) + # Last LSTM layer, return_sequences=False + model.add(layers.LSTM(n_nodes, return_sequences=False)) model.add(layers.Dense(n_channels * prediction_horizon)) model.compile(optimizer="adam", loss="mse") return model @@ -313,30 +312,31 @@ def _check_params(self, X: np.ndarray) -> None: "time series." ) - # Create input and output sequences for lstm using sliding window - def _create_sequences(self, data, window_size, prediction_horizon): - """Create input and output sequences using sliding window to train LSTM. - Parameters - ---------- - data: np.dnarray - The time series of shape (n_timepoints, n_channels). - window_size: int - The length of the sliding window. - prediction_horizon: int - The number of time steps in future that would be predicted by the model. +# Create input and output sequences for lstm using sliding window +def _create_sequences(data, window_size, prediction_horizon): + """Create input and output sequences using sliding window to train LSTM. - Returns - ------- - X: np.ndarray - The array of input sequences of shape - (n_timepoints - window_size - 1, n_channels). - y: np.ndarray - The array of output sequences of shape - (n_timepoints - window_size - 1, window_size). - """ - X, y = [], [] - for i in range(len(data) - window_size - prediction_horizon + 1): - X.append(data[i : (i + window_size)]) - y.append(data[(i + window_size) : (i + window_size + prediction_horizon)]) - return np.array(X), np.array(y) + Parameters + ---------- + data: np.dnarray + The time series of shape (n_timepoints, n_channels). + window_size: int + The length of the sliding window. + prediction_horizon: int + The number of time steps in future that would be predicted by the model. + + Returns + ------- + X: np.ndarray + The array of input sequences of shape + (n_timepoints - window_size - 1, n_channels). + y: np.ndarray + The array of output sequences of shape + (n_timepoints - window_size - 1, window_size). + """ + X, y = [], [] + for i in range(len(data) - window_size - prediction_horizon + 1): + X.append(data[i : (i + window_size)]) + y.append(data[(i + window_size) : (i + window_size + prediction_horizon)]) + return np.array(X), np.array(y) From 69dbabcce7eb772d891f05764c9327b017eaef4d Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Mon, 7 Oct 2024 23:16:57 +0530 Subject: [PATCH 04/13] Unit test for LSTM_AD --- aeon/anomaly_detection/_lstm_ad.py | 9 ++-- aeon/anomaly_detection/tests/test_lstm_ad.py | 48 ++++++++++++++++++++ 2 files changed, 53 insertions(+), 4 deletions(-) create mode 100644 aeon/anomaly_detection/tests/test_lstm_ad.py diff --git a/aeon/anomaly_detection/_lstm_ad.py b/aeon/anomaly_detection/_lstm_ad.py index 7893db9581..6d2b0c6d23 100644 --- a/aeon/anomaly_detection/_lstm_ad.py +++ b/aeon/anomaly_detection/_lstm_ad.py @@ -88,9 +88,9 @@ class LSTM_AD(BaseAnomalyDetector): >>> X, y = load_anomaly_detection( name=("KDD-TSAD", "001_UCR_Anomaly_DISTORTED1sddb40") ) - >>> detector = LSTM_AD(window_size=4, n_epochs=10, batch_size = 4) - >>> detector.fit(X) - + >>> detector = LSTM_AD(n_layers=4, n_nodes=64, window_size=10, prediction_horizon=2) + >>> detector.fit(X, axis=0) + >>> anomaly_pred = detector.predict(X, axis=0) """ _tags = { @@ -248,7 +248,8 @@ def _predict(self, X): errors = y_ - predict_test likelihoods = self.distribution.pdf(errors) anomalies = (likelihoods < self.best_tau).astype(int) - prediction = np.concatenate([np.zeros(self.window_size + 1), anomalies]) + padding = np.zeros(X.shape[0] - len(anomalies)) + prediction = np.concatenate([padding, anomalies]) return np.array(prediction, dtype=int) def _build_model( diff --git a/aeon/anomaly_detection/tests/test_lstm_ad.py b/aeon/anomaly_detection/tests/test_lstm_ad.py new file mode 100644 index 0000000000..094271596d --- /dev/null +++ b/aeon/anomaly_detection/tests/test_lstm_ad.py @@ -0,0 +1,48 @@ +"""Tests for the LSTM_AD class.""" + +import numpy as np + +from aeon.anomaly_detection import LSTM_AD +from aeon.testing.data_generation._legacy import make_series + + +def test_lstmad_univariate(): + """Test LSTM_AD univariate output.""" + series = make_series(n_timepoints=1000, return_numpy=True, random_state=42) + labels = np.zeros(1000) + + # Create anomalies + anomaly_indices = np.random.choice(1000, 20, replace=False) + series[anomaly_indices] += np.random.normal(loc=0, scale=4, size=(20, 1)) + labels[anomaly_indices] = 1 + labels = np.array(labels, dtype=int) + + ad = LSTM_AD( + n_layers=4, n_nodes=16, window_size=10, prediction_horizon=1, n_epochs=1 + ) + pred = ad.fit(series, labels, axis=0) + + assert pred.shape == (1000,) + assert pred.dtype == np.int_ + + +def test_lstmad_multivariate(): + """Test LSTM_AD multivariate output.""" + series = make_series( + n_timepoints=1000, n_columns=3, return_numpy=True, random_state=42 + ) + labels = np.zeros(1000) + + # Create anomalies + anomaly_indices = np.random.choice(1000, 20, replace=False) + series[anomaly_indices] += np.random.normal(loc=0, scale=4, size=(20, 1)) + labels[anomaly_indices] = 1 + labels = np.array(labels, dtype=int) + + ad = LSTM_AD( + n_layers=4, n_nodes=16, window_size=10, prediction_horizon=1, n_epochs=1 + ) + pred = ad.fit(series, labels, axis=0) + + assert pred.shape == (1000,) + assert pred.dtype == np.int_ From 22693a1f0c454b288c295ae9fc25f1b768861fc7 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Sat, 12 Oct 2024 09:54:35 +0530 Subject: [PATCH 05/13] LSTM_AD in init --- aeon/anomaly_detection/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aeon/anomaly_detection/__init__.py b/aeon/anomaly_detection/__init__.py index e11fd1d283..91d12fb367 100644 --- a/aeon/anomaly_detection/__init__.py +++ b/aeon/anomaly_detection/__init__.py @@ -8,11 +8,13 @@ "PyODAdapter", "STOMP", "LeftSTAMPi", + "LSTM_AD", ] from aeon.anomaly_detection._dwt_mlead import DWT_MLEAD from aeon.anomaly_detection._kmeans import KMeansAD from aeon.anomaly_detection._left_stampi import LeftSTAMPi +from aeon.anomaly_detection._lstm_ad import LSTM_AD from aeon.anomaly_detection._merlin import MERLIN from aeon.anomaly_detection._pyodadapter import PyODAdapter from aeon.anomaly_detection._stomp import STOMP From 4dd304fffda77bf21ed1d2afd83c5b707ba2cf89 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Sat, 12 Oct 2024 12:42:22 +0530 Subject: [PATCH 06/13] lstm_ad tests --- aeon/anomaly_detection/_lstm_ad.py | 24 ++++++++++++-------- aeon/anomaly_detection/tests/test_lstm_ad.py | 18 +++++++-------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/aeon/anomaly_detection/_lstm_ad.py b/aeon/anomaly_detection/_lstm_ad.py index 6d2b0c6d23..f4c4381a0a 100644 --- a/aeon/anomaly_detection/_lstm_ad.py +++ b/aeon/anomaly_detection/_lstm_ad.py @@ -7,8 +7,6 @@ from sklearn.covariance import EmpiricalCovariance from sklearn.metrics import fbeta_score from sklearn.model_selection import train_test_split -from tensorflow.keras import layers, models -from tensorflow.keras.callbacks import EarlyStopping from aeon.anomaly_detection.base import BaseAnomalyDetector @@ -133,6 +131,8 @@ def _fit(self, X: np.array, y: np.array): y: np.ndarray of shape (n_timepoints,) or (n_timepoints, 1) Anomaly annotations for the training time series with values 0 or 1. """ + import tensorflow as tf + self._check_params(X) # Create normal time series if not present @@ -179,8 +179,8 @@ def _fit(self, X: np.array, y: np.array): self.window_size, self.prediction_horizon, ) - self.model_summary_ = self.model.summary() - early_stopping = EarlyStopping( + # self.model_summary_ = self.model.summary() + early_stopping = tf.keras.callbacks.EarlyStopping( monitor="val_loss", patience=self.patience, restore_best_weights=True ) self.history = self.model.fit( @@ -279,15 +279,19 @@ def _build_model( ------- output : a compiled Keras Model """ - model = models.Sequential() - model.add(layers.Input(shape=(window_size, n_channels))) - model.add(layers.LSTM(n_nodes, return_sequences=True)) # First LSTM layer + import tensorflow as tf + + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Input(shape=(window_size, n_channels))) + model.add( + tf.keras.layers.LSTM(n_nodes, return_sequences=True) + ) # First LSTM layer if n_layers > 2: for _ in range(1, n_layers - 1): - model.add(layers.LSTM(n_nodes, return_sequences=True)) + model.add(tf.keras.layers.LSTM(n_nodes, return_sequences=True)) # Last LSTM layer, return_sequences=False - model.add(layers.LSTM(n_nodes, return_sequences=False)) - model.add(layers.Dense(n_channels * prediction_horizon)) + model.add(tf.keras.layers.LSTM(n_nodes, return_sequences=False)) + model.add(tf.keras.layers.Dense(n_channels * prediction_horizon)) model.compile(optimizer="adam", loss="mse") return model diff --git a/aeon/anomaly_detection/tests/test_lstm_ad.py b/aeon/anomaly_detection/tests/test_lstm_ad.py index 094271596d..0c64ef8f88 100644 --- a/aeon/anomaly_detection/tests/test_lstm_ad.py +++ b/aeon/anomaly_detection/tests/test_lstm_ad.py @@ -9,18 +9,18 @@ def test_lstmad_univariate(): """Test LSTM_AD univariate output.""" series = make_series(n_timepoints=1000, return_numpy=True, random_state=42) - labels = np.zeros(1000) + labels = np.zeros(1000).astype(int) # Create anomalies anomaly_indices = np.random.choice(1000, 20, replace=False) - series[anomaly_indices] += np.random.normal(loc=0, scale=4, size=(20, 1)) + series[anomaly_indices] += np.random.normal(loc=0, scale=4, size=(20,)) labels[anomaly_indices] = 1 - labels = np.array(labels, dtype=int) ad = LSTM_AD( n_layers=4, n_nodes=16, window_size=10, prediction_horizon=1, n_epochs=1 ) - pred = ad.fit(series, labels, axis=0) + ad.fit(series, labels, axis=0) + pred = ad.predict(series, axis=0) assert pred.shape == (1000,) assert pred.dtype == np.int_ @@ -31,18 +31,18 @@ def test_lstmad_multivariate(): series = make_series( n_timepoints=1000, n_columns=3, return_numpy=True, random_state=42 ) - labels = np.zeros(1000) + labels = np.zeros(1000).astype(int) # Create anomalies - anomaly_indices = np.random.choice(1000, 20, replace=False) - series[anomaly_indices] += np.random.normal(loc=0, scale=4, size=(20, 1)) + anomaly_indices = np.random.choice(1000, 50, replace=False) + series[anomaly_indices] += np.random.normal(loc=0, scale=4, size=(50, 3)) labels[anomaly_indices] = 1 - labels = np.array(labels, dtype=int) ad = LSTM_AD( n_layers=4, n_nodes=16, window_size=10, prediction_horizon=1, n_epochs=1 ) - pred = ad.fit(series, labels, axis=0) + ad.fit(series, labels, axis=0) + pred = ad.predict(series, axis=0) assert pred.shape == (1000,) assert pred.dtype == np.int_ From 8a9cc0ce282066a8d1a52d075b0209ee9312bef0 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Sat, 12 Oct 2024 15:21:19 +0530 Subject: [PATCH 07/13] soft dependancy and docstring --- aeon/anomaly_detection/_lstm_ad.py | 5 +++-- aeon/anomaly_detection/tests/test_lstm_ad.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/aeon/anomaly_detection/_lstm_ad.py b/aeon/anomaly_detection/_lstm_ad.py index f4c4381a0a..51f69779f0 100644 --- a/aeon/anomaly_detection/_lstm_ad.py +++ b/aeon/anomaly_detection/_lstm_ad.py @@ -83,9 +83,10 @@ class LSTM_AD(BaseAnomalyDetector): -------- >>> import numpy as np >>> from aeon.datasets import load_anomaly_detection + >>> from aeon.anomaly_detection import LSTM_AD >>> X, y = load_anomaly_detection( - name=("KDD-TSAD", "001_UCR_Anomaly_DISTORTED1sddb40") - ) + ... name=("KDD-TSAD", "001_UCR_Anomaly_DISTORTED1sddb40") + ... ) >>> detector = LSTM_AD(n_layers=4, n_nodes=64, window_size=10, prediction_horizon=2) >>> detector.fit(X, axis=0) >>> anomaly_pred = detector.predict(X, axis=0) diff --git a/aeon/anomaly_detection/tests/test_lstm_ad.py b/aeon/anomaly_detection/tests/test_lstm_ad.py index 0c64ef8f88..8dbe3e9b00 100644 --- a/aeon/anomaly_detection/tests/test_lstm_ad.py +++ b/aeon/anomaly_detection/tests/test_lstm_ad.py @@ -1,11 +1,17 @@ """Tests for the LSTM_AD class.""" import numpy as np +import pytest from aeon.anomaly_detection import LSTM_AD from aeon.testing.data_generation._legacy import make_series +from aeon.utils.validation._dependencies import _check_soft_dependencies +@pytest.mark.skipif( + not _check_soft_dependencies("tensorflow", severity="none"), + reason="skip test if required soft dependency not available", +) def test_lstmad_univariate(): """Test LSTM_AD univariate output.""" series = make_series(n_timepoints=1000, return_numpy=True, random_state=42) @@ -26,6 +32,10 @@ def test_lstmad_univariate(): assert pred.dtype == np.int_ +@pytest.mark.skipif( + not _check_soft_dependencies("tensorflow", severity="none"), + reason="skip test if required soft dependency not available", +) def test_lstmad_multivariate(): """Test LSTM_AD multivariate output.""" series = make_series( From ca44d3dc9a632482b10807488da6ec9c5469c247 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Sat, 12 Oct 2024 15:40:18 +0530 Subject: [PATCH 08/13] soft-dep tests --- aeon/anomaly_detection/_lstm_ad.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aeon/anomaly_detection/_lstm_ad.py b/aeon/anomaly_detection/_lstm_ad.py index 51f69779f0..00a4ea6beb 100644 --- a/aeon/anomaly_detection/_lstm_ad.py +++ b/aeon/anomaly_detection/_lstm_ad.py @@ -87,9 +87,11 @@ class LSTM_AD(BaseAnomalyDetector): >>> X, y = load_anomaly_detection( ... name=("KDD-TSAD", "001_UCR_Anomaly_DISTORTED1sddb40") ... ) - >>> detector = LSTM_AD(n_layers=4, n_nodes=64, window_size=10, prediction_horizon=2) - >>> detector.fit(X, axis=0) - >>> anomaly_pred = detector.predict(X, axis=0) + >>> detector = LSTM_AD( + ... n_layers=4, n_nodes=64, window_size=10, prediction_horizon=2 + ... ) # doctest: +SKIP + >>> detector.fit(X, axis=0) # doctest: +SKIP + LSTM_AD(...) """ _tags = { From 2ab090a0a1385a76cbe486669d15b19b114bc165 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Mon, 14 Oct 2024 22:32:21 +0530 Subject: [PATCH 09/13] Deep Learning Submodule --- aeon/anomaly_detection/__init__.py | 2 +- .../deep_learning/__init__.py | 1 + .../{ => deep_learning}/_lstm_ad.py | 212 ++++++++++++------ aeon/anomaly_detection/deep_learning/base.py | 131 +++++++++++ aeon/networks/__init__.py | 2 + aeon/networks/_lstm.py | 62 +++++ 6 files changed, 345 insertions(+), 65 deletions(-) create mode 100644 aeon/anomaly_detection/deep_learning/__init__.py rename aeon/anomaly_detection/{ => deep_learning}/_lstm_ad.py (72%) create mode 100644 aeon/anomaly_detection/deep_learning/base.py create mode 100644 aeon/networks/_lstm.py diff --git a/aeon/anomaly_detection/__init__.py b/aeon/anomaly_detection/__init__.py index 91d12fb367..2ceb8ee942 100644 --- a/aeon/anomaly_detection/__init__.py +++ b/aeon/anomaly_detection/__init__.py @@ -14,8 +14,8 @@ from aeon.anomaly_detection._dwt_mlead import DWT_MLEAD from aeon.anomaly_detection._kmeans import KMeansAD from aeon.anomaly_detection._left_stampi import LeftSTAMPi -from aeon.anomaly_detection._lstm_ad import LSTM_AD from aeon.anomaly_detection._merlin import MERLIN from aeon.anomaly_detection._pyodadapter import PyODAdapter from aeon.anomaly_detection._stomp import STOMP from aeon.anomaly_detection._stray import STRAY +from aeon.anomaly_detection.deep_learning._lstm_ad import LSTM_AD diff --git a/aeon/anomaly_detection/deep_learning/__init__.py b/aeon/anomaly_detection/deep_learning/__init__.py new file mode 100644 index 0000000000..71d11d8423 --- /dev/null +++ b/aeon/anomaly_detection/deep_learning/__init__.py @@ -0,0 +1 @@ +"""Deep learning based anomaly detector.""" diff --git a/aeon/anomaly_detection/_lstm_ad.py b/aeon/anomaly_detection/deep_learning/_lstm_ad.py similarity index 72% rename from aeon/anomaly_detection/_lstm_ad.py rename to aeon/anomaly_detection/deep_learning/_lstm_ad.py index 00a4ea6beb..e49977b131 100644 --- a/aeon/anomaly_detection/_lstm_ad.py +++ b/aeon/anomaly_detection/deep_learning/_lstm_ad.py @@ -2,16 +2,19 @@ __all__ = ["LSTM_AD"] +import time + import numpy as np from scipy.stats import multivariate_normal from sklearn.covariance import EmpiricalCovariance from sklearn.metrics import fbeta_score from sklearn.model_selection import train_test_split -from aeon.anomaly_detection.base import BaseAnomalyDetector +from aeon.anomaly_detection.deep_learning.base import BaseDeepAnomalyDetector +from aeon.networks import LSTMNetwork -class LSTM_AD(BaseAnomalyDetector): +class LSTM_AD(BaseDeepAnomalyDetector): """LSTM-AD anomaly detector. The LSTM-AD uses stacked LSTM network for anomaly detection in time series. A @@ -58,6 +61,8 @@ class LSTM_AD(BaseAnomalyDetector): batch_size : int, default=32 The number of time steps per gradient update. + optimizer : keras.optimizer, default=keras.optimizers.Adadelta() + n_epochs: int, default = 1500 The number of epochs to train the model. @@ -67,6 +72,38 @@ class LSTM_AD(BaseAnomalyDetector): verbose : boolean, default = False whether to output extra information + file_path : str, default = "./" + file_path when saving model_Checkpoint callback + + save_best_model : bool, default = False + Whether or not to save the best model, if the + modelcheckpoint callback is used by default, + this condition, if True, will prevent the + automatic deletion of the best saved model from + file and the user can choose the file name + + save_last_model : bool, default = False + Whether or not to save the last model, last + epoch trained, using the base class method + save_last_model_to_file + + save_init_model : bool, default = False + Whether to save the initialization of the model. + + best_file_name : str, default = "best_model" + The name of the file of the best model, if + save_best_model is set to False, this parameter + is discarded + + last_file_name : str, default = "last_model" + The name of the file of the last model, if + save_last_model is set to False, this parameter + is discarded + + init_file_name : str, default = "init_model" + The name of the file of the init model, if save_init_model is set to False, + this parameter is discarded. + Notes ----- This implementation is inspired by [1]_. @@ -99,6 +136,7 @@ class LSTM_AD(BaseAnomalyDetector): "capability:multivariate": True, "capability:missing_values": False, "fit_is_empty": False, + "requires_y": True, "python_dependencies": "tensorflow", } @@ -112,6 +150,15 @@ def __init__( n_epochs: int = 1500, patience: int = 5, verbose: bool = False, + loss="mse", + optimizer=None, + file_path="./", + save_best_model=False, + save_last_model=False, + save_init_model=False, + best_file_name="best_model", + last_file_name="last_model", + init_file_name="init_model", ): self.n_layers = n_layers self.n_nodes = n_nodes @@ -121,8 +168,76 @@ def __init__( self.n_epochs = n_epochs self.patience = patience self.verbose = verbose + self.loss = loss + self.optimizer = optimizer + self.file_path = file_path + self.save_best_model = save_best_model + self.save_last_model = save_last_model + self.save_init_model = save_init_model + self.best_file_name = best_file_name + self.last_file_name = last_file_name + self.init_file_name = init_file_name + + self.history = None + + super().__init__() + + self._network = LSTMNetwork(self.n_nodes, self.n_layers) + + def build_model(self, **kwargs): + """Construct a compiled, un-trained, keras model that is ready for training. + + In aeon, time series are stored in numpy arrays of shape (d,m), where d + is the number of dimensions, m is the series length. Keras/tensorflow assume + data is in shape (m,d). This method also assumes (m,d). Transpose should + happen in fit. + + Parameters + ---------- + n_layers : int + The number of layers in the LSTM model. + n_nodes : int + The number of LSTM units in each layer. + n_channels : int + It is basically d, the number of dimesions. + window_size : int + Tie number of time steps fed to the model. + prediction_horizon : int + The number of time steps to be predicted by the model. + + Returns + ------- + output : a compiled Keras Model + """ + import tensorflow as tf + + # input_layer, output_layer = self._network.build_network(input_shape, + # prediction_horizon, **kwargs) + # Input layer for the LSTM model + input_layer = tf.keras.layers.Input(shape=(self.window_size, self.n_channels)) + + # Build the LSTM layers + x = input_layer + for _ in range(self.n_layers - 1): + x = tf.keras.layers.LSTM(self.n_nodes, return_sequences=True)(x) + + # Last LSTM layer with return_sequences=False to output final representation + x = tf.keras.layers.LSTM(self.n_nodes, return_sequences=False)(x) + + # Output Dense layer + output_layer = tf.keras.layers.Dense(self.n_channels * self.prediction_horizon)( + x + ) + + self.optimizer_ = ( + tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer + ) - super().__init__(axis=0) + model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer) + + model.compile(optimizer=self.optimizer_, loss=self.loss) + + return model def _fit(self, X: np.array, y: np.array): """Fit the model on the data. @@ -156,6 +271,7 @@ def _fit(self, X: np.array, y: np.array): X_val1, X_val2, y_val1, y_val2 = train_test_split( X_val, y_val, test_size=0.5, shuffle=False ) + X_train_n, y_train_n = _create_sequences( X_train, self.window_size, self.prediction_horizon ) @@ -174,30 +290,39 @@ def _fit(self, X: np.array, y: np.array): ) y_anomalies = y_anomalies.reshape(-1, self.prediction_horizon * self.n_channels) - # Create a stacked LSTM model and fit on the training data - self.model = self._build_model( - self.n_layers, - self.n_nodes, - self.n_channels, - self.window_size, - self.prediction_horizon, - ) - # self.model_summary_ = self.model.summary() - early_stopping = tf.keras.callbacks.EarlyStopping( - monitor="val_loss", patience=self.patience, restore_best_weights=True + # Fit LSTM model on the normal train set + # input_shape = (self.window_size, self.n_channels) + + self.training_model_ = self.build_model() + + if self.save_init_model: + self.training_model_.save(self.file_path + self.init_file_name + ".keras") + + if self.verbose: + self.training_model_.summary() + + self.file_name_ = ( + self.best_file_name if self.save_best_model else str(time.time_ns()) ) - self.history = self.model.fit( + + self.callbacks_ = [ + tf.keras.callbacks.EarlyStopping( + monitor="val_loss", patience=self.patience, restore_best_weights=True + ) + ] + + self.history = self.training_model_.fit( X_train_n, y_train_n, validation_data=(X_val_1, y_val_1), - epochs=self.n_epochs, batch_size=self.batch_size, - callbacks=[early_stopping], + epochs=self.n_epochs, verbose=self.verbose, + callbacks=self.callbacks_, ) # Prediction errors on validation set 1 to calculate error vector - predicted_vN1 = self.model.predict(X_val_1) + predicted_vN1 = self.training_model_.predict(X_val_1) errors_vN1 = y_val_1 - predicted_vN1 # Fit the error vectors to a Gaussian distribution @@ -211,8 +336,8 @@ def _fit(self, X: np.array, y: np.array): # Create a Gaussian Normal Distribution self.distribution = multivariate_normal(mean=mu, cov=cov_matrix) - predicted_vN2 = self.model.predict(X_val_2) - predicted_vA = self.model.predict(X_anomalies) + predicted_vN2 = self.training_model_.predict(X_val_2) + predicted_vA = self.training_model_.predict(X_anomalies) errors_vN2 = y_val_2 - predicted_vN2 errors_vA = y_anomalies - predicted_vA @@ -244,10 +369,12 @@ def _fit(self, X: np.array, y: np.array): self.best_tau = tau self.best_fbeta = fbeta + return self + def _predict(self, X): X_, y_ = _create_sequences(X, self.window_size, self.prediction_horizon) y_ = y_.reshape(-1, self.prediction_horizon * self.n_channels) - predict_test = self.model.predict(X_) + predict_test = self.training_model_.predict(X_) errors = y_ - predict_test likelihoods = self.distribution.pdf(errors) anomalies = (likelihoods < self.best_tau).astype(int) @@ -255,49 +382,6 @@ def _predict(self, X): prediction = np.concatenate([padding, anomalies]) return np.array(prediction, dtype=int) - def _build_model( - self, n_layers, n_nodes, n_channels, window_size, prediction_horizon - ): - """Construct a compiled, un-trained, keras model that is ready for training. - - In aeon, time series are stored in numpy arrays of shape (d,m), where d - is the number of dimensions, m is the series length. Keras/tensorflow assume - data is in shape (m,d). This method also assumes (m,d). Transpose should - happen in fit. - - Parameters - ---------- - n_layers : int - The number of layers in the LSTM model. - n_nodes : int - The number of LSTM units in each layer. - n_channels : int - It is basically d, the number of dimesions. - window_size : int - Tie number of time steps fed to the model. - prediction_horizon : int - The number of time steps to be predicted by the model. - - Returns - ------- - output : a compiled Keras Model - """ - import tensorflow as tf - - model = tf.keras.models.Sequential() - model.add(tf.keras.layers.Input(shape=(window_size, n_channels))) - model.add( - tf.keras.layers.LSTM(n_nodes, return_sequences=True) - ) # First LSTM layer - if n_layers > 2: - for _ in range(1, n_layers - 1): - model.add(tf.keras.layers.LSTM(n_nodes, return_sequences=True)) - # Last LSTM layer, return_sequences=False - model.add(tf.keras.layers.LSTM(n_nodes, return_sequences=False)) - model.add(tf.keras.layers.Dense(n_channels * prediction_horizon)) - model.compile(optimizer="adam", loss="mse") - return model - def _check_params(self, X: np.ndarray) -> None: if X.ndim == 1: self.n_channels = 1 diff --git a/aeon/anomaly_detection/deep_learning/base.py b/aeon/anomaly_detection/deep_learning/base.py new file mode 100644 index 0000000000..7eead0c1e2 --- /dev/null +++ b/aeon/anomaly_detection/deep_learning/base.py @@ -0,0 +1,131 @@ +""" +Abstract base class for the Keras neural network anomaly detectors. + +The reason for this class between BaseAnomalyDetector and deep_learning +anomaly detectors is because we can generalise tags, _predict and _predict_proba +""" + +__all__ = ["BaseDeepAnomalyDetector"] + +from abc import ABC, abstractmethod + +from aeon.anomaly_detection.base import BaseAnomalyDetector + + +class BaseDeepAnomalyDetector(BaseAnomalyDetector, ABC): + """Abstract base class for deep learning time series anomaly detection. + + The base anomaly detector provides a deep learning default method for + _predict and _predict_proba, and provides a new abstract method for building a + model. + + Parameters + ---------- + batch_size : int, default = 40 + training batch size for the model + last_file_name : str, default = "last_model" + The name of the file of the last model, used + only if save_last_model_to_file is used + + Arguments + --------- + self.model = None + + """ + + _tags = { + "capability:multivariate": True, + "algorithm_type": "deeplearning", + "non_deterministic": True, + "cant_pickle": True, + "python_dependencies": "tensorflow", + } + + def __init__( + self, + last_file_name="last_model", + ): + self.last_file_name = last_file_name + self.model_ = None + + super().__init__(axis=0) + + @abstractmethod + def build_model(self, input_shape): + """Construct a compiled, un-trained, keras model that is ready for training. + + Parameters + ---------- + input_shape : tuple + The shape of the data fed into the input layer + + Returns + ------- + A compiled Keras Model + """ + ... + + def summary(self): + """ + Summary function to return the losses/metrics for model fit. + + Returns + ------- + history : dict or None, + Dictionary containing model's train/validation losses and metrics + + """ + return self.history.history if self.history is not None else None + + def save_last_model_to_file(self, file_path="./"): + """Save the last epoch of the trained deep learning model. + + Parameters + ---------- + file_path : str, default = "./" + The directory where the model will be saved + + Returns + ------- + None + """ + self.model_.save(file_path + self.last_file_name + ".keras") + + def load_model(self, model_path, classes): + """Load a pre-trained keras model instead of fitting. + + When calling this function, all functionalities can be used + such as predict, predict_proba etc. with the loaded model. + + Parameters + ---------- + model_path : str (path including model name and extension) + The directory where the model will be loaded from including the model + name with a ".keras" extension. + Example: model_path="path/to/file/best_model.keras" + + Returns + ------- + None + """ + import tensorflow as tf + + self.model_ = tf.keras.models.load_model(model_path) + self.is_fitted = True + + self.classes_ = classes + self.n_classes_ = len(self.classes_) + + def _get_model_checkpoint_callback(self, callbacks, file_path, file_name): + import tensorflow as tf + + model_checkpoint_ = tf.keras.callbacks.ModelCheckpoint( + filepath=file_path + file_name + ".keras", + monitor="loss", + save_best_only=True, + ) + + if isinstance(callbacks, list): + return callbacks + [model_checkpoint_] + else: + return [callbacks] + [model_checkpoint_] diff --git a/aeon/networks/__init__.py b/aeon/networks/__init__.py index dd7b90f8e8..217fc60334 100644 --- a/aeon/networks/__init__.py +++ b/aeon/networks/__init__.py @@ -17,6 +17,7 @@ "AEAttentionBiGRUNetwork", "AEDRNNNetwork", "AEBiGRUNetwork", + "LSTMNetwork", ] from aeon.networks._ae_abgru import AEAttentionBiGRUNetwork @@ -29,6 +30,7 @@ from aeon.networks._fcn import FCNNetwork from aeon.networks._inception import InceptionNetwork from aeon.networks._lite import LITENetwork +from aeon.networks._lstm import LSTMNetwork from aeon.networks._mlp import MLPNetwork from aeon.networks._resnet import ResNetNetwork from aeon.networks._tapnet import TapNetNetwork diff --git a/aeon/networks/_lstm.py b/aeon/networks/_lstm.py new file mode 100644 index 0000000000..bb11b7faa6 --- /dev/null +++ b/aeon/networks/_lstm.py @@ -0,0 +1,62 @@ +"""Long Short Term Memory Network (LSTMNetwork).""" + +from aeon.networks.base import BaseDeepLearningNetwork + + +class LSTMNetwork(BaseDeepLearningNetwork): + """Establish the network structure for an LSTM. + + Inspired by _[1]. + + References + ---------- + .. [1] Malhotra Pankaj, Lovekesh Vig, Gautam Shroff, and Puneet Agarwal. + Long Short Term Memory Networks for Anomaly Detection in Time Series. In Proceedings + of the European Symposium on Artificial Neural Networks, Computational Intelligence + and Machine Learning (ESANN), Vol. 23, 2015. + https://www.esann.org/sites/default/files/proceedings/legacy/es2015-56.pdf + """ + + def __init__( + self, + n_nodes, + n_layers, + ): + self.n_nodes = n_nodes + self.n_layers = n_layers + super().__init__() + + def build_network(self, input_shape, prediction_horizon, **kwargs): + """Construct an LSTM network and return its input and output layers. + + Parameters + ---------- + input_shape : tuple of shape = (window_size (w), n_channels (d)) + The shape of the data fed into the input layer + n_nodes : int, optional (default=64) + The number of LSTM units in each layer + n_layers : int, optional (default=2) + The number of LSTM layers + + Returns + ------- + input_layer : a keras layer + output_layer : a keras layer + """ + import tensorflow as tf + + # Input layer for the LSTM model + input_layer = tf.keras.layers.Input(shape=input_shape) + + # Build the LSTM layers + x = input_layer + for _ in range(self.n_layers - 1): + x = tf.keras.layers.LSTM(self.n_nodes, return_sequences=True)(x) + + # Last LSTM layer with return_sequences=False to output final representation + x = tf.keras.layers.LSTM(self.n_nodes, return_sequences=False)(x) + + # Output Dense layer + output_layer = tf.keras.layers.Dense(prediction_horizon * input_shape[1])(x) + + return input_layer, output_layer From 64aa97e019d2362a8d43bb67d539bdf89329432a Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Tue, 15 Oct 2024 16:04:35 +0530 Subject: [PATCH 10/13] base deep learning class for AD --- .../deep_learning/_lstm_ad.py | 22 +++++-------------- aeon/networks/_lstm.py | 12 ++++++---- aeon/networks/tests/test_all_networks.py | 5 ++++- 3 files changed, 17 insertions(+), 22 deletions(-) diff --git a/aeon/anomaly_detection/deep_learning/_lstm_ad.py b/aeon/anomaly_detection/deep_learning/_lstm_ad.py index e49977b131..577cd9bd74 100644 --- a/aeon/anomaly_detection/deep_learning/_lstm_ad.py +++ b/aeon/anomaly_detection/deep_learning/_lstm_ad.py @@ -182,7 +182,9 @@ def __init__( super().__init__() - self._network = LSTMNetwork(self.n_nodes, self.n_layers) + self._network = LSTMNetwork( + self.n_nodes, self.n_layers, self.prediction_horizon + ) def build_model(self, **kwargs): """Construct a compiled, un-trained, keras model that is ready for training. @@ -211,22 +213,8 @@ def build_model(self, **kwargs): """ import tensorflow as tf - # input_layer, output_layer = self._network.build_network(input_shape, - # prediction_horizon, **kwargs) - # Input layer for the LSTM model - input_layer = tf.keras.layers.Input(shape=(self.window_size, self.n_channels)) - - # Build the LSTM layers - x = input_layer - for _ in range(self.n_layers - 1): - x = tf.keras.layers.LSTM(self.n_nodes, return_sequences=True)(x) - - # Last LSTM layer with return_sequences=False to output final representation - x = tf.keras.layers.LSTM(self.n_nodes, return_sequences=False)(x) - - # Output Dense layer - output_layer = tf.keras.layers.Dense(self.n_channels * self.prediction_horizon)( - x + input_layer, output_layer = self._network.build_network( + (self.window_size, self.n_channels), **kwargs ) self.optimizer_ = ( diff --git a/aeon/networks/_lstm.py b/aeon/networks/_lstm.py index bb11b7faa6..dbabd3c482 100644 --- a/aeon/networks/_lstm.py +++ b/aeon/networks/_lstm.py @@ -19,14 +19,16 @@ class LSTMNetwork(BaseDeepLearningNetwork): def __init__( self, - n_nodes, - n_layers, + n_nodes=64, + n_layers=2, + prediction_horizon=1, ): self.n_nodes = n_nodes self.n_layers = n_layers + self.prediction_horizon = prediction_horizon super().__init__() - def build_network(self, input_shape, prediction_horizon, **kwargs): + def build_network(self, input_shape, **kwargs): """Construct an LSTM network and return its input and output layers. Parameters @@ -57,6 +59,8 @@ def build_network(self, input_shape, prediction_horizon, **kwargs): x = tf.keras.layers.LSTM(self.n_nodes, return_sequences=False)(x) # Output Dense layer - output_layer = tf.keras.layers.Dense(prediction_horizon * input_shape[1])(x) + output_layer = tf.keras.layers.Dense(input_shape[1] * self.prediction_horizon)( + x + ) return input_layer, output_layer diff --git a/aeon/networks/tests/test_all_networks.py b/aeon/networks/tests/test_all_networks.py index 106a5b8b4f..c06c5d6759 100644 --- a/aeon/networks/tests/test_all_networks.py +++ b/aeon/networks/tests/test_all_networks.py @@ -39,7 +39,10 @@ def test_all_networks_functionality(network): if _check_soft_dependencies( network._config["python_dependencies"], severity="none" ) and _check_python_version(network._config["python_version"], severity="none"): - my_network = network() + if network.__name__ == "LSTMNetwork": + my_network = network(n_nodes=50, n_layers=2, prediction_horizon=1) + else: + my_network = network() if network._config["structure"] == "auto-encoder": encoder, decoder = my_network.build_network(input_shape=input_shape) From ce0c2133d5e051236bde4640734e24e6c04f2e19 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Tue, 15 Oct 2024 16:20:17 +0530 Subject: [PATCH 11/13] update init --- aeon/anomaly_detection/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aeon/anomaly_detection/__init__.py b/aeon/anomaly_detection/__init__.py index 294d18b16c..db02a9aaa6 100644 --- a/aeon/anomaly_detection/__init__.py +++ b/aeon/anomaly_detection/__init__.py @@ -9,6 +9,7 @@ "STOMP", "LeftSTAMPi", "IsolationForest", + "LSTM_AD", ] from aeon.anomaly_detection._dwt_mlead import DWT_MLEAD @@ -19,3 +20,4 @@ from aeon.anomaly_detection._pyodadapter import PyODAdapter from aeon.anomaly_detection._stomp import STOMP from aeon.anomaly_detection._stray import STRAY +from aeon.anomaly_detection.deep_learning._lstm_ad import LSTM_AD From 04e03964248651be4694bdfc1896b250f9a06b15 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Tue, 15 Oct 2024 21:16:29 +0530 Subject: [PATCH 12/13] Save and load model --- aeon/anomaly_detection/__init__.py | 2 -- .../deep_learning/__init__.py | 4 +++ .../deep_learning/_lstm_ad.py | 30 +++++++++++-------- aeon/anomaly_detection/deep_learning/base.py | 5 ++-- aeon/anomaly_detection/tests/test_lstm_ad.py | 2 +- 5 files changed, 24 insertions(+), 19 deletions(-) diff --git a/aeon/anomaly_detection/__init__.py b/aeon/anomaly_detection/__init__.py index db02a9aaa6..294d18b16c 100644 --- a/aeon/anomaly_detection/__init__.py +++ b/aeon/anomaly_detection/__init__.py @@ -9,7 +9,6 @@ "STOMP", "LeftSTAMPi", "IsolationForest", - "LSTM_AD", ] from aeon.anomaly_detection._dwt_mlead import DWT_MLEAD @@ -20,4 +19,3 @@ from aeon.anomaly_detection._pyodadapter import PyODAdapter from aeon.anomaly_detection._stomp import STOMP from aeon.anomaly_detection._stray import STRAY -from aeon.anomaly_detection.deep_learning._lstm_ad import LSTM_AD diff --git a/aeon/anomaly_detection/deep_learning/__init__.py b/aeon/anomaly_detection/deep_learning/__init__.py index 71d11d8423..47325d4253 100644 --- a/aeon/anomaly_detection/deep_learning/__init__.py +++ b/aeon/anomaly_detection/deep_learning/__init__.py @@ -1 +1,5 @@ """Deep learning based anomaly detector.""" + +__all__ = ["LSTM_AD"] + +from aeon.anomaly_detection.deep_learning._lstm_ad import LSTM_AD diff --git a/aeon/anomaly_detection/deep_learning/_lstm_ad.py b/aeon/anomaly_detection/deep_learning/_lstm_ad.py index 577cd9bd74..6eb1d06821 100644 --- a/aeon/anomaly_detection/deep_learning/_lstm_ad.py +++ b/aeon/anomaly_detection/deep_learning/_lstm_ad.py @@ -2,7 +2,10 @@ __all__ = ["LSTM_AD"] +import gc +import os import time +from copy import deepcopy import numpy as np from scipy.stats import multivariate_normal @@ -194,19 +197,6 @@ def build_model(self, **kwargs): data is in shape (m,d). This method also assumes (m,d). Transpose should happen in fit. - Parameters - ---------- - n_layers : int - The number of layers in the LSTM model. - n_nodes : int - The number of LSTM units in each layer. - n_channels : int - It is basically d, the number of dimesions. - window_size : int - Tie number of time steps fed to the model. - prediction_horizon : int - The number of time steps to be predicted by the model. - Returns ------- output : a compiled Keras Model @@ -357,6 +347,20 @@ def _fit(self, X: np.array, y: np.array): self.best_tau = tau self.best_fbeta = fbeta + try: + if self.save_best_model: + self.model_ = tf.keras.models.load_model( + self.file_path + self.file_name_ + ".keras", compile=False + ) + else: + os.remove(self.file_path + self.file_name_ + ".keras") + except FileNotFoundError: + self.model_ = deepcopy(self.training_model_) + + if self.save_last_model: + self.save_last_model_to_file(file_path=self.file_path) + + gc.collect() return self def _predict(self, X): diff --git a/aeon/anomaly_detection/deep_learning/base.py b/aeon/anomaly_detection/deep_learning/base.py index 7eead0c1e2..9053c1a568 100644 --- a/aeon/anomaly_detection/deep_learning/base.py +++ b/aeon/anomaly_detection/deep_learning/base.py @@ -43,8 +43,10 @@ class BaseDeepAnomalyDetector(BaseAnomalyDetector, ABC): def __init__( self, + batch_size=40, last_file_name="last_model", ): + self.batch_size = batch_size self.last_file_name = last_file_name self.model_ = None @@ -113,9 +115,6 @@ def load_model(self, model_path, classes): self.model_ = tf.keras.models.load_model(model_path) self.is_fitted = True - self.classes_ = classes - self.n_classes_ = len(self.classes_) - def _get_model_checkpoint_callback(self, callbacks, file_path, file_name): import tensorflow as tf diff --git a/aeon/anomaly_detection/tests/test_lstm_ad.py b/aeon/anomaly_detection/tests/test_lstm_ad.py index 8dbe3e9b00..1256c5c25a 100644 --- a/aeon/anomaly_detection/tests/test_lstm_ad.py +++ b/aeon/anomaly_detection/tests/test_lstm_ad.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from aeon.anomaly_detection import LSTM_AD +from aeon.anomaly_detection.deep_learning import LSTM_AD from aeon.testing.data_generation._legacy import make_series from aeon.utils.validation._dependencies import _check_soft_dependencies From 3a293ef52901ec708b7eeb5753aaecf3488f4846 Mon Sep 17 00:00:00 2001 From: Divya Tiwari Date: Tue, 15 Oct 2024 21:23:56 +0530 Subject: [PATCH 13/13] docs --- aeon/anomaly_detection/deep_learning/_lstm_ad.py | 2 +- docs/api_reference/anomaly_detection.rst | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/aeon/anomaly_detection/deep_learning/_lstm_ad.py b/aeon/anomaly_detection/deep_learning/_lstm_ad.py index 6eb1d06821..566f154bbd 100644 --- a/aeon/anomaly_detection/deep_learning/_lstm_ad.py +++ b/aeon/anomaly_detection/deep_learning/_lstm_ad.py @@ -123,7 +123,7 @@ class LSTM_AD(BaseDeepAnomalyDetector): -------- >>> import numpy as np >>> from aeon.datasets import load_anomaly_detection - >>> from aeon.anomaly_detection import LSTM_AD + >>> from aeon.anomaly_detection.deep_learning import LSTM_AD >>> X, y = load_anomaly_detection( ... name=("KDD-TSAD", "001_UCR_Anomaly_DISTORTED1sddb40") ... ) diff --git a/docs/api_reference/anomaly_detection.rst b/docs/api_reference/anomaly_detection.rst index 665f6f8ff8..e1bc7598a5 100644 --- a/docs/api_reference/anomaly_detection.rst +++ b/docs/api_reference/anomaly_detection.rst @@ -77,3 +77,4 @@ Detectors PyODAdapter STRAY STOMP + LSTM_AD