diff --git a/.github/workflows/pr_pytest.yml b/.github/workflows/pr_pytest.yml index fcdc665778..82dbdc0f14 100644 --- a/.github/workflows/pr_pytest.yml +++ b/.github/workflows/pr_pytest.yml @@ -3,7 +3,7 @@ name: PR pytest on: push: branches: - - main + - tcn_fst pull_request: paths: - "aeon/**" diff --git a/aeon/forecasting/deep_learning/__init__.py b/aeon/forecasting/deep_learning/__init__.py new file mode 100644 index 0000000000..8e3bac6a86 --- /dev/null +++ b/aeon/forecasting/deep_learning/__init__.py @@ -0,0 +1,9 @@ +"""Initialization for aeon forecasting deep learning module.""" + +__all__ = [ + "BaseDeepForecaster", + "TCNForecaster", +] + +from aeon.forecasting.deep_learning._tcn import TCNForecaster +from aeon.forecasting.deep_learning.base import BaseDeepForecaster diff --git a/aeon/forecasting/deep_learning/_tcn.py b/aeon/forecasting/deep_learning/_tcn.py new file mode 100644 index 0000000000..d3d06bf0fe --- /dev/null +++ b/aeon/forecasting/deep_learning/_tcn.py @@ -0,0 +1,150 @@ +"""TCNForecaster module for deep learning forecasting in aeon.""" + +from __future__ import annotations + +__maintainer__ = [] + +__all__ = ["TCNForecaster"] + +from typing import Any + +from aeon.forecasting.deep_learning.base import BaseDeepForecaster +from aeon.networks._tcn import TCNNetwork + + +class TCNForecaster(BaseDeepForecaster): + """A deep learning forecaster using Temporal Convolutional Network (TCN). + + It leverages the `TCNNetwork` from aeon's network module + to build the architecture suitable for forecasting tasks. + + Parameters + ---------- + horizon : int, default=1 + Forecasting horizon, the number of steps ahead to predict. + window : int, default=10 + The window size for creating input sequences. + batch_size : int, default=32 + Batch size for training the model. + epochs : int, default=100 + Number of epochs to train the model. + verbose : int, default=0 + Verbosity mode (0, 1, or 2). + optimizer : str or tf.keras.optimizers.Optimizer, default='adam' + Optimizer to use for training. + loss : str or tf.keras.losses.Loss, default='mse' + Loss function for training. + random_state : int, default=None + Seed for random number generators. + axis : int, default=0 + Axis along which to apply the forecaster. + n_blocks : list of int, default=[16, 16, 16] + List specifying the number of output channels for each layer of the + TCN. The length determines the depth of the network. + kernel_size : int, default=2 + Size of the convolutional kernel in the TCN. + dropout : float, default=0.2 + Dropout rate applied after each convolutional layer for + regularization. + """ + + _tags = { + "python_dependencies": ["tensorflow"], + "capability:horizon": True, + "capability:multivariate": True, + "capability:exogenous": False, + "capability:univariate": True, + } + + def __init__( + self, + horizon=1, + window=10, + batch_size=32, + epochs=100, + verbose=0, + optimizer="adam", + loss="mse", + random_state=None, + axis=0, + n_blocks=None, + kernel_size=2, + dropout=0.2, + ): + super().__init__( + horizon=horizon, + window=window, + batch_size=batch_size, + epochs=epochs, + verbose=verbose, + optimizer=optimizer, + random_state=random_state, + axis=axis, + loss=loss, + ) + self.n_blocks = n_blocks + self.kernel_size = kernel_size + self.dropout = dropout + + def _build_model(self, input_shape): + """Build the TCN model for forecasting. + + Parameters + ---------- + input_shape : tuple + Shape of input data, typically (window, num_inputs). + + Returns + ------- + model : tf.keras.Model + Compiled Keras model with TCN architecture. + """ + import tensorflow as tf + + # Initialize the TCN network with the updated parameters + network = TCNNetwork( + n_blocks=self.n_blocks if self.n_blocks is not None else [16, 16, 16], + kernel_size=self.kernel_size, + dropout=self.dropout, + ) + + # Build the network with the given input shape + input_layer, output = network.build_network(input_shape=input_shape) + + # Create the final model + model = tf.keras.Model(inputs=input_layer, outputs=output) + return model + + # Added to handle __name__ in tests (class-level access) + @classmethod + def _get_test_params( + cls, parameter_set: str = "default" + ) -> dict[str, Any] | list[dict[str, Any]]: + """ + Return testing parameter settings for the estimator. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + For forecasters, a "default" set of parameters should be provided for + general testing, and a "results_comparison" set for comparing against + previously recorded results if the general set does not produce suitable + probabilities to compare against. + + Returns + ------- + params : dict or list of dict, default={} + Parameters to create testing instances of the class. + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + """ + param = { + "epochs": 10, + "batch_size": 4, + "n_blocks": [8, 8], + "kernel_size": 2, + "dropout": 0.1, + } + return [param] diff --git a/aeon/forecasting/deep_learning/base.py b/aeon/forecasting/deep_learning/base.py new file mode 100644 index 0000000000..40c7b3e212 --- /dev/null +++ b/aeon/forecasting/deep_learning/base.py @@ -0,0 +1,234 @@ +"""Base class module for deep learning forecasters in aeon. + +This module defines the `BaseDeepForecaster` class, an abstract base class for +deep learning-based forecasting models within the aeon toolkit. +""" + +from __future__ import annotations + +__maintainer__ = [] +__all__ = ["BaseDeepForecaster"] + +from abc import abstractmethod + +import numpy as np +import pandas as pd + +from aeon.forecasting.base import BaseForecaster + + +class BaseDeepForecaster(BaseForecaster): + """Base class for deep learning forecasters in aeon. + + This class provides a foundation for deep learning-based forecasting models, + handling data preprocessing, model training, and prediction. + + Parameters + ---------- + horizon : int, default=1 + Forecasting horizon, the number of steps ahead to predict. + window : int, default=10 + The window size for creating input sequences. + batch_size : int, default=32 + Batch size for training the model. + epochs : int, default=100 + Number of epochs to train the model. + verbose : int, default=0 + Verbosity mode (0, 1, or 2). + optimizer : str or tf.keras.optimizers.Optimizer, default='adam' + Optimizer to use for training. + loss : str or tf.keras.losses.Loss, default='mse' + Loss function for training. + random_state : int, default=None + Seed for random number generators. + axis : int, default=0 + Axis along which to apply the forecaster. + Default is 0 for univariate time series. + """ + + def __init__( + self, + horizon=1, + window=10, + batch_size=32, + epochs=100, + verbose=0, + optimizer="adam", + loss="mse", + random_state=None, + axis=0, + ): + self.horizon = horizon + self.window = window + self.batch_size = batch_size + self.epochs = epochs + self.verbose = verbose + self.optimizer = optimizer + self.loss = loss + self.random_state = random_state + self.axis = axis + self.model_ = None + self.last_window_ = None + + # Pass horizon and axis to BaseForecaster + super().__init__(horizon=horizon, axis=axis) + + def _fit(self, y, X=None): + """Fit the forecaster to training data. + + Parameters + ---------- + y : np.ndarray or pd.Series + Target time series to which to fit the forecaster. + X : np.ndarray or pd.DataFrame, default=None + Exogenous variables. + + Returns + ------- + self : BaseDeepForecaster + Returns an instance of self. + """ + import tensorflow as tf + + # Set random seed for reproducibility + if self.random_state is not None: + np.random.seed(self.random_state) + tf.random.set_seed(self.random_state) + + # Convert input data to numpy array + y_inner = self._convert_input(y) + if y_inner.shape[0] < self.window + self.horizon: + raise ValueError( + f"Data length ({y_inner.shape[0]}) is insufficient" + f"({self.window}) and horizon ({self.horizon})." + ) + + # Create sequences for training + X_train, y_train = self._create_sequences(y_inner) + + if X_train.shape[0] == 0: + raise ValueError("No training sequences could be created.") + + # Build and compile the model + input_shape = X_train.shape[1:] + self.model_ = self._build_model(input_shape) + self.model_.compile(optimizer=self.optimizer, loss=self.loss) + + # Train the model + self.model_.fit( + X_train, + y_train, + batch_size=self.batch_size, + epochs=self.epochs, + verbose=self.verbose, + ) + self.last_window_ = y_inner[-self.window :] + return self + + def _predict(self, y=None, X=None): + """Make forecasts for y. + + Parameters + ---------- + y : np.ndarray or pd.Series, default=None + Series to predict from. If None, uses last fitted window. + X : np.ndarray or pd.DataFrame, default=None + Exogenous variables (not supported by default). + + Returns + ------- + predictions : np.ndarray + Predicted values for the specified horizon. + """ + if y is None: + if not hasattr(self, "last_window_"): + raise ValueError("No fitted data available for prediction.") + y_inner = self.last_window_ + else: + y_inner = self._convert_input(y) + if len(y_inner) < self.window: + raise ValueError( + f"Input data length ({len(y_inner)}) is less than the window size " + f"({self.window})." + ) + y_inner = y_inner[-self.window :] + + last_window = y_inner.reshape(1, self.window, 1) + predictions = [] + current_window = last_window + for _ in range(self.horizon): + pred = self.model_.predict(current_window, verbose=0) + predictions.append(pred[0, 0]) + current_window = np.roll(current_window, -1, axis=1) + current_window[0, -1, 0] = pred[0, 0] + return np.array(predictions) + + def _convert_input(self, y): + """Convert input data to numpy array. + + Parameters + ---------- + y : np.ndarray or pd.Series + Input time series. + + Returns + ------- + y_inner : np.ndarray + Converted numpy array. + """ + if isinstance(y, pd.Series) or isinstance(y, pd.DataFrame): + y_inner = y.values + else: + y_inner = y + + # Ensure 1D array + if len(y_inner.shape) > 1: + y_inner = y_inner.flatten() + + return y_inner + + def _create_sequences(self, data): + """Create input sequences and target values for training. + + Parameters + ---------- + data : np.ndarray + Time series data. + + Returns + ------- + X : np.ndarray + Input sequences. + y : np.ndarray + Target values. + """ + if len(data) < self.window + self.horizon: + raise ValueError( + f"Data length ({len(data)}) is insufficient for window " + f"({self.window}) and horizon ({self.horizon})." + ) + + X, y = [], [] + for i in range(len(data) - self.window - self.horizon + 1): + X.append(data[i : (i + self.window)]) + y.append(data[i + self.window : (i + self.window + self.horizon)]) + + X = np.array(X).reshape(-1, self.window, 1) + y = np.array(y).reshape(-1, self.horizon) + return X, y + + @abstractmethod + def _build_model(self, input_shape): + """Build the deep learning model. + + Parameters + ---------- + input_shape : tuple + Shape of input data. + + Returns + ------- + model : tf.keras.Model + Compiled Keras model. + """ + pass diff --git a/aeon/forecasting/deep_learning/tests/__init__.py b/aeon/forecasting/deep_learning/tests/__init__.py new file mode 100644 index 0000000000..3dda9d25ea --- /dev/null +++ b/aeon/forecasting/deep_learning/tests/__init__.py @@ -0,0 +1 @@ +"""Deep Learning Forecasting Tests File.""" diff --git a/aeon/forecasting/deep_learning/tests/test_base.py b/aeon/forecasting/deep_learning/tests/test_base.py new file mode 100644 index 0000000000..21e90e4a68 --- /dev/null +++ b/aeon/forecasting/deep_learning/tests/test_base.py @@ -0,0 +1,69 @@ +"""Test for BaseDeepForecaster class in aeon.""" + +import numpy as np +import pytest + +from aeon.forecasting.deep_learning import BaseDeepForecaster +from aeon.utils.validation._dependencies import _check_soft_dependencies + + +@pytest.mark.skipif( + not _check_soft_dependencies("tensorflow", severity="none"), + reason="skip test if required soft dependency not available", +) +class SimpleDeepForecaster(BaseDeepForecaster): + """A simple concrete implementation of BaseDeepForecaster for testing.""" + + def __init__(self, horizon=1, window=5, epochs=1, verbose=0): + super().__init__(horizon=horizon, window=window, epochs=epochs, verbose=verbose) + + def _build_model(self, input_shape): + import tensorflow as tf + + model = tf.keras.Sequential( + [ + tf.keras.layers.Flatten(input_shape=input_shape), + tf.keras.layers.Dense(10, activation="relu"), + tf.keras.layers.Dense(self.horizon), + ] + ) + return model + + +@pytest.mark.skipif( + not _check_soft_dependencies("tensorflow", severity="none"), + reason="skip test if required soft dependency not available", +) +def test_base_deep_forecaster_fit_predict(): + """Test fitting and predicting with BaseDeepForecaster implementation.""" + # Generate synthetic data + np.random.seed(42) + data = np.random.randn(50) + + # Initialize forecaster + forecaster = SimpleDeepForecaster(horizon=2, window=5, epochs=1, verbose=0) + + # Fit the model + forecaster.fit(data) + + # Predict + predictions = forecaster.predict() + + # Validate output shape + assert ( + len(predictions) == 2 + ), f"Expected predictions of length 2, got {len(predictions)}" + assert isinstance(predictions, np.ndarray), "Predictions should be a numpy array" + + +@pytest.mark.skipif( + not _check_soft_dependencies("tensorflow", severity="none"), + reason="skip test if required soft dependency not available", +) +def test_base_deep_forecaster_insufficient_data(): + """Test error handling for insufficient data.""" + data = np.random.randn(5) + forecaster = SimpleDeepForecaster(horizon=2, window=5, epochs=1, verbose=0) + + with pytest.raises(ValueError, match="Data length.*insufficient"): + forecaster.fit(data) diff --git a/aeon/forecasting/deep_learning/tests/test_tcn.py b/aeon/forecasting/deep_learning/tests/test_tcn.py new file mode 100644 index 0000000000..2717eaf4b4 --- /dev/null +++ b/aeon/forecasting/deep_learning/tests/test_tcn.py @@ -0,0 +1,37 @@ +"""Test TCN.""" + +__maintainer__ = [] +__all__ = [] + +import pytest + +from aeon.datasets import load_airline +from aeon.forecasting.deep_learning._tcn import TCNForecaster +from aeon.utils.validation._dependencies import _check_soft_dependencies + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("horizon,window,epochs", [(1, 10, 2), (3, 12, 3), (5, 15, 2)]) +def test_tcn_forecaster(horizon, window, epochs): + """Test TCNForecaster with different parameter combinations.""" + import tensorflow as tf + + # Load airline dataset + y = load_airline() + + # Initialize TCNForecaster + forecaster = TCNForecaster( + horizon=horizon, window=window, epochs=epochs, batch_size=16, verbose=0 + ) + + # Fit and predict + forecaster.fit(y) + prediction = forecaster.predict(y) + + # Basic assertions + assert prediction is not None + if isinstance(prediction, tf.Tensor): + assert not tf.math.is_nan(prediction).numpy() diff --git a/aeon/networks/__init__.py b/aeon/networks/__init__.py index d774abe102..aed37be7e7 100644 --- a/aeon/networks/__init__.py +++ b/aeon/networks/__init__.py @@ -19,6 +19,7 @@ "AEBiGRUNetwork", "DisjointCNNNetwork", "RecurrentNetwork", + "TCNNetwork", ] from aeon.networks._ae_abgru import AEAttentionBiGRUNetwork from aeon.networks._ae_bgru import AEBiGRUNetwork @@ -36,4 +37,5 @@ from aeon.networks._mlp import MLPNetwork from aeon.networks._resnet import ResNetNetwork from aeon.networks._rnn import RecurrentNetwork +from aeon.networks._tcn import TCNNetwork from aeon.networks.base import BaseDeepLearningNetwork diff --git a/aeon/networks/_tcn.py b/aeon/networks/_tcn.py new file mode 100644 index 0000000000..834b5865e7 --- /dev/null +++ b/aeon/networks/_tcn.py @@ -0,0 +1,341 @@ +"""Implementation of Temporal Convolutional Network (TCN).""" + +__maintainer__ = [] + +from aeon.networks.base import BaseDeepLearningNetwork + + +class TCNNetwork(BaseDeepLearningNetwork): + """Temporal Convolutional Network (TCN) for sequence modeling. + + A generic convolutional architecture for sequence modeling that combines: + - Dilated convolutions for exponentially large receptive fields + - Residual connections for training stability + + The TCN can take sequences of any length and map them to output sequences + of the same length, making it suitable for autoregressive prediction tasks. + + Parameters + ---------- + n_blocks : list of int + List specifying the number of output channels for each layer. + The length determines the depth of the network. + kernel_size : int, default=2 + Size of the convolutional kernel. Larger kernels can capture + more local context but require more parameters. + dropout : float, default=0.2 + Dropout rate applied after each convolutional layer for regularization. + + Notes + ----- + The receptive field size grows exponentially with network depth due to + dilated convolutions with dilation factors of 2^i for layer i. + + References + ---------- + .. [1] Bai, S., Kolter, J. Z., & Koltun, V. (2018). An empirical evaluation of + generic convolutional and recurrent networks for sequence modeling. + arXiv preprint arXiv:1803.01271. + + Examples + -------- + >>> from aeon.networks._tcn import TCNNetwork + >>> from aeon.testing.data_generation import make_example_3d_numpy + >>> import tensorflow as tf + >>> X, y = make_example_3d_numpy(n_cases=8, n_channels=4, n_timepoints=150, + ... return_y=True, regression_target=True, + ... random_state=42) + >>> network = TCNNetwork(n_blocks=[8, 8]) + >>> input_layer, output = network.build_network(input_shape=(4, 150)) + >>> model = tf.keras.Model(inputs=input_layer, outputs=output) + >>> model.compile(optimizer="adam", loss="mse") + >>> model.fit(X, y, epochs=2, batch_size=2, verbose=0) # doctest: +SKIP + + """ + + _config = { + "python_dependencies": ["tensorflow"], + "python_version": "<3.13", + "structure": "encoder", + } + + def __init__( + self, + n_blocks: list = [16] * 3, + kernel_size: int = 2, + dropout: float = 0.2, + ): + """Initialize the TCN architecture. + + Parameters + ---------- + num_inputs : int + Number of input channels/features. + n_blocks : list of int + Number of output channels for each temporal block. + kernel_size : int, default=2 + Size of convolutional kernels. + dropout : float, default=0.2 + Dropout rate for regularization. + """ + super().__init__() + self.n_blocks = n_blocks + self.kernel_size = kernel_size + self.dropout = dropout + + def _conv1d_with_variable_padding( + self, + input_tensor, + n_filters: int, + kernel_size: int, + padding_value: int, + strides: int = 1, + dilation_rate: int = 1, + ): + """Apply 1D convolution with variable padding for causal convolutions. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape (batch_size, channels, sequence_length). + n_filters : int + Number of output filters. + kernel_size : int + Size of the convolutional kernel. + padding_value : int + Amount of padding to apply. + strides : int, default=1 + Stride of the convolution. + dilation_rate : int, default=1 + Dilation rate for dilated convolutions. + + Returns + ------- + tf.Tensor + Output tensor after convolution. + """ + import tensorflow as tf + + # Transpose to Keras format (batch, sequence, channels) + x_keras_format = tf.keras.layers.Permute((2, 1))(input_tensor) + + # Apply padding in sequence dimension + padded_x = tf.keras.layers.ZeroPadding1D(padding=padding_value)(x_keras_format) + + # Create and apply convolution layer + conv_layer = tf.keras.layers.Conv1D( + filters=n_filters, + kernel_size=kernel_size, + strides=strides, + dilation_rate=dilation_rate, + padding="valid", + ) + + # Apply convolution + out = conv_layer(padded_x) + + # Transpose back to PyTorch format (batch, channels, sequence) + return tf.keras.layers.Permute((2, 1))(out) + + def _chomp(self, input_tensor, chomp_size: int): + """Remove padding from the end of sequences to maintain causality. + + This operation ensures that the output at time t only depends on + inputs from times 0 to t, preventing information leakage from future. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape (batch_size, channels, sequence_length). + chomp_size : int + Number of time steps to remove from the end. + + Returns + ------- + tf.Tensor + Chomped tensor with reduced sequence length. + """ + return input_tensor[:, :, :-chomp_size] + + def _temporal_block( + self, + input_tensor, + n_inputs: int, + n_filters: int, + kernel_size: int, + strides: int, + dilation_rate: int, + padding_value: int, + dropout: float = 0.2, + training: bool = None, + ): + """Create a temporal block with dilated causal convolutions. + + Each temporal block consists of: + 1. Two dilated causal convolutions + 2. ReLU activations and dropout for regularization + 3. Residual connection with optional 1x1 convolution for dimension + matching + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape (batch_size, channels, sequence_length). + n_inputs : int + Number of input channels. + n_filters : int + Number of output filters. + kernel_size : int + Size of convolutional kernels. + strides : int + Stride of convolutions (typically 1). + dilation_rate : int + Dilation factor for dilated convolutions. + padding_value : int + Padding size to be chomped off. + dropout : float, default=0.2 + Dropout rate for regularization. + training : bool, optional + Whether the model is in training mode. + + Returns + ------- + tf.Tensor + Output tensor of shape (batch_size, n_filters, sequence_length). + """ + import tensorflow as tf + + # First convolution block + out = self._conv1d_with_variable_padding( + input_tensor, n_filters, kernel_size, padding_value, strides, dilation_rate + ) + out = self._chomp(out, padding_value) + out = tf.keras.layers.ReLU()(out) + out = tf.keras.layers.Dropout(dropout)(out, training=training) + + # Second convolution block + out = self._conv1d_with_variable_padding( + out, n_filters, kernel_size, padding_value, strides, dilation_rate + ) + out = self._chomp(out, padding_value) + out = tf.keras.layers.ReLU()(out) + out = tf.keras.layers.Dropout(dropout)(out, training=training) + + # Residual connection with optional dimension matching + if n_inputs != n_filters: + res = self._conv1d_with_variable_padding( + input_tensor=input_tensor, + n_filters=n_filters, + kernel_size=1, + padding_value=0, + strides=1, + dilation_rate=1, + ) + else: + res = input_tensor + + # Add residual and apply final ReLU + result = tf.keras.layers.Add()([out, res]) + return tf.keras.layers.ReLU()(result) + + def _temporal_conv_net( + self, + input_tensor, + n_inputs: int, + n_blocks: list, + kernel_size: int = 2, + dropout: float = 0.2, + training: bool = None, + ): + """Apply the complete Temporal Convolutional Network. + + Stacks multiple temporal blocks with exponentially increasing dilation + factors to achieve a large receptive field efficiently. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape (batch_size, channels, sequence_length). + n_inputs : int + Number of input channels. + n_blocks : list of int + Number of output channels for each temporal block. + kernel_size : int, default=2 + Size of convolutional kernels. + dropout : float, default=0.2 + Dropout rate for regularization. + training : bool, optional + Whether the model is in training mode. + + Returns + ------- + tf.Tensor + Output tensor after applying all temporal blocks. + """ + num_levels = len(n_blocks) + for i in range(num_levels): + dilation_rate = 2**i + in_channels = n_inputs if i == 0 else n_blocks[i - 1] + out_channels = n_blocks[i] + padding_value = (kernel_size - 1) * dilation_rate + + input_tensor = self._temporal_block( + input_tensor, + n_inputs=in_channels, + n_filters=out_channels, + kernel_size=kernel_size, + strides=1, + dilation_rate=dilation_rate, + padding_value=padding_value, + dropout=dropout, + training=training, + ) + + return input_tensor + + def build_network(self, input_shape: tuple, **kwargs) -> tuple: + """Build the complete TCN architecture. + + Constructs a series of temporal blocks with exponentially increasing + dilation factors to achieve a large receptive field efficiently. + + Parameters + ---------- + input_shape : tuple + Shape of input data (n_channels, n_timepoints). + **kwargs + Additional keyword arguments (unused). + + Returns + ------- + tuple + A tuple containing (input_layer, output_tensor) representing + the complete network architecture. + + Notes + ----- + The dilation factor for layer i is 2^i, which ensures exponential + growth of the receptive field while maintaining computational + efficiency. + """ + import tensorflow as tf + + # Create input layer + input_layer = tf.keras.layers.Input(shape=input_shape) + + # Transpose input to match the expected format (batch, channels, seq) + x = input_layer + n_inputs = input_shape[0] + + # Apply TCN using the private function + x = self._temporal_conv_net( + x, + n_inputs=n_inputs, + n_blocks=self.n_blocks, + kernel_size=self.kernel_size, + dropout=self.dropout, + ) + + x = tf.keras.layers.Dense(input_shape[0])(x[:, -1, :]) + output = tf.keras.layers.Dense(1)(x) + return input_layer, output diff --git a/aeon/networks/tests/test_tcn.py b/aeon/networks/tests/test_tcn.py new file mode 100644 index 0000000000..94495e3c41 --- /dev/null +++ b/aeon/networks/tests/test_tcn.py @@ -0,0 +1,192 @@ +"""Tests for the TCNNetwork.""" + +import pytest + +from aeon.networks import TCNNetwork +from aeon.utils.validation._dependencies import _check_soft_dependencies + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_basic(): + """Test basic TCN network creation and build_network functionality.""" + import tensorflow as tf + + input_shape = (100, 5) + n_blocks = [32, 64] + + tcn_network = TCNNetwork(n_blocks=n_blocks) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Check that layers are created correctly + assert hasattr(input_layer, "shape"), "Input layer should have a shape attribute" + assert hasattr(output_layer, "shape"), "Output layer should have a shape attribute" + assert input_layer.dtype == tf.float32 + assert output_layer.dtype == tf.float32 + + # Create a model to test the network structure + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None, "Model should be created successfully" + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("n_blocks", [[32], [32, 64], [16, 32, 64], [64, 32, 16]]) +def test_tcn_network_different_channels(n_blocks): + """Test TCN network with different channel configurations.""" + import tensorflow as tf + + input_shape = (50, 3) + + tcn_network = TCNNetwork(n_blocks=n_blocks) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Create a model and verify it works + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None + + # Test with dummy data + import numpy as np + + dummy_input = np.random.random((8,) + input_shape) + output = model(dummy_input) + assert output is not None, "Model should produce output" + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("kernel_size", [2, 3, 5]) +def test_tcn_network_kernel_sizes(kernel_size): + """Test TCN network with different kernel sizes.""" + import tensorflow as tf + + input_shape = (80, 4) + n_blocks = [32, 64] + + tcn_network = TCNNetwork( + n_blocks=n_blocks, + kernel_size=kernel_size, + ) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Verify network builds successfully + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("dropout", [0.0, 0.1, 0.3, 0.5]) +def test_tcn_network_dropout_rates(dropout): + """Test TCN network with different dropout rates.""" + import tensorflow as tf + + input_shape = (60, 2) + n_blocks = [16, 32] + + tcn_network = TCNNetwork(n_blocks=n_blocks, dropout=dropout) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Verify network builds successfully + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_output_shape(): + """Test TCN network output shapes.""" + import numpy as np + import tensorflow as tf + + input_shape = (40, 6) + batch_size = 16 + n_blocks = [32, 64] + + tcn_network = TCNNetwork(n_blocks=n_blocks) + input_layer, output_layer = tcn_network.build_network(input_shape) + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + + # Create dummy input and test output shape + dummy_input = np.random.random((batch_size,) + input_shape) + output = model(dummy_input) + + # Output should maintain sequence length and have final channel dimension + expected_shape = (batch_size, 1) + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, got {output.shape}" + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_config(): + """Test TCN network configuration attributes.""" + tcn_network = TCNNetwork(n_blocks=[16, 32]) + + # Check _config attributes + assert "python_dependencies" in tcn_network._config + assert "tensorflow" in tcn_network._config["python_dependencies"] + assert "python_version" in tcn_network._config + assert "structure" in tcn_network._config + assert tcn_network._config["structure"] == "encoder" + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_parameter_initialization(): + """Test TCN network parameter initialization.""" + n_blocks = [32, 64, 128] + kernel_size = 3 + dropout = 0.2 + + tcn_network = TCNNetwork( + n_blocks=n_blocks, + kernel_size=kernel_size, + dropout=dropout, + ) + + # Check that parameters are set correctly + assert tcn_network.n_blocks == n_blocks + assert tcn_network.kernel_size == kernel_size + assert tcn_network.dropout == dropout + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_single_layer(): + """Test TCN network with single temporal block.""" + import tensorflow as tf + + input_shape = (30, 2) + n_blocks = [16] # Single layer + + tcn_network = TCNNetwork(n_blocks=n_blocks) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Verify single layer network works + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None + + # Test with dummy data + import numpy as np + + dummy_input = np.random.random((4,) + input_shape) + output = model(dummy_input) + assert output.shape == (4, 1)