From 22ac415c5c5f537807821b91e702126efd3e1e79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Petelin?= Date: Mon, 28 Apr 2025 15:28:05 +0200 Subject: [PATCH 1/9] Added minimum TS2Vec implementation --- .../collection/contrastive_based/__init__.py | 7 + .../collection/contrastive_based/_ts2vec.py | 567 ++++++++++++++++++ .../contrastive_based/tests/__init__.py | 1 + .../contrastive_based/tests/test_ts2vec.py | 29 + 4 files changed, 604 insertions(+) create mode 100644 aeon/transformations/collection/contrastive_based/__init__.py create mode 100644 aeon/transformations/collection/contrastive_based/_ts2vec.py create mode 100644 aeon/transformations/collection/contrastive_based/tests/__init__.py create mode 100644 aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py diff --git a/aeon/transformations/collection/contrastive_based/__init__.py b/aeon/transformations/collection/contrastive_based/__init__.py new file mode 100644 index 0000000000..55e8fe4ccc --- /dev/null +++ b/aeon/transformations/collection/contrastive_based/__init__.py @@ -0,0 +1,7 @@ +"""Contrastive learning transformers.""" + +__all__ = [ + "TS2Vec", +] + +from ._ts2vec import TS2Vec diff --git a/aeon/transformations/collection/contrastive_based/_ts2vec.py b/aeon/transformations/collection/contrastive_based/_ts2vec.py new file mode 100644 index 0000000000..d27c254286 --- /dev/null +++ b/aeon/transformations/collection/contrastive_based/_ts2vec.py @@ -0,0 +1,567 @@ +"""TS2Vec Transformer.""" + +__maintainer__ = ["GasperPetelin"] +__all__ = ["TS2Vec"] + +import numpy as np +from aeon.transformations.collection import BaseCollectionTransformer +from aeon.utils.validation import check_n_jobs +from aeon.utils.validation._dependencies import _check_soft_dependencies + +class TS2Vec(BaseCollectionTransformer): + _tags = { + "capability:multivariate": True, + "output_data_type": "Tabular", + "algorithm_type": "contrastive", + "python_dependencies": "torch", + } + + def __init__(self, output_dim=320, n_jobs=1): + self.output_dim = output_dim + self.n_jobs = n_jobs + super().__init__() + + def _transform(self, X, y=None): + return self._ts2vec.encode(X.transpose(0, 2, 1), encoding_window='full_series') + + def _fit(self, X, y=None): + self._ts2vec = _TS2Vec( + input_dims=X.shape[1], + output_dims=self.output_dim, + #device='cuda', + ) + self._ts2vec.fit(X.transpose(0, 2, 1), verbose=False) + return self + +if _check_soft_dependencies("torch", severity="none"): + import torch + import torch.nn.functional as F + from torch import nn + from torch.utils.data import TensorDataset, DataLoader + + class _TS2Vec(): + def __init__( + self, + input_dims, + output_dims=320, + hidden_dims=64, + depth=10, + device='cuda', + lr=0.001, + batch_size=16, + max_train_length=None, + temporal_unit=0, + after_iter_callback=None, + after_epoch_callback=None + ): + ''' Initialize a TS2Vec model. + + Args: + input_dims (int): The input dimension. For a univariate time series, this should be set to 1. + output_dims (int): The representation dimension. + hidden_dims (int): The hidden dimension of the encoder. + depth (int): The number of hidden residual blocks in the encoder. + device (int): The gpu used for training and inference. + lr (int): The learning rate. + batch_size (int): The batch size. + max_train_length (Union[int, NoneType]): The maximum allowed sequence length for training. For sequence with a length greater than , it would be cropped into some sequences, each of which has a length less than . + temporal_unit (int): The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory. + after_iter_callback (Union[Callable, NoneType]): A callback function that would be called after each iteration. + after_epoch_callback (Union[Callable, NoneType]): A callback function that would be called after each epoch. + ''' + + super().__init__() + self.device = device + self.lr = lr + self.batch_size = batch_size + self.max_train_length = max_train_length + self.temporal_unit = temporal_unit + + self._net = TSEncoder(input_dims=input_dims, output_dims=output_dims, hidden_dims=hidden_dims, depth=depth).to(self.device) + self.net = torch.optim.swa_utils.AveragedModel(self._net) + self.net.update_parameters(self._net) + + self.after_iter_callback = after_iter_callback + self.after_epoch_callback = after_epoch_callback + + self.n_epochs = 0 + self.n_iters = 0 + + @staticmethod + def pad_nan_to_target(array, target_length, axis=0, both_side=False): + assert array.dtype in [np.float16, np.float32, np.float64] + pad_size = target_length - array.shape[axis] + if pad_size <= 0: + return array + npad = [(0, 0)] * array.ndim + if both_side: + npad[axis] = (pad_size // 2, pad_size - pad_size//2) + else: + npad[axis] = (0, pad_size) + return np.pad(array, pad_width=npad, mode='constant', constant_values=np.nan) + + @staticmethod + def split_with_nan(x, sections, axis=0): + assert x.dtype in [np.float16, np.float32, np.float64] + arrs = np.array_split(x, sections, axis=axis) + target_length = arrs[0].shape[axis] + for i in range(len(arrs)): + arrs[i] = _TS2Vec.pad_nan_to_target(arrs[i], target_length, axis=axis) + return arrs + + @staticmethod + def take_per_row(A, indx, num_elem): + all_indx = indx[:,None] + np.arange(num_elem) + return A[torch.arange(all_indx.shape[0])[:,None], all_indx] + + @staticmethod + def centerize_vary_length_series(x): + prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1) + suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1) + offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros + rows, column_indices = np.ogrid[:x.shape[0], :x.shape[1]] + offset[offset < 0] += x.shape[1] + column_indices = column_indices - offset[:, np.newaxis] + return x[rows, column_indices] + + @staticmethod + def instance_contrastive_loss(z1, z2): + B, T = z1.size(0), z1.size(1) + if B == 1: + return z1.new_tensor(0.) + z = torch.cat([z1, z2], dim=0) # 2B x T x C + z = z.transpose(0, 1) # T x 2B x C + sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B + logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1) + logits += torch.triu(sim, diagonal=1)[:, :, 1:] + logits = -F.log_softmax(logits, dim=-1) + + i = torch.arange(B, device=z1.device) + loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2 + return loss + + @staticmethod + def temporal_contrastive_loss(z1, z2): + B, T = z1.size(0), z1.size(1) + if T == 1: + return z1.new_tensor(0.) + z = torch.cat([z1, z2], dim=1) # B x 2T x C + sim = torch.matmul(z, z.transpose(1, 2)) # B x 2T x 2T + logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1) + logits += torch.triu(sim, diagonal=1)[:, :, 1:] + logits = -F.log_softmax(logits, dim=-1) + + t = torch.arange(T, device=z1.device) + loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2 + return loss + + @staticmethod + def hierarchical_contrastive_loss(z1, z2, alpha=0.5, temporal_unit=0): + loss = torch.tensor(0., device=z1.device) + d = 0 + while z1.size(1) > 1: + if alpha != 0: + loss += alpha * _TS2Vec.instance_contrastive_loss(z1, z2) + if d >= temporal_unit: + if 1 - alpha != 0: + loss += (1 - alpha) * _TS2Vec.temporal_contrastive_loss(z1, z2) + d += 1 + z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2) + z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2) + if z1.size(1) == 1: + if alpha != 0: + loss += alpha * _TS2Vec.instance_contrastive_loss(z1, z2) + d += 1 + return loss / d + + def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): + ''' Training the TS2Vec model. + + Args: + train_data (numpy.ndarray): The training data. It should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. + n_epochs (Union[int, NoneType]): The number of epochs. When this reaches, the training stops. + n_iters (Union[int, NoneType]): The number of iterations. When this reaches, the training stops. If both n_epochs and n_iters are not specified, a default setting would be used that sets n_iters to 200 for a dataset with size <= 100000, 600 otherwise. + verbose (bool): Whether to print the training loss after each epoch. + + Returns: + loss_log: a list containing the training losses on each epoch. + ''' + assert train_data.ndim == 3 + + if n_iters is None and n_epochs is None: + n_iters = 200 if train_data.size <= 100000 else 600 # default param for n_iters + + if self.max_train_length is not None: + sections = train_data.shape[1] // self.max_train_length + if sections >= 2: + train_data = np.concatenate(self.__class__.split_with_nan(train_data, sections, axis=1), axis=0) + + temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0) + if temporal_missing[0] or temporal_missing[-1]: + train_data = self.__class__.centerize_vary_length_series(train_data) + + train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)] + + train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float)) + train_loader = DataLoader(train_dataset, batch_size=min(self.batch_size, len(train_dataset)), shuffle=True, drop_last=True) + + optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr) + + loss_log = [] + + while True: + if n_epochs is not None and self.n_epochs >= n_epochs: + break + + cum_loss = 0 + n_epoch_iters = 0 + + interrupted = False + for batch in train_loader: + if n_iters is not None and self.n_iters >= n_iters: + interrupted = True + break + + x = batch[0] + if self.max_train_length is not None and x.size(1) > self.max_train_length: + window_offset = np.random.randint(x.size(1) - self.max_train_length + 1) + x = x[:, window_offset : window_offset + self.max_train_length] + x = x.to(self.device) + + ts_l = x.size(1) + crop_l = np.random.randint(low=2 ** (self.temporal_unit + 1), high=ts_l+1) + crop_left = np.random.randint(ts_l - crop_l + 1) + crop_right = crop_left + crop_l + crop_eleft = np.random.randint(crop_left + 1) + crop_eright = np.random.randint(low=crop_right, high=ts_l + 1) + crop_offset = np.random.randint(low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0)) + + optimizer.zero_grad() + + out1 = self._net(self.__class__.take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft)) + out1 = out1[:, -crop_l:] + + out2 = self._net(self.__class__.take_per_row(x, crop_offset + crop_left, crop_eright - crop_left)) + out2 = out2[:, :crop_l] + + loss = _TS2Vec.hierarchical_contrastive_loss( + out1, + out2, + temporal_unit=self.temporal_unit + ) + + loss.backward() + optimizer.step() + self.net.update_parameters(self._net) + + cum_loss += loss.item() + n_epoch_iters += 1 + + self.n_iters += 1 + + if self.after_iter_callback is not None: + self.after_iter_callback(self, loss.item()) + + if interrupted: + break + + cum_loss /= n_epoch_iters + loss_log.append(cum_loss) + if verbose: + print(f"Epoch #{self.n_epochs}: loss={cum_loss}") + self.n_epochs += 1 + + if self.after_epoch_callback is not None: + self.after_epoch_callback(self, cum_loss) + + return loss_log + + def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None): + out = self.net(x.to(self.device, non_blocking=True), mask) + if encoding_window == 'full_series': + if slicing is not None: + out = out[:, slicing] + out = F.max_pool1d( + out.transpose(1, 2), + kernel_size = out.size(1), + ).transpose(1, 2) + + elif isinstance(encoding_window, int): + out = F.max_pool1d( + out.transpose(1, 2), + kernel_size = encoding_window, + stride = 1, + padding = encoding_window // 2 + ).transpose(1, 2) + if encoding_window % 2 == 0: + out = out[:, :-1] + if slicing is not None: + out = out[:, slicing] + + elif encoding_window == 'multiscale': + p = 0 + reprs = [] + while (1 << p) + 1 < out.size(1): + t_out = F.max_pool1d( + out.transpose(1, 2), + kernel_size = (1 << (p + 1)) + 1, + stride = 1, + padding = 1 << p + ).transpose(1, 2) + if slicing is not None: + t_out = t_out[:, slicing] + reprs.append(t_out) + p += 1 + out = torch.cat(reprs, dim=-1) + + else: + if slicing is not None: + out = out[:, slicing] + + return out.cpu() + + def torch_pad_nan(arr, left=0, right=0, dim=0): + if left > 0: + padshape = list(arr.shape) + padshape[dim] = left + arr = torch.cat((torch.full(padshape, np.nan), arr), dim=dim) + if right > 0: + padshape = list(arr.shape) + padshape[dim] = right + arr = torch.cat((arr, torch.full(padshape, np.nan)), dim=dim) + return arr + + def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_length=None, sliding_padding=0, batch_size=None): + ''' Compute representations using the model. + + Args: + data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. + mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'. + encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size. + causal (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp. + sliding_length (Union[int, NoneType]): The length of sliding window. When this param is specified, a sliding inference would be applied on the time series. + sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows. + batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training. + + Returns: + repr: The representations for data. + ''' + assert self.net is not None, 'please train or load a net first' + assert data.ndim == 3 + if batch_size is None: + batch_size = self.batch_size + n_samples, ts_l, _ = data.shape + + org_training = self.net.training + self.net.eval() + + dataset = TensorDataset(torch.from_numpy(data).to(torch.float)) + loader = DataLoader(dataset, batch_size=batch_size) + + with torch.no_grad(): + output = [] + for batch in loader: + x = batch[0] + if sliding_length is not None: + reprs = [] + if n_samples < batch_size: + calc_buffer = [] + calc_buffer_l = 0 + for i in range(0, ts_l, sliding_length): + l = i - sliding_padding + r = i + sliding_length + (sliding_padding if not causal else 0) + x_sliding = torch_pad_nan( + x[:, max(l, 0) : min(r, ts_l)], + left=-l if l<0 else 0, + right=r-ts_l if r>ts_l else 0, + dim=1 + ) + if n_samples < batch_size: + if calc_buffer_l + n_samples > batch_size: + out = self._eval_with_pooling( + torch.cat(calc_buffer, dim=0), + mask, + slicing=slice(sliding_padding, sliding_padding+sliding_length), + encoding_window=encoding_window + ) + reprs += torch.split(out, n_samples) + calc_buffer = [] + calc_buffer_l = 0 + calc_buffer.append(x_sliding) + calc_buffer_l += n_samples + else: + out = self._eval_with_pooling( + x_sliding, + mask, + slicing=slice(sliding_padding, sliding_padding+sliding_length), + encoding_window=encoding_window + ) + reprs.append(out) + + if n_samples < batch_size: + if calc_buffer_l > 0: + out = self._eval_with_pooling( + torch.cat(calc_buffer, dim=0), + mask, + slicing=slice(sliding_padding, sliding_padding+sliding_length), + encoding_window=encoding_window + ) + reprs += torch.split(out, n_samples) + calc_buffer = [] + calc_buffer_l = 0 + + out = torch.cat(reprs, dim=1) + if encoding_window == 'full_series': + out = F.max_pool1d( + out.transpose(1, 2).contiguous(), + kernel_size = out.size(1), + ).squeeze(1) + else: + out = self._eval_with_pooling(x, mask, encoding_window=encoding_window) + if encoding_window == 'full_series': + out = out.squeeze(1) + + output.append(out) + + output = torch.cat(output, dim=0) + + self.net.train(org_training) + return output.numpy() + + def save(self, fn): + ''' Save the model to a file. + + Args: + fn (str): filename. + ''' + torch.save(self.net.state_dict(), fn) + + def load(self, fn): + ''' Load the model from a file. + + Args: + fn (str): filename. + ''' + state_dict = torch.load(fn, map_location=self.device) + self.net.load_state_dict(state_dict) + + class SamePadConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1): + super().__init__() + self.receptive_field = (kernel_size - 1) * dilation + 1 + padding = self.receptive_field // 2 + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, + padding=padding, + dilation=dilation, + groups=groups + ) + self.remove = 1 if self.receptive_field % 2 == 0 else 0 + + def forward(self, x): + out = self.conv(x) + if self.remove > 0: + out = out[:, :, : -self.remove] + return out + + class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False): + super().__init__() + self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation) + self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation) + self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None + + def forward(self, x): + residual = x if self.projector is None else self.projector(x) + x = F.gelu(x) + x = self.conv1(x) + x = F.gelu(x) + x = self.conv2(x) + return x + residual + + class DilatedConvEncoder(nn.Module): + def __init__(self, in_channels, channels, kernel_size): + super().__init__() + self.net = nn.Sequential(*[ + ConvBlock( + channels[i-1] if i > 0 else in_channels, + channels[i], + kernel_size=kernel_size, + dilation=2**i, + final=(i == len(channels)-1) + ) + for i in range(len(channels)) + ]) + + def forward(self, x): + return self.net(x) + + class TSEncoder(nn.Module): + def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial'): + super().__init__() + self.input_dims = input_dims + self.output_dims = output_dims + self.hidden_dims = hidden_dims + self.mask_mode = mask_mode + self.input_fc = nn.Linear(input_dims, hidden_dims) + self.feature_extractor = DilatedConvEncoder( + hidden_dims, + [hidden_dims] * depth + [output_dims], + kernel_size=3 + ) + self.repr_dropout = nn.Dropout(p=0.1) + + @staticmethod + def generate_binomial_mask(B, T, p=0.5): + return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool) + + @staticmethod + def generate_continuous_mask(B, T, n=5, l=0.1): + res = torch.full((B, T), True, dtype=torch.bool) + if isinstance(n, float): + n = int(n * T) + n = max(min(n, T // 2), 1) + + if isinstance(l, float): + l = int(l * T) + l = max(l, 1) + + for i in range(B): + for _ in range(n): + t = np.random.randint(T-l+1) + res[i, t:t+l] = False + return res + + def forward(self, x, mask=None): # x: B x T x input_dims + nan_mask = ~x.isnan().any(axis=-1) + x[~nan_mask] = 0 + x = self.input_fc(x) # B x T x Ch + + # generate & apply mask + if mask is None: + if self.training: + mask = self.mask_mode + else: + mask = 'all_true' + + if mask == 'binomial': + mask = self.__class__.generate_binomial_mask(x.size(0), x.size(1)).to(x.device) + elif mask == 'continuous': + mask = self.__class__.generate_continuous_mask(x.size(0), x.size(1)).to(x.device) + elif mask == 'all_true': + mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) + elif mask == 'all_false': + mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) + elif mask == 'mask_last': + mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) + mask[:, -1] = False + + mask &= nan_mask + x[~mask] = 0 + + # conv encoder + x = x.transpose(1, 2) # B x Ch x T + x = self.repr_dropout(self.feature_extractor(x)) # B x Co x T + x = x.transpose(1, 2) # B x T x Co + + return x \ No newline at end of file diff --git a/aeon/transformations/collection/contrastive_based/tests/__init__.py b/aeon/transformations/collection/contrastive_based/tests/__init__.py new file mode 100644 index 0000000000..6c50744f89 --- /dev/null +++ b/aeon/transformations/collection/contrastive_based/tests/__init__.py @@ -0,0 +1 @@ +"""Contrastive learning unit tests.""" diff --git a/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py b/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py new file mode 100644 index 0000000000..44098cfa03 --- /dev/null +++ b/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py @@ -0,0 +1,29 @@ +import numpy as np + +from aeon.transformations.collection.contrastive_based._ts2vec import TS2Vec + + + +def test_shape(): + expected_features = 200 + X = np.random.random(size=(10, 1, 100)) + transformer = TS2Vec(output_dim=expected_features) + transformer.fit(X) + X_trans = transformer.transform(X) + np.testing.assert_equal(X_trans.shape, (len(X), expected_features)) + +def test_shape2(): + expected_features = 500 + X = np.random.random(size=(10, 1, 100)) + transformer = TS2Vec(output_dim=expected_features) + transformer.fit(X) + X_trans = transformer.transform(X) + np.testing.assert_equal(X_trans.shape, (len(X), expected_features)) + +def test_shape3(): + expected_features = 200 + X = np.random.random(size=(10, 3, 100)) + transformer = TS2Vec(output_dim=expected_features) + transformer.fit(X) + X_trans = transformer.transform(X) + np.testing.assert_equal(X_trans.shape, (len(X), expected_features)) \ No newline at end of file From 2b27a7f040b639f33106d1161ae9178e44821780 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Petelin?= Date: Mon, 28 Apr 2025 19:47:08 +0200 Subject: [PATCH 2/9] Fixed TS2Vec tests --- .../_yield_regression_checks.py | 1 + .../collection/contrastive_based/_ts2vec.py | 455 +++++++++--------- aeon/utils/tags/_tags.py | 1 + 3 files changed, 237 insertions(+), 220 deletions(-) diff --git a/aeon/testing/estimator_checking/_yield_regression_checks.py b/aeon/testing/estimator_checking/_yield_regression_checks.py index 06fe479654..531721a459 100644 --- a/aeon/testing/estimator_checking/_yield_regression_checks.py +++ b/aeon/testing/estimator_checking/_yield_regression_checks.py @@ -169,6 +169,7 @@ def check_regressor_overrides_and_tags(estimator_class): "feature", "hybrid", "shapelet", + "contrastive", ] algorithm_type = estimator_class.get_class_tag("algorithm_type") if algorithm_type is not None: diff --git a/aeon/transformations/collection/contrastive_based/_ts2vec.py b/aeon/transformations/collection/contrastive_based/_ts2vec.py index d27c254286..a8d048085b 100644 --- a/aeon/transformations/collection/contrastive_based/_ts2vec.py +++ b/aeon/transformations/collection/contrastive_based/_ts2vec.py @@ -12,25 +12,40 @@ class TS2Vec(BaseCollectionTransformer): _tags = { "capability:multivariate": True, "output_data_type": "Tabular", + "capability:multithreading": True, "algorithm_type": "contrastive", "python_dependencies": "torch", + "non_deterministic": True, } - def __init__(self, output_dim=320, n_jobs=1): + def __init__(self, output_dim=320, device=None, n_jobs=1, verbose=False): self.output_dim = output_dim self.n_jobs = n_jobs + self.device = device + self.verbose = verbose super().__init__() def _transform(self, X, y=None): return self._ts2vec.encode(X.transpose(0, 2, 1), encoding_window='full_series') def _fit(self, X, y=None): + import torch + + n_jobs = check_n_jobs(self.n_jobs) + torch.set_num_threads(n_jobs) + + selected_device = None + if self.device is None: + selected_device = "cuda" if torch.cuda.is_available() else "cpu" + else: + selected_device = self.device + self._ts2vec = _TS2Vec( input_dims=X.shape[1], output_dims=self.output_dim, - #device='cuda', + device=selected_device, ) - self._ts2vec.fit(X.transpose(0, 2, 1), verbose=False) + self.loss_ = self._ts2vec.fit(X.transpose(0, 2, 1), verbose=self.verbose) return self if _check_soft_dependencies("torch", severity="none"): @@ -39,6 +54,216 @@ def _fit(self, X, y=None): from torch import nn from torch.utils.data import TensorDataset, DataLoader + class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False): + super().__init__() + self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation) + self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation) + self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None + + def forward(self, x): + residual = x if self.projector is None else self.projector(x) + x = F.gelu(x) + x = self.conv1(x) + x = F.gelu(x) + x = self.conv2(x) + return x + residual + + class DilatedConvEncoder(nn.Module): + def __init__(self, in_channels, channels, kernel_size): + super().__init__() + self.net = nn.Sequential(*[ + ConvBlock( + channels[i-1] if i > 0 else in_channels, + channels[i], + kernel_size=kernel_size, + dilation=2**i, + final=(i == len(channels)-1) + ) + for i in range(len(channels)) + ]) + + def forward(self, x): + return self.net(x) + + class SamePadConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1): + super().__init__() + self.receptive_field = (kernel_size - 1) * dilation + 1 + padding = self.receptive_field // 2 + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, + padding=padding, + dilation=dilation, + groups=groups + ) + self.remove = 1 if self.receptive_field % 2 == 0 else 0 + + def forward(self, x): + out = self.conv(x) + if self.remove > 0: + out = out[:, :, : -self.remove] + return out + + class TSEncoder(nn.Module): + def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial'): + super().__init__() + self.input_dims = input_dims + self.output_dims = output_dims + self.hidden_dims = hidden_dims + self.mask_mode = mask_mode + self.input_fc = nn.Linear(input_dims, hidden_dims) + self.feature_extractor = DilatedConvEncoder( + hidden_dims, + [hidden_dims] * depth + [output_dims], + kernel_size=3 + ) + self.repr_dropout = nn.Dropout(p=0.1) + + def forward(self, x, mask=None): # x: B x T x input_dims + nan_mask = ~x.isnan().any(axis=-1) + x[~nan_mask] = 0 + x = self.input_fc(x) # B x T x Ch + + # generate & apply mask + if mask is None: + if self.training: + mask = self.mask_mode + else: + mask = 'all_true' + + if mask == 'binomial': + mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device) + elif mask == 'continuous': + mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device) + elif mask == 'all_true': + mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) + elif mask == 'all_false': + mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) + elif mask == 'mask_last': + mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) + mask[:, -1] = False + + mask &= nan_mask + x[~mask] = 0 + + # conv encoder + x = x.transpose(1, 2) # B x Ch x T + x = self.repr_dropout(self.feature_extractor(x)) # B x Co x T + x = x.transpose(1, 2) # B x T x Co + + return x + + def generate_binomial_mask(B, T, p=0.5): + return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool) + + def generate_continuous_mask(B, T, n=5, l=0.1): + res = torch.full((B, T), True, dtype=torch.bool) + if isinstance(n, float): + n = int(n * T) + n = max(min(n, T // 2), 1) + + if isinstance(l, float): + l = int(l * T) + l = max(l, 1) + + for i in range(B): + for _ in range(n): + t = np.random.randint(T-l+1) + res[i, t:t+l] = False + return res + + def pad_nan_to_target(array, target_length, axis=0, both_side=False): + assert array.dtype in [np.float16, np.float32, np.float64] + pad_size = target_length - array.shape[axis] + if pad_size <= 0: + return array + npad = [(0, 0)] * array.ndim + if both_side: + npad[axis] = (pad_size // 2, pad_size - pad_size//2) + else: + npad[axis] = (0, pad_size) + return np.pad(array, pad_width=npad, mode='constant', constant_values=np.nan) + + def take_per_row(A, indx, num_elem): + all_indx = indx[:,None] + np.arange(num_elem) + return A[torch.arange(all_indx.shape[0])[:,None], all_indx] + + def torch_pad_nan(arr, left=0, right=0, dim=0): + if left > 0: + padshape = list(arr.shape) + padshape[dim] = left + arr = torch.cat((torch.full(padshape, np.nan), arr), dim=dim) + if right > 0: + padshape = list(arr.shape) + padshape[dim] = right + arr = torch.cat((arr, torch.full(padshape, np.nan)), dim=dim) + return arr + + def instance_contrastive_loss(z1, z2): + B, T = z1.size(0), z1.size(1) + if B == 1: + return z1.new_tensor(0.) + z = torch.cat([z1, z2], dim=0) # 2B x T x C + z = z.transpose(0, 1) # T x 2B x C + sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B + logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1) + logits += torch.triu(sim, diagonal=1)[:, :, 1:] + logits = -F.log_softmax(logits, dim=-1) + + i = torch.arange(B, device=z1.device) + loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2 + return loss + + def temporal_contrastive_loss(z1, z2): + B, T = z1.size(0), z1.size(1) + if T == 1: + return z1.new_tensor(0.) + z = torch.cat([z1, z2], dim=1) # B x 2T x C + sim = torch.matmul(z, z.transpose(1, 2)) # B x 2T x 2T + logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1) + logits += torch.triu(sim, diagonal=1)[:, :, 1:] + logits = -F.log_softmax(logits, dim=-1) + + t = torch.arange(T, device=z1.device) + loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2 + return loss + + def hierarchical_contrastive_loss(z1, z2, alpha=0.5, temporal_unit=0): + loss = torch.tensor(0., device=z1.device) + d = 0 + while z1.size(1) > 1: + if alpha != 0: + loss += alpha * instance_contrastive_loss(z1, z2) + if d >= temporal_unit: + if 1 - alpha != 0: + loss += (1 - alpha) * temporal_contrastive_loss(z1, z2) + d += 1 + z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2) + z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2) + if z1.size(1) == 1: + if alpha != 0: + loss += alpha * instance_contrastive_loss(z1, z2) + d += 1 + return loss / d + + def split_with_nan(x, sections, axis=0): + assert x.dtype in [np.float16, np.float32, np.float64] + arrs = np.array_split(x, sections, axis=axis) + target_length = arrs[0].shape[axis] + for i in range(len(arrs)): + arrs[i] = pad_nan_to_target(arrs[i], target_length, axis=axis) + return arrs + + def centerize_vary_length_series(x): + prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1) + suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1) + offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros + rows, column_indices = np.ogrid[:x.shape[0], :x.shape[1]] + offset[offset < 0] += x.shape[1] + column_indices = column_indices - offset[:, np.newaxis] + return x[rows, column_indices] + class _TS2Vec(): def __init__( self, @@ -87,92 +312,7 @@ def __init__( self.n_epochs = 0 self.n_iters = 0 - @staticmethod - def pad_nan_to_target(array, target_length, axis=0, both_side=False): - assert array.dtype in [np.float16, np.float32, np.float64] - pad_size = target_length - array.shape[axis] - if pad_size <= 0: - return array - npad = [(0, 0)] * array.ndim - if both_side: - npad[axis] = (pad_size // 2, pad_size - pad_size//2) - else: - npad[axis] = (0, pad_size) - return np.pad(array, pad_width=npad, mode='constant', constant_values=np.nan) - - @staticmethod - def split_with_nan(x, sections, axis=0): - assert x.dtype in [np.float16, np.float32, np.float64] - arrs = np.array_split(x, sections, axis=axis) - target_length = arrs[0].shape[axis] - for i in range(len(arrs)): - arrs[i] = _TS2Vec.pad_nan_to_target(arrs[i], target_length, axis=axis) - return arrs - - @staticmethod - def take_per_row(A, indx, num_elem): - all_indx = indx[:,None] + np.arange(num_elem) - return A[torch.arange(all_indx.shape[0])[:,None], all_indx] - - @staticmethod - def centerize_vary_length_series(x): - prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1) - suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1) - offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros - rows, column_indices = np.ogrid[:x.shape[0], :x.shape[1]] - offset[offset < 0] += x.shape[1] - column_indices = column_indices - offset[:, np.newaxis] - return x[rows, column_indices] - - @staticmethod - def instance_contrastive_loss(z1, z2): - B, T = z1.size(0), z1.size(1) - if B == 1: - return z1.new_tensor(0.) - z = torch.cat([z1, z2], dim=0) # 2B x T x C - z = z.transpose(0, 1) # T x 2B x C - sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B - logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1) - logits += torch.triu(sim, diagonal=1)[:, :, 1:] - logits = -F.log_softmax(logits, dim=-1) - - i = torch.arange(B, device=z1.device) - loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2 - return loss - - @staticmethod - def temporal_contrastive_loss(z1, z2): - B, T = z1.size(0), z1.size(1) - if T == 1: - return z1.new_tensor(0.) - z = torch.cat([z1, z2], dim=1) # B x 2T x C - sim = torch.matmul(z, z.transpose(1, 2)) # B x 2T x 2T - logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1) - logits += torch.triu(sim, diagonal=1)[:, :, 1:] - logits = -F.log_softmax(logits, dim=-1) - - t = torch.arange(T, device=z1.device) - loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2 - return loss - - @staticmethod - def hierarchical_contrastive_loss(z1, z2, alpha=0.5, temporal_unit=0): - loss = torch.tensor(0., device=z1.device) - d = 0 - while z1.size(1) > 1: - if alpha != 0: - loss += alpha * _TS2Vec.instance_contrastive_loss(z1, z2) - if d >= temporal_unit: - if 1 - alpha != 0: - loss += (1 - alpha) * _TS2Vec.temporal_contrastive_loss(z1, z2) - d += 1 - z1 = F.max_pool1d(z1.transpose(1, 2), kernel_size=2).transpose(1, 2) - z2 = F.max_pool1d(z2.transpose(1, 2), kernel_size=2).transpose(1, 2) - if z1.size(1) == 1: - if alpha != 0: - loss += alpha * _TS2Vec.instance_contrastive_loss(z1, z2) - d += 1 - return loss / d + def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): ''' Training the TS2Vec model. @@ -194,11 +334,11 @@ def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): if self.max_train_length is not None: sections = train_data.shape[1] // self.max_train_length if sections >= 2: - train_data = np.concatenate(self.__class__.split_with_nan(train_data, sections, axis=1), axis=0) + train_data = np.concatenate(split_with_nan(train_data, sections, axis=1), axis=0) temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0) if temporal_missing[0] or temporal_missing[-1]: - train_data = self.__class__.centerize_vary_length_series(train_data) + train_data = centerize_vary_length_series(train_data) train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)] @@ -238,13 +378,13 @@ def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): optimizer.zero_grad() - out1 = self._net(self.__class__.take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft)) + out1 = self._net(take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft)) out1 = out1[:, -crop_l:] - out2 = self._net(self.__class__.take_per_row(x, crop_offset + crop_left, crop_eright - crop_left)) + out2 = self._net(take_per_row(x, crop_offset + crop_left, crop_eright - crop_left)) out2 = out2[:, :crop_l] - loss = _TS2Vec.hierarchical_contrastive_loss( + loss = hierarchical_contrastive_loss( out1, out2, temporal_unit=self.temporal_unit @@ -320,16 +460,7 @@ def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None): return out.cpu() - def torch_pad_nan(arr, left=0, right=0, dim=0): - if left > 0: - padshape = list(arr.shape) - padshape[dim] = left - arr = torch.cat((torch.full(padshape, np.nan), arr), dim=dim) - if right > 0: - padshape = list(arr.shape) - padshape[dim] = right - arr = torch.cat((arr, torch.full(padshape, np.nan)), dim=dim) - return arr + def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_length=None, sliding_padding=0, batch_size=None): ''' Compute representations using the model. @@ -445,123 +576,7 @@ def load(self, fn): state_dict = torch.load(fn, map_location=self.device) self.net.load_state_dict(state_dict) - class SamePadConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1): - super().__init__() - self.receptive_field = (kernel_size - 1) * dilation + 1 - padding = self.receptive_field // 2 - self.conv = nn.Conv1d( - in_channels, out_channels, kernel_size, - padding=padding, - dilation=dilation, - groups=groups - ) - self.remove = 1 if self.receptive_field % 2 == 0 else 0 - - def forward(self, x): - out = self.conv(x) - if self.remove > 0: - out = out[:, :, : -self.remove] - return out - - class ConvBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False): - super().__init__() - self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation) - self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation) - self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None - - def forward(self, x): - residual = x if self.projector is None else self.projector(x) - x = F.gelu(x) - x = self.conv1(x) - x = F.gelu(x) - x = self.conv2(x) - return x + residual - - class DilatedConvEncoder(nn.Module): - def __init__(self, in_channels, channels, kernel_size): - super().__init__() - self.net = nn.Sequential(*[ - ConvBlock( - channels[i-1] if i > 0 else in_channels, - channels[i], - kernel_size=kernel_size, - dilation=2**i, - final=(i == len(channels)-1) - ) - for i in range(len(channels)) - ]) - - def forward(self, x): - return self.net(x) - class TSEncoder(nn.Module): - def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial'): - super().__init__() - self.input_dims = input_dims - self.output_dims = output_dims - self.hidden_dims = hidden_dims - self.mask_mode = mask_mode - self.input_fc = nn.Linear(input_dims, hidden_dims) - self.feature_extractor = DilatedConvEncoder( - hidden_dims, - [hidden_dims] * depth + [output_dims], - kernel_size=3 - ) - self.repr_dropout = nn.Dropout(p=0.1) - @staticmethod - def generate_binomial_mask(B, T, p=0.5): - return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool) - @staticmethod - def generate_continuous_mask(B, T, n=5, l=0.1): - res = torch.full((B, T), True, dtype=torch.bool) - if isinstance(n, float): - n = int(n * T) - n = max(min(n, T // 2), 1) - - if isinstance(l, float): - l = int(l * T) - l = max(l, 1) - - for i in range(B): - for _ in range(n): - t = np.random.randint(T-l+1) - res[i, t:t+l] = False - return res - def forward(self, x, mask=None): # x: B x T x input_dims - nan_mask = ~x.isnan().any(axis=-1) - x[~nan_mask] = 0 - x = self.input_fc(x) # B x T x Ch - - # generate & apply mask - if mask is None: - if self.training: - mask = self.mask_mode - else: - mask = 'all_true' - - if mask == 'binomial': - mask = self.__class__.generate_binomial_mask(x.size(0), x.size(1)).to(x.device) - elif mask == 'continuous': - mask = self.__class__.generate_continuous_mask(x.size(0), x.size(1)).to(x.device) - elif mask == 'all_true': - mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) - elif mask == 'all_false': - mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) - elif mask == 'mask_last': - mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) - mask[:, -1] = False - - mask &= nan_mask - x[~mask] = 0 - - # conv encoder - x = x.transpose(1, 2) # B x Ch x T - x = self.repr_dropout(self.feature_extractor(x)) # B x Co x T - x = x.transpose(1, 2) # B x T x Co - - return x \ No newline at end of file diff --git a/aeon/utils/tags/_tags.py b/aeon/utils/tags/_tags.py index a4f9e04152..31462707a1 100644 --- a/aeon/utils/tags/_tags.py +++ b/aeon/utils/tags/_tags.py @@ -75,6 +75,7 @@ class : identifier for the base class of objects this tag applies to "convolution", "shapelet", "deeplearning", + "contrastive", ], ), None, From a78b11caaa7aaf6dca8c52b176038f9d717e0b97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Petelin?= Date: Tue, 29 Apr 2025 12:51:07 +0200 Subject: [PATCH 3/9] Added TS2Vec doc --- .../collection/contrastive_based/_ts2vec.py | 543 +++++++++++------- 1 file changed, 325 insertions(+), 218 deletions(-) diff --git a/aeon/transformations/collection/contrastive_based/_ts2vec.py b/aeon/transformations/collection/contrastive_based/_ts2vec.py index a8d048085b..dc5a17b8e0 100644 --- a/aeon/transformations/collection/contrastive_based/_ts2vec.py +++ b/aeon/transformations/collection/contrastive_based/_ts2vec.py @@ -4,30 +4,118 @@ __all__ = ["TS2Vec"] import numpy as np + from aeon.transformations.collection import BaseCollectionTransformer from aeon.utils.validation import check_n_jobs from aeon.utils.validation._dependencies import _check_soft_dependencies + class TS2Vec(BaseCollectionTransformer): + """TS2Vec Transformer. + + Parameters + ---------- + output_dim : int, default=320 + The dimension of the output representation. + hidden_dim : int, default=64 + The dimension of the hidden layer in the encoder. + depth : int, default=10 + The number of hidden residual blocks in the encoder. + lr : float, default=0.001 + The learning rate for the optimizer. + batch_size : int, default=16 + The batch size for training. + max_train_length : None or int, default=None + The maximum allowed sequence length for training. + For sequences longer than this, they will be cropped into smaller sequences. + temporal_unit : int, default=0 + The minimum unit for temporal contrast. This helps reduce the cost of time + and memory when training on long sequences. + n_epochs : None or int, default=None + The number of epochs. When this reaches, the training stops. + n_iters : None or int, default=None + The number of iterations. When this reaches, the training stops. + If both n_epochs and n_iters are not specified, a default setting will be used + that sets n_iters to 200 for datasets with size <= 100000, and 600 otherwise. + verbose : bool, default=False + Whether to print the training loss after each epoch. + device : None or str, default=None + The device to use for training and inference. If None, it will automatically + select 'cuda' if available, otherwise 'cpu'. + + Notes + ----- + Inspired by the original implementation + https://github.com/zhihanyue/ts2vec + Copyright (c) 2022 Zhihan Yue + + References + ---------- + .. [1] Yue, Z., Wang, Y., Duan, J., Yang, T., Huang, C., Tong, Y. and Xu, B., + 2022, June. Ts2vec: Towards universal representation of time series. + In Proceedings of the AAAI conference on artificial intelligence + (Vol. 36, No. 8, pp. 8980-8987). + + Examples + -------- + >>> from aeon.transformations.collection.convolution_based import HydraTransformer + >>> from aeon.testing.data_generation import make_example_3d_numpy + >>> X, _ = make_example_3d_numpy(n_cases=10, n_channels=1, n_timepoints=12, + ... random_state=0) + >>> clf = TS2Vec() # doctest: +SKIP + >>> clf.fit(X) # doctest: +SKIP + TS2Vec() + >>> clf.transform(X)[0] # doctest: +SKIP + tensor([0.0375 -0.003 -0.0953, ..., 0.0375, -0.0035, -0.0953]) + """ + _tags = { + "python_dependencies": "torch", "capability:multivariate": True, + "non_deterministic": True, "output_data_type": "Tabular", "capability:multithreading": True, "algorithm_type": "contrastive", - "python_dependencies": "torch", - "non_deterministic": True, + "cant_pickle": True, } - def __init__(self, output_dim=320, device=None, n_jobs=1, verbose=False): + def __init__( + self, + output_dim=320, + hidden_dim=64, + depth=10, + lr=0.001, + batch_size=16, + max_train_length=None, + temporal_unit=0, + n_epochs=None, + n_iters=None, + device=None, + n_jobs=1, + verbose=False, + ): self.output_dim = output_dim - self.n_jobs = n_jobs + self.hidden_dim = hidden_dim self.device = device + self.depth = depth + self.lr = lr + self.batch_size = batch_size + self.max_train_length = max_train_length + self.temporal_unit = temporal_unit + self.n_jobs = n_jobs + self.verbose = verbose + self.n_epochs = n_epochs + self.n_iters = n_iters super().__init__() def _transform(self, X, y=None): - return self._ts2vec.encode(X.transpose(0, 2, 1), encoding_window='full_series') - + return self._ts2vec.encode( + X.transpose(0, 2, 1), + encoding_window="full_series", + batch_size=self.batch_size, + ) + def _fit(self, X, y=None): import torch @@ -42,25 +130,47 @@ def _fit(self, X, y=None): self._ts2vec = _TS2Vec( input_dims=X.shape[1], - output_dims=self.output_dim, device=selected_device, + output_dims=self.output_dim, + hidden_dims=self.hidden_dim, + depth=self.depth, + lr=self.lr, + batch_size=self.batch_size, + max_train_length=self.max_train_length, + temporal_unit=self.temporal_unit, + ) + self.loss_ = self._ts2vec.fit( + X.transpose(0, 2, 1), + verbose=self.verbose, + n_epochs=self.n_epochs, + n_iters=self.n_iters, ) - self.loss_ = self._ts2vec.fit(X.transpose(0, 2, 1), verbose=self.verbose) return self - + + if _check_soft_dependencies("torch", severity="none"): import torch import torch.nn.functional as F from torch import nn - from torch.utils.data import TensorDataset, DataLoader + from torch.utils.data import DataLoader, TensorDataset class ConvBlock(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, dilation, final=False): + def __init__( + self, in_channels, out_channels, kernel_size, dilation, final=False + ): super().__init__() - self.conv1 = SamePadConv(in_channels, out_channels, kernel_size, dilation=dilation) - self.conv2 = SamePadConv(out_channels, out_channels, kernel_size, dilation=dilation) - self.projector = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels or final else None - + self.conv1 = SamePadConv( + in_channels, out_channels, kernel_size, dilation=dilation + ) + self.conv2 = SamePadConv( + out_channels, out_channels, kernel_size, dilation=dilation + ) + self.projector = ( + nn.Conv1d(in_channels, out_channels, 1) + if in_channels != out_channels or final + else None + ) + def forward(self, x): residual = x if self.projector is None else self.projector(x) x = F.gelu(x) @@ -72,33 +182,39 @@ def forward(self, x): class DilatedConvEncoder(nn.Module): def __init__(self, in_channels, channels, kernel_size): super().__init__() - self.net = nn.Sequential(*[ - ConvBlock( - channels[i-1] if i > 0 else in_channels, - channels[i], - kernel_size=kernel_size, - dilation=2**i, - final=(i == len(channels)-1) - ) - for i in range(len(channels)) - ]) - + self.net = nn.Sequential( + *[ + ConvBlock( + channels[i - 1] if i > 0 else in_channels, + channels[i], + kernel_size=kernel_size, + dilation=2**i, + final=(i == len(channels) - 1), + ) + for i in range(len(channels)) + ] + ) + def forward(self, x): return self.net(x) - + class SamePadConv(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1): + def __init__( + self, in_channels, out_channels, kernel_size, dilation=1, groups=1 + ): super().__init__() self.receptive_field = (kernel_size - 1) * dilation + 1 padding = self.receptive_field // 2 self.conv = nn.Conv1d( - in_channels, out_channels, kernel_size, + in_channels, + out_channels, + kernel_size, padding=padding, dilation=dilation, - groups=groups + groups=groups, ) self.remove = 1 if self.receptive_field % 2 == 0 else 0 - + def forward(self, x): out = self.conv(x) if self.remove > 0: @@ -106,7 +222,14 @@ def forward(self, x): return out class TSEncoder(nn.Module): - def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode='binomial'): + def __init__( + self, + input_dims, + output_dims, + hidden_dims=64, + depth=10, + mask_mode="binomial", + ): super().__init__() self.input_dims = input_dims self.output_dims = output_dims @@ -114,9 +237,7 @@ def __init__(self, input_dims, output_dims, hidden_dims=64, depth=10, mask_mode= self.mask_mode = mask_mode self.input_fc = nn.Linear(input_dims, hidden_dims) self.feature_extractor = DilatedConvEncoder( - hidden_dims, - [hidden_dims] * depth + [output_dims], - kernel_size=3 + hidden_dims, [hidden_dims] * depth + [output_dims], kernel_size=3 ) self.repr_dropout = nn.Dropout(p=0.1) @@ -124,53 +245,53 @@ def forward(self, x, mask=None): # x: B x T x input_dims nan_mask = ~x.isnan().any(axis=-1) x[~nan_mask] = 0 x = self.input_fc(x) # B x T x Ch - + # generate & apply mask if mask is None: if self.training: mask = self.mask_mode else: - mask = 'all_true' - - if mask == 'binomial': + mask = "all_true" + + if mask == "binomial": mask = generate_binomial_mask(x.size(0), x.size(1)).to(x.device) - elif mask == 'continuous': + elif mask == "continuous": mask = generate_continuous_mask(x.size(0), x.size(1)).to(x.device) - elif mask == 'all_true': + elif mask == "all_true": mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) - elif mask == 'all_false': + elif mask == "all_false": mask = x.new_full((x.size(0), x.size(1)), False, dtype=torch.bool) - elif mask == 'mask_last': + elif mask == "mask_last": mask = x.new_full((x.size(0), x.size(1)), True, dtype=torch.bool) mask[:, -1] = False - + mask &= nan_mask x[~mask] = 0 - + # conv encoder x = x.transpose(1, 2) # B x Ch x T x = self.repr_dropout(self.feature_extractor(x)) # B x Co x T x = x.transpose(1, 2) # B x T x Co - + return x def generate_binomial_mask(B, T, p=0.5): return torch.from_numpy(np.random.binomial(1, p, size=(B, T))).to(torch.bool) - def generate_continuous_mask(B, T, n=5, l=0.1): + def generate_continuous_mask(B, T, n=5, mask_length=0.1): res = torch.full((B, T), True, dtype=torch.bool) if isinstance(n, float): n = int(n * T) n = max(min(n, T // 2), 1) - - if isinstance(l, float): - l = int(l * T) - l = max(l, 1) - + + if isinstance(mask_length, float): + mask_length = int(mask_length * T) + mask_length = max(mask_length, 1) + for i in range(B): for _ in range(n): - t = np.random.randint(T-l+1) - res[i, t:t+l] = False + t = np.random.randint(T - mask_length + 1) + res[i, t : t + mask_length] = False return res def pad_nan_to_target(array, target_length, axis=0, both_side=False): @@ -180,14 +301,14 @@ def pad_nan_to_target(array, target_length, axis=0, both_side=False): return array npad = [(0, 0)] * array.ndim if both_side: - npad[axis] = (pad_size // 2, pad_size - pad_size//2) + npad[axis] = (pad_size // 2, pad_size - pad_size // 2) else: npad[axis] = (0, pad_size) - return np.pad(array, pad_width=npad, mode='constant', constant_values=np.nan) + return np.pad(array, pad_width=npad, mode="constant", constant_values=np.nan) def take_per_row(A, indx, num_elem): - all_indx = indx[:,None] + np.arange(num_elem) - return A[torch.arange(all_indx.shape[0])[:,None], all_indx] + all_indx = indx[:, None] + np.arange(num_elem) + return A[torch.arange(all_indx.shape[0])[:, None], all_indx] def torch_pad_nan(arr, left=0, right=0, dim=0): if left > 0: @@ -201,36 +322,36 @@ def torch_pad_nan(arr, left=0, right=0, dim=0): return arr def instance_contrastive_loss(z1, z2): - B, T = z1.size(0), z1.size(1) + B, _ = z1.size(0), z1.size(1) if B == 1: - return z1.new_tensor(0.) + return z1.new_tensor(0.0) z = torch.cat([z1, z2], dim=0) # 2B x T x C z = z.transpose(0, 1) # T x 2B x C sim = torch.matmul(z, z.transpose(1, 2)) # T x 2B x 2B - logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1) + logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # T x 2B x (2B-1) logits += torch.triu(sim, diagonal=1)[:, :, 1:] logits = -F.log_softmax(logits, dim=-1) - + i = torch.arange(B, device=z1.device) loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2 return loss def temporal_contrastive_loss(z1, z2): - B, T = z1.size(0), z1.size(1) + _, T = z1.size(0), z1.size(1) if T == 1: - return z1.new_tensor(0.) + return z1.new_tensor(0.0) z = torch.cat([z1, z2], dim=1) # B x 2T x C sim = torch.matmul(z, z.transpose(1, 2)) # B x 2T x 2T - logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1) + logits = torch.tril(sim, diagonal=-1)[:, :, :-1] # B x 2T x (2T-1) logits += torch.triu(sim, diagonal=1)[:, :, 1:] logits = -F.log_softmax(logits, dim=-1) - + t = torch.arange(T, device=z1.device) loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2 return loss def hierarchical_contrastive_loss(z1, z2, alpha=0.5, temporal_unit=0): - loss = torch.tensor(0., device=z1.device) + loss = torch.tensor(0.0, device=z1.device) d = 0 while z1.size(1) > 1: if alpha != 0: @@ -254,230 +375,223 @@ def split_with_nan(x, sections, axis=0): for i in range(len(arrs)): arrs[i] = pad_nan_to_target(arrs[i], target_length, axis=axis) return arrs - + def centerize_vary_length_series(x): prefix_zeros = np.argmax(~np.isnan(x).all(axis=-1), axis=1) suffix_zeros = np.argmax(~np.isnan(x[:, ::-1]).all(axis=-1), axis=1) offset = (prefix_zeros + suffix_zeros) // 2 - prefix_zeros - rows, column_indices = np.ogrid[:x.shape[0], :x.shape[1]] + rows, column_indices = np.ogrid[: x.shape[0], : x.shape[1]] offset[offset < 0] += x.shape[1] column_indices = column_indices - offset[:, np.newaxis] return x[rows, column_indices] - class _TS2Vec(): + class _TS2Vec: def __init__( self, input_dims, output_dims=320, hidden_dims=64, depth=10, - device='cuda', + device="cuda", lr=0.001, batch_size=16, max_train_length=None, temporal_unit=0, after_iter_callback=None, - after_epoch_callback=None + after_epoch_callback=None, ): - ''' Initialize a TS2Vec model. - - Args: - input_dims (int): The input dimension. For a univariate time series, this should be set to 1. - output_dims (int): The representation dimension. - hidden_dims (int): The hidden dimension of the encoder. - depth (int): The number of hidden residual blocks in the encoder. - device (int): The gpu used for training and inference. - lr (int): The learning rate. - batch_size (int): The batch size. - max_train_length (Union[int, NoneType]): The maximum allowed sequence length for training. For sequence with a length greater than , it would be cropped into some sequences, each of which has a length less than . - temporal_unit (int): The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory. - after_iter_callback (Union[Callable, NoneType]): A callback function that would be called after each iteration. - after_epoch_callback (Union[Callable, NoneType]): A callback function that would be called after each epoch. - ''' - super().__init__() self.device = device self.lr = lr self.batch_size = batch_size self.max_train_length = max_train_length self.temporal_unit = temporal_unit - - self._net = TSEncoder(input_dims=input_dims, output_dims=output_dims, hidden_dims=hidden_dims, depth=depth).to(self.device) + + self._net = TSEncoder( + input_dims=input_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + depth=depth, + ).to(self.device) self.net = torch.optim.swa_utils.AveragedModel(self._net) self.net.update_parameters(self._net) - + self.after_iter_callback = after_iter_callback self.after_epoch_callback = after_epoch_callback - + self.n_epochs = 0 self.n_iters = 0 - - def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False): - ''' Training the TS2Vec model. - - Args: - train_data (numpy.ndarray): The training data. It should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. - n_epochs (Union[int, NoneType]): The number of epochs. When this reaches, the training stops. - n_iters (Union[int, NoneType]): The number of iterations. When this reaches, the training stops. If both n_epochs and n_iters are not specified, a default setting would be used that sets n_iters to 200 for a dataset with size <= 100000, 600 otherwise. - verbose (bool): Whether to print the training loss after each epoch. - - Returns: - loss_log: a list containing the training losses on each epoch. - ''' assert train_data.ndim == 3 - + if n_iters is None and n_epochs is None: - n_iters = 200 if train_data.size <= 100000 else 600 # default param for n_iters - + n_iters = ( + 200 if train_data.size <= 100000 else 600 + ) # default param for n_iters + if self.max_train_length is not None: sections = train_data.shape[1] // self.max_train_length if sections >= 2: - train_data = np.concatenate(split_with_nan(train_data, sections, axis=1), axis=0) + train_data = np.concatenate( + split_with_nan(train_data, sections, axis=1), axis=0 + ) temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0) if temporal_missing[0] or temporal_missing[-1]: train_data = centerize_vary_length_series(train_data) - + train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)] - + train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float)) - train_loader = DataLoader(train_dataset, batch_size=min(self.batch_size, len(train_dataset)), shuffle=True, drop_last=True) - + train_loader = DataLoader( + train_dataset, + batch_size=min(self.batch_size, len(train_dataset)), + shuffle=True, + drop_last=True, + ) + optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr) - + loss_log = [] - + while True: if n_epochs is not None and self.n_epochs >= n_epochs: break - + cum_loss = 0 n_epoch_iters = 0 - + interrupted = False for batch in train_loader: if n_iters is not None and self.n_iters >= n_iters: interrupted = True break - + x = batch[0] - if self.max_train_length is not None and x.size(1) > self.max_train_length: - window_offset = np.random.randint(x.size(1) - self.max_train_length + 1) + if ( + self.max_train_length is not None + and x.size(1) > self.max_train_length + ): + window_offset = np.random.randint( + x.size(1) - self.max_train_length + 1 + ) x = x[:, window_offset : window_offset + self.max_train_length] x = x.to(self.device) - + ts_l = x.size(1) - crop_l = np.random.randint(low=2 ** (self.temporal_unit + 1), high=ts_l+1) + crop_l = np.random.randint( + low=2 ** (self.temporal_unit + 1), high=ts_l + 1 + ) crop_left = np.random.randint(ts_l - crop_l + 1) crop_right = crop_left + crop_l crop_eleft = np.random.randint(crop_left + 1) crop_eright = np.random.randint(low=crop_right, high=ts_l + 1) - crop_offset = np.random.randint(low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0)) - + crop_offset = np.random.randint( + low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0) + ) + optimizer.zero_grad() - - out1 = self._net(take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft)) + + out1 = self._net( + take_per_row( + x, crop_offset + crop_eleft, crop_right - crop_eleft + ) + ) out1 = out1[:, -crop_l:] - - out2 = self._net(take_per_row(x, crop_offset + crop_left, crop_eright - crop_left)) + + out2 = self._net( + take_per_row( + x, crop_offset + crop_left, crop_eright - crop_left + ) + ) out2 = out2[:, :crop_l] - + loss = hierarchical_contrastive_loss( - out1, - out2, - temporal_unit=self.temporal_unit + out1, out2, temporal_unit=self.temporal_unit ) - + loss.backward() optimizer.step() self.net.update_parameters(self._net) - + cum_loss += loss.item() n_epoch_iters += 1 - + self.n_iters += 1 - + if self.after_iter_callback is not None: self.after_iter_callback(self, loss.item()) - + if interrupted: break - + cum_loss /= n_epoch_iters loss_log.append(cum_loss) if verbose: - print(f"Epoch #{self.n_epochs}: loss={cum_loss}") + print(f"Epoch #{self.n_epochs}: loss={cum_loss}") # noqa self.n_epochs += 1 - + if self.after_epoch_callback is not None: self.after_epoch_callback(self, cum_loss) - + return loss_log - + def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None): out = self.net(x.to(self.device, non_blocking=True), mask) - if encoding_window == 'full_series': + if encoding_window == "full_series": if slicing is not None: out = out[:, slicing] out = F.max_pool1d( out.transpose(1, 2), - kernel_size = out.size(1), + kernel_size=out.size(1), ).transpose(1, 2) - + elif isinstance(encoding_window, int): out = F.max_pool1d( out.transpose(1, 2), - kernel_size = encoding_window, - stride = 1, - padding = encoding_window // 2 + kernel_size=encoding_window, + stride=1, + padding=encoding_window // 2, ).transpose(1, 2) if encoding_window % 2 == 0: out = out[:, :-1] if slicing is not None: out = out[:, slicing] - - elif encoding_window == 'multiscale': + + elif encoding_window == "multiscale": p = 0 reprs = [] while (1 << p) + 1 < out.size(1): t_out = F.max_pool1d( out.transpose(1, 2), - kernel_size = (1 << (p + 1)) + 1, - stride = 1, - padding = 1 << p + kernel_size=(1 << (p + 1)) + 1, + stride=1, + padding=1 << p, ).transpose(1, 2) if slicing is not None: t_out = t_out[:, slicing] reprs.append(t_out) p += 1 out = torch.cat(reprs, dim=-1) - + else: if slicing is not None: out = out[:, slicing] - + return out.cpu() - - - - def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_length=None, sliding_padding=0, batch_size=None): - ''' Compute representations using the model. - - Args: - data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN. - mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'. - encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size. - causal (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp. - sliding_length (Union[int, NoneType]): The length of sliding window. When this param is specified, a sliding inference would be applied on the time series. - sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows. - batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training. - - Returns: - repr: The representations for data. - ''' - assert self.net is not None, 'please train or load a net first' + + def encode( + self, + data, + mask=None, + encoding_window=None, + causal=False, + sliding_length=None, + sliding_padding=0, + batch_size=None, + ): + assert self.net is not None, "please train or load a net first" assert data.ndim == 3 if batch_size is None: batch_size = self.batch_size @@ -485,10 +599,10 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le org_training = self.net.training self.net.eval() - + dataset = TensorDataset(torch.from_numpy(data).to(torch.float)) loader = DataLoader(dataset, batch_size=batch_size) - + with torch.no_grad(): output = [] for batch in loader: @@ -499,21 +613,28 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le calc_buffer = [] calc_buffer_l = 0 for i in range(0, ts_l, sliding_length): - l = i - sliding_padding - r = i + sliding_length + (sliding_padding if not causal else 0) + left = i - sliding_padding + right = ( + i + + sliding_length + + (sliding_padding if not causal else 0) + ) x_sliding = torch_pad_nan( - x[:, max(l, 0) : min(r, ts_l)], - left=-l if l<0 else 0, - right=r-ts_l if r>ts_l else 0, - dim=1 + x[:, max(left, 0) : min(right, ts_l)], + left=-left if left < 0 else 0, + right=right - ts_l if right > ts_l else 0, + dim=1, ) if n_samples < batch_size: if calc_buffer_l + n_samples > batch_size: out = self._eval_with_pooling( torch.cat(calc_buffer, dim=0), mask, - slicing=slice(sliding_padding, sliding_padding+sliding_length), - encoding_window=encoding_window + slicing=slice( + sliding_padding, + sliding_padding + sliding_length, + ), + encoding_window=encoding_window, ) reprs += torch.split(out, n_samples) calc_buffer = [] @@ -524,8 +645,11 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le out = self._eval_with_pooling( x_sliding, mask, - slicing=slice(sliding_padding, sliding_padding+sliding_length), - encoding_window=encoding_window + slicing=slice( + sliding_padding, + sliding_padding + sliding_length, + ), + encoding_window=encoding_window, ) reprs.append(out) @@ -534,49 +658,32 @@ def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_le out = self._eval_with_pooling( torch.cat(calc_buffer, dim=0), mask, - slicing=slice(sliding_padding, sliding_padding+sliding_length), - encoding_window=encoding_window + slicing=slice( + sliding_padding, + sliding_padding + sliding_length, + ), + encoding_window=encoding_window, ) reprs += torch.split(out, n_samples) calc_buffer = [] calc_buffer_l = 0 - + out = torch.cat(reprs, dim=1) - if encoding_window == 'full_series': + if encoding_window == "full_series": out = F.max_pool1d( out.transpose(1, 2).contiguous(), - kernel_size = out.size(1), + kernel_size=out.size(1), ).squeeze(1) else: - out = self._eval_with_pooling(x, mask, encoding_window=encoding_window) - if encoding_window == 'full_series': + out = self._eval_with_pooling( + x, mask, encoding_window=encoding_window + ) + if encoding_window == "full_series": out = out.squeeze(1) - + output.append(out) - + output = torch.cat(output, dim=0) - + self.net.train(org_training) return output.numpy() - - def save(self, fn): - ''' Save the model to a file. - - Args: - fn (str): filename. - ''' - torch.save(self.net.state_dict(), fn) - - def load(self, fn): - ''' Load the model from a file. - - Args: - fn (str): filename. - ''' - state_dict = torch.load(fn, map_location=self.device) - self.net.load_state_dict(state_dict) - - - - - From ab692bb729a3312b1d08742edbad4a7cc8ebb8d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Petelin?= Date: Tue, 29 Apr 2025 15:38:16 +0200 Subject: [PATCH 4/9] Fixed Ts2Vec tests --- .../contrastive_based/tests/test_ts2vec.py | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py b/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py index 44098cfa03..45e9c7d91a 100644 --- a/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py +++ b/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py @@ -1,29 +1,20 @@ +"""TS2vec tests.""" + import numpy as np +import pytest from aeon.transformations.collection.contrastive_based._ts2vec import TS2Vec - -def test_shape(): - expected_features = 200 - X = np.random.random(size=(10, 1, 100)) - transformer = TS2Vec(output_dim=expected_features) - transformer.fit(X) - X_trans = transformer.transform(X) - np.testing.assert_equal(X_trans.shape, (len(X), expected_features)) - -def test_shape2(): - expected_features = 500 - X = np.random.random(size=(10, 1, 100)) - transformer = TS2Vec(output_dim=expected_features) - transformer.fit(X) - X_trans = transformer.transform(X) - np.testing.assert_equal(X_trans.shape, (len(X), expected_features)) - -def test_shape3(): - expected_features = 200 - X = np.random.random(size=(10, 3, 100)) - transformer = TS2Vec(output_dim=expected_features) - transformer.fit(X) - X_trans = transformer.transform(X) - np.testing.assert_equal(X_trans.shape, (len(X), expected_features)) \ No newline at end of file +@pytest.mark.parametrize("expected_feature_size", [3, 5, 10]) +@pytest.mark.parametrize("n_series", [1, 2, 5]) +@pytest.mark.parametrize("n_channels", [1, 2, 3]) +@pytest.mark.parametrize("series_length", [3, 10, 20]) +def test_ts2vec_output_shapes( + expected_feature_size, n_series, n_channels, series_length +): + """Test the output shapes of the TS2Vec transformer.""" + X = np.random.random(size=(n_series, n_channels, series_length)) + transformer = TS2Vec(output_dim=expected_feature_size, device="cpu", n_epochs=2) + X_t = transformer.fit_transform(X) + assert X_t.shape == (n_series, expected_feature_size) From bd176afe991037cfce04209d158796a1e8dbfd99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Petelin?= Date: Mon, 19 May 2025 14:46:57 +0200 Subject: [PATCH 5/9] Updated TS2Vec docs --- .../collection/contrastive_based/_ts2vec.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/aeon/transformations/collection/contrastive_based/_ts2vec.py b/aeon/transformations/collection/contrastive_based/_ts2vec.py index dc5a17b8e0..b2dc6d60d6 100644 --- a/aeon/transformations/collection/contrastive_based/_ts2vec.py +++ b/aeon/transformations/collection/contrastive_based/_ts2vec.py @@ -13,6 +13,13 @@ class TS2Vec(BaseCollectionTransformer): """TS2Vec Transformer. + TS2Vec [1]_ is a self-supervised model designed to learn universal representations + of time series data. It employs a hierarchical contrastive learning framework + that captures both local and global temporal dependencies. This approach + enables TS2Vec to generate robust representations for each timestamp + and allows for flexible aggregation to obtain representations for arbitrary + subsequences. + Parameters ---------- output_dim : int, default=320 @@ -39,6 +46,8 @@ class TS2Vec(BaseCollectionTransformer): that sets n_iters to 200 for datasets with size <= 100000, and 600 otherwise. verbose : bool, default=False Whether to print the training loss after each epoch. + after_epoch_callback : callable, default=None + A callback function to be called after each epoch. device : None or str, default=None The device to use for training and inference. If None, it will automatically select 'cuda' if available, otherwise 'cpu'. @@ -92,6 +101,7 @@ def __init__( n_iters=None, device=None, n_jobs=1, + after_epoch_callback=None, verbose=False, ): self.output_dim = output_dim @@ -107,6 +117,7 @@ def __init__( self.verbose = verbose self.n_epochs = n_epochs self.n_iters = n_iters + self.after_epoch_callback = after_epoch_callback super().__init__() def _transform(self, X, y=None): @@ -138,6 +149,7 @@ def _fit(self, X, y=None): batch_size=self.batch_size, max_train_length=self.max_train_length, temporal_unit=self.temporal_unit, + after_epoch_callback=self.after_epoch_callback, ) self.loss_ = self._ts2vec.fit( X.transpose(0, 2, 1), From d5715730a06ca1b50c56033bb993a15773ab39a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Petelin?= Date: Mon, 19 May 2025 15:27:48 +0200 Subject: [PATCH 6/9] Fixed test import --- .../collection/contrastive_based/tests/test_ts2vec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py b/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py index 45e9c7d91a..7fde89d0aa 100644 --- a/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py +++ b/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from aeon.transformations.collection.contrastive_based._ts2vec import TS2Vec +from aeon.transformations.collection.contrastive_based import TS2Vec @pytest.mark.parametrize("expected_feature_size", [3, 5, 10]) From b483c903bb97c2bd26648809d428a6d2a8168380 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Petelin?= Date: Sun, 25 May 2025 11:50:25 +0200 Subject: [PATCH 7/9] Updated docs --- docs/api_reference/transformations.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/api_reference/transformations.rst b/docs/api_reference/transformations.rst index 138ca4ffcb..c73a104463 100644 --- a/docs/api_reference/transformations.rst +++ b/docs/api_reference/transformations.rst @@ -126,6 +126,18 @@ Interval based SupervisedIntervals QUANTTransformer + +Contrastive based +~~~~~~~~~~~~~~ + +.. currentmodule:: aeon.transformations.collection.contrastive_based + +.. autosummary:: + :toctree: auto_generated/ + :template: class.rst + + TS2Vec + Shapelet based ~~~~~~~~~~~~~~ From f16f90791815f670998736236cf4a1363a0016cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Petelin?= Date: Sun, 25 May 2025 12:16:53 +0200 Subject: [PATCH 8/9] Removed type from test --- aeon/testing/estimator_checking/_yield_regression_checks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aeon/testing/estimator_checking/_yield_regression_checks.py b/aeon/testing/estimator_checking/_yield_regression_checks.py index 531721a459..06fe479654 100644 --- a/aeon/testing/estimator_checking/_yield_regression_checks.py +++ b/aeon/testing/estimator_checking/_yield_regression_checks.py @@ -169,7 +169,6 @@ def check_regressor_overrides_and_tags(estimator_class): "feature", "hybrid", "shapelet", - "contrastive", ] algorithm_type = estimator_class.get_class_tag("algorithm_type") if algorithm_type is not None: From ff0b885422f023eb27b5254ff576811c6cdba2ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C5=A1per=20Petelin?= Date: Sun, 25 May 2025 21:13:43 +0200 Subject: [PATCH 9/9] Update test_ts2vec.py --- .../collection/contrastive_based/tests/test_ts2vec.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py b/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py index 7fde89d0aa..f5305f663a 100644 --- a/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py +++ b/aeon/transformations/collection/contrastive_based/tests/test_ts2vec.py @@ -4,8 +4,13 @@ import pytest from aeon.transformations.collection.contrastive_based import TS2Vec +from aeon.utils.validation._dependencies import _check_soft_dependencies +@pytest.mark.skipif( + not _check_soft_dependencies("torch", severity="none"), + reason="skip test if required soft dependency torch not available", +) @pytest.mark.parametrize("expected_feature_size", [3, 5, 10]) @pytest.mark.parametrize("n_series", [1, 2, 5]) @pytest.mark.parametrize("n_channels", [1, 2, 3])