diff --git a/.github/workflows/pr_pytest.yml b/.github/workflows/pr_pytest.yml index fcdc665778..5240ce85ec 100644 --- a/.github/workflows/pr_pytest.yml +++ b/.github/workflows/pr_pytest.yml @@ -3,7 +3,7 @@ name: PR pytest on: push: branches: - - main + - tcn_net pull_request: paths: - "aeon/**" diff --git a/aeon/networks/__init__.py b/aeon/networks/__init__.py index d774abe102..aed37be7e7 100644 --- a/aeon/networks/__init__.py +++ b/aeon/networks/__init__.py @@ -19,6 +19,7 @@ "AEBiGRUNetwork", "DisjointCNNNetwork", "RecurrentNetwork", + "TCNNetwork", ] from aeon.networks._ae_abgru import AEAttentionBiGRUNetwork from aeon.networks._ae_bgru import AEBiGRUNetwork @@ -36,4 +37,5 @@ from aeon.networks._mlp import MLPNetwork from aeon.networks._resnet import ResNetNetwork from aeon.networks._rnn import RecurrentNetwork +from aeon.networks._tcn import TCNNetwork from aeon.networks.base import BaseDeepLearningNetwork diff --git a/aeon/networks/_tcn.py b/aeon/networks/_tcn.py new file mode 100644 index 0000000000..834b5865e7 --- /dev/null +++ b/aeon/networks/_tcn.py @@ -0,0 +1,341 @@ +"""Implementation of Temporal Convolutional Network (TCN).""" + +__maintainer__ = [] + +from aeon.networks.base import BaseDeepLearningNetwork + + +class TCNNetwork(BaseDeepLearningNetwork): + """Temporal Convolutional Network (TCN) for sequence modeling. + + A generic convolutional architecture for sequence modeling that combines: + - Dilated convolutions for exponentially large receptive fields + - Residual connections for training stability + + The TCN can take sequences of any length and map them to output sequences + of the same length, making it suitable for autoregressive prediction tasks. + + Parameters + ---------- + n_blocks : list of int + List specifying the number of output channels for each layer. + The length determines the depth of the network. + kernel_size : int, default=2 + Size of the convolutional kernel. Larger kernels can capture + more local context but require more parameters. + dropout : float, default=0.2 + Dropout rate applied after each convolutional layer for regularization. + + Notes + ----- + The receptive field size grows exponentially with network depth due to + dilated convolutions with dilation factors of 2^i for layer i. + + References + ---------- + .. [1] Bai, S., Kolter, J. Z., & Koltun, V. (2018). An empirical evaluation of + generic convolutional and recurrent networks for sequence modeling. + arXiv preprint arXiv:1803.01271. + + Examples + -------- + >>> from aeon.networks._tcn import TCNNetwork + >>> from aeon.testing.data_generation import make_example_3d_numpy + >>> import tensorflow as tf + >>> X, y = make_example_3d_numpy(n_cases=8, n_channels=4, n_timepoints=150, + ... return_y=True, regression_target=True, + ... random_state=42) + >>> network = TCNNetwork(n_blocks=[8, 8]) + >>> input_layer, output = network.build_network(input_shape=(4, 150)) + >>> model = tf.keras.Model(inputs=input_layer, outputs=output) + >>> model.compile(optimizer="adam", loss="mse") + >>> model.fit(X, y, epochs=2, batch_size=2, verbose=0) # doctest: +SKIP + + """ + + _config = { + "python_dependencies": ["tensorflow"], + "python_version": "<3.13", + "structure": "encoder", + } + + def __init__( + self, + n_blocks: list = [16] * 3, + kernel_size: int = 2, + dropout: float = 0.2, + ): + """Initialize the TCN architecture. + + Parameters + ---------- + num_inputs : int + Number of input channels/features. + n_blocks : list of int + Number of output channels for each temporal block. + kernel_size : int, default=2 + Size of convolutional kernels. + dropout : float, default=0.2 + Dropout rate for regularization. + """ + super().__init__() + self.n_blocks = n_blocks + self.kernel_size = kernel_size + self.dropout = dropout + + def _conv1d_with_variable_padding( + self, + input_tensor, + n_filters: int, + kernel_size: int, + padding_value: int, + strides: int = 1, + dilation_rate: int = 1, + ): + """Apply 1D convolution with variable padding for causal convolutions. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape (batch_size, channels, sequence_length). + n_filters : int + Number of output filters. + kernel_size : int + Size of the convolutional kernel. + padding_value : int + Amount of padding to apply. + strides : int, default=1 + Stride of the convolution. + dilation_rate : int, default=1 + Dilation rate for dilated convolutions. + + Returns + ------- + tf.Tensor + Output tensor after convolution. + """ + import tensorflow as tf + + # Transpose to Keras format (batch, sequence, channels) + x_keras_format = tf.keras.layers.Permute((2, 1))(input_tensor) + + # Apply padding in sequence dimension + padded_x = tf.keras.layers.ZeroPadding1D(padding=padding_value)(x_keras_format) + + # Create and apply convolution layer + conv_layer = tf.keras.layers.Conv1D( + filters=n_filters, + kernel_size=kernel_size, + strides=strides, + dilation_rate=dilation_rate, + padding="valid", + ) + + # Apply convolution + out = conv_layer(padded_x) + + # Transpose back to PyTorch format (batch, channels, sequence) + return tf.keras.layers.Permute((2, 1))(out) + + def _chomp(self, input_tensor, chomp_size: int): + """Remove padding from the end of sequences to maintain causality. + + This operation ensures that the output at time t only depends on + inputs from times 0 to t, preventing information leakage from future. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape (batch_size, channels, sequence_length). + chomp_size : int + Number of time steps to remove from the end. + + Returns + ------- + tf.Tensor + Chomped tensor with reduced sequence length. + """ + return input_tensor[:, :, :-chomp_size] + + def _temporal_block( + self, + input_tensor, + n_inputs: int, + n_filters: int, + kernel_size: int, + strides: int, + dilation_rate: int, + padding_value: int, + dropout: float = 0.2, + training: bool = None, + ): + """Create a temporal block with dilated causal convolutions. + + Each temporal block consists of: + 1. Two dilated causal convolutions + 2. ReLU activations and dropout for regularization + 3. Residual connection with optional 1x1 convolution for dimension + matching + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape (batch_size, channels, sequence_length). + n_inputs : int + Number of input channels. + n_filters : int + Number of output filters. + kernel_size : int + Size of convolutional kernels. + strides : int + Stride of convolutions (typically 1). + dilation_rate : int + Dilation factor for dilated convolutions. + padding_value : int + Padding size to be chomped off. + dropout : float, default=0.2 + Dropout rate for regularization. + training : bool, optional + Whether the model is in training mode. + + Returns + ------- + tf.Tensor + Output tensor of shape (batch_size, n_filters, sequence_length). + """ + import tensorflow as tf + + # First convolution block + out = self._conv1d_with_variable_padding( + input_tensor, n_filters, kernel_size, padding_value, strides, dilation_rate + ) + out = self._chomp(out, padding_value) + out = tf.keras.layers.ReLU()(out) + out = tf.keras.layers.Dropout(dropout)(out, training=training) + + # Second convolution block + out = self._conv1d_with_variable_padding( + out, n_filters, kernel_size, padding_value, strides, dilation_rate + ) + out = self._chomp(out, padding_value) + out = tf.keras.layers.ReLU()(out) + out = tf.keras.layers.Dropout(dropout)(out, training=training) + + # Residual connection with optional dimension matching + if n_inputs != n_filters: + res = self._conv1d_with_variable_padding( + input_tensor=input_tensor, + n_filters=n_filters, + kernel_size=1, + padding_value=0, + strides=1, + dilation_rate=1, + ) + else: + res = input_tensor + + # Add residual and apply final ReLU + result = tf.keras.layers.Add()([out, res]) + return tf.keras.layers.ReLU()(result) + + def _temporal_conv_net( + self, + input_tensor, + n_inputs: int, + n_blocks: list, + kernel_size: int = 2, + dropout: float = 0.2, + training: bool = None, + ): + """Apply the complete Temporal Convolutional Network. + + Stacks multiple temporal blocks with exponentially increasing dilation + factors to achieve a large receptive field efficiently. + + Parameters + ---------- + input_tensor : tf.Tensor + Input tensor of shape (batch_size, channels, sequence_length). + n_inputs : int + Number of input channels. + n_blocks : list of int + Number of output channels for each temporal block. + kernel_size : int, default=2 + Size of convolutional kernels. + dropout : float, default=0.2 + Dropout rate for regularization. + training : bool, optional + Whether the model is in training mode. + + Returns + ------- + tf.Tensor + Output tensor after applying all temporal blocks. + """ + num_levels = len(n_blocks) + for i in range(num_levels): + dilation_rate = 2**i + in_channels = n_inputs if i == 0 else n_blocks[i - 1] + out_channels = n_blocks[i] + padding_value = (kernel_size - 1) * dilation_rate + + input_tensor = self._temporal_block( + input_tensor, + n_inputs=in_channels, + n_filters=out_channels, + kernel_size=kernel_size, + strides=1, + dilation_rate=dilation_rate, + padding_value=padding_value, + dropout=dropout, + training=training, + ) + + return input_tensor + + def build_network(self, input_shape: tuple, **kwargs) -> tuple: + """Build the complete TCN architecture. + + Constructs a series of temporal blocks with exponentially increasing + dilation factors to achieve a large receptive field efficiently. + + Parameters + ---------- + input_shape : tuple + Shape of input data (n_channels, n_timepoints). + **kwargs + Additional keyword arguments (unused). + + Returns + ------- + tuple + A tuple containing (input_layer, output_tensor) representing + the complete network architecture. + + Notes + ----- + The dilation factor for layer i is 2^i, which ensures exponential + growth of the receptive field while maintaining computational + efficiency. + """ + import tensorflow as tf + + # Create input layer + input_layer = tf.keras.layers.Input(shape=input_shape) + + # Transpose input to match the expected format (batch, channels, seq) + x = input_layer + n_inputs = input_shape[0] + + # Apply TCN using the private function + x = self._temporal_conv_net( + x, + n_inputs=n_inputs, + n_blocks=self.n_blocks, + kernel_size=self.kernel_size, + dropout=self.dropout, + ) + + x = tf.keras.layers.Dense(input_shape[0])(x[:, -1, :]) + output = tf.keras.layers.Dense(1)(x) + return input_layer, output diff --git a/aeon/networks/tests/test_tcn.py b/aeon/networks/tests/test_tcn.py new file mode 100644 index 0000000000..94495e3c41 --- /dev/null +++ b/aeon/networks/tests/test_tcn.py @@ -0,0 +1,192 @@ +"""Tests for the TCNNetwork.""" + +import pytest + +from aeon.networks import TCNNetwork +from aeon.utils.validation._dependencies import _check_soft_dependencies + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_basic(): + """Test basic TCN network creation and build_network functionality.""" + import tensorflow as tf + + input_shape = (100, 5) + n_blocks = [32, 64] + + tcn_network = TCNNetwork(n_blocks=n_blocks) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Check that layers are created correctly + assert hasattr(input_layer, "shape"), "Input layer should have a shape attribute" + assert hasattr(output_layer, "shape"), "Output layer should have a shape attribute" + assert input_layer.dtype == tf.float32 + assert output_layer.dtype == tf.float32 + + # Create a model to test the network structure + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None, "Model should be created successfully" + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("n_blocks", [[32], [32, 64], [16, 32, 64], [64, 32, 16]]) +def test_tcn_network_different_channels(n_blocks): + """Test TCN network with different channel configurations.""" + import tensorflow as tf + + input_shape = (50, 3) + + tcn_network = TCNNetwork(n_blocks=n_blocks) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Create a model and verify it works + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None + + # Test with dummy data + import numpy as np + + dummy_input = np.random.random((8,) + input_shape) + output = model(dummy_input) + assert output is not None, "Model should produce output" + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("kernel_size", [2, 3, 5]) +def test_tcn_network_kernel_sizes(kernel_size): + """Test TCN network with different kernel sizes.""" + import tensorflow as tf + + input_shape = (80, 4) + n_blocks = [32, 64] + + tcn_network = TCNNetwork( + n_blocks=n_blocks, + kernel_size=kernel_size, + ) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Verify network builds successfully + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("dropout", [0.0, 0.1, 0.3, 0.5]) +def test_tcn_network_dropout_rates(dropout): + """Test TCN network with different dropout rates.""" + import tensorflow as tf + + input_shape = (60, 2) + n_blocks = [16, 32] + + tcn_network = TCNNetwork(n_blocks=n_blocks, dropout=dropout) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Verify network builds successfully + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_output_shape(): + """Test TCN network output shapes.""" + import numpy as np + import tensorflow as tf + + input_shape = (40, 6) + batch_size = 16 + n_blocks = [32, 64] + + tcn_network = TCNNetwork(n_blocks=n_blocks) + input_layer, output_layer = tcn_network.build_network(input_shape) + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + + # Create dummy input and test output shape + dummy_input = np.random.random((batch_size,) + input_shape) + output = model(dummy_input) + + # Output should maintain sequence length and have final channel dimension + expected_shape = (batch_size, 1) + assert ( + output.shape == expected_shape + ), f"Expected shape {expected_shape}, got {output.shape}" + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_config(): + """Test TCN network configuration attributes.""" + tcn_network = TCNNetwork(n_blocks=[16, 32]) + + # Check _config attributes + assert "python_dependencies" in tcn_network._config + assert "tensorflow" in tcn_network._config["python_dependencies"] + assert "python_version" in tcn_network._config + assert "structure" in tcn_network._config + assert tcn_network._config["structure"] == "encoder" + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_parameter_initialization(): + """Test TCN network parameter initialization.""" + n_blocks = [32, 64, 128] + kernel_size = 3 + dropout = 0.2 + + tcn_network = TCNNetwork( + n_blocks=n_blocks, + kernel_size=kernel_size, + dropout=dropout, + ) + + # Check that parameters are set correctly + assert tcn_network.n_blocks == n_blocks + assert tcn_network.kernel_size == kernel_size + assert tcn_network.dropout == dropout + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +def test_tcn_network_single_layer(): + """Test TCN network with single temporal block.""" + import tensorflow as tf + + input_shape = (30, 2) + n_blocks = [16] # Single layer + + tcn_network = TCNNetwork(n_blocks=n_blocks) + input_layer, output_layer = tcn_network.build_network(input_shape) + + # Verify single layer network works + model = tf.keras.Model(inputs=input_layer, outputs=output_layer) + assert model is not None + + # Test with dummy data + import numpy as np + + dummy_input = np.random.random((4,) + input_shape) + output = model(dummy_input) + assert output.shape == (4, 1)