diff --git a/README.md b/README.md index 3f6eb132f1..c66d9c9a03 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ does not apply: - `segmentation` - `similarity_search` - `visualisation` +- `transformations.collection.self_supervised` | Overview | | |-----------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| diff --git a/aeon/networks/base.py b/aeon/networks/base.py index e517894a49..cd249ff1e6 100644 --- a/aeon/networks/base.py +++ b/aeon/networks/base.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod +from aeon.utils.repr import get_unchanged_and_required_params_as_str from aeon.utils.validation._dependencies import ( _check_python_version, _check_soft_dependencies, @@ -25,6 +26,11 @@ def __init__(self, soft_dependencies="tensorflow", python_version="<3.13"): _check_python_version(python_version) super().__init__() + def __repr__(self): + """Format str output like scikit-learn estimators.""" + changed_params = get_unchanged_and_required_params_as_str(self) + return f"{self.__class__.__name__}({changed_params})" + @abstractmethod def build_network(self, input_shape, **kwargs): """Construct a network and return its input and output layers. diff --git a/aeon/networks/tests/test_network_base.py b/aeon/networks/tests/test_network_base.py index 922d3cd514..542583eff5 100644 --- a/aeon/networks/tests/test_network_base.py +++ b/aeon/networks/tests/test_network_base.py @@ -20,7 +20,8 @@ def build_network(self, input_shape, **kwargs): import tensorflow as tf input_layer = tf.keras.layers.Input(input_shape) - output_layer = tf.keras.layers.Dense(units=10)(input_layer) + flatten_layer = tf.keras.layers.Flatten()(input_layer) + output_layer = tf.keras.layers.Dense(units=10)(flatten_layer) return input_layer, output_layer diff --git a/aeon/transformations/collection/self_supervised/__init__.py b/aeon/transformations/collection/self_supervised/__init__.py new file mode 100644 index 0000000000..f1b40c5d49 --- /dev/null +++ b/aeon/transformations/collection/self_supervised/__init__.py @@ -0,0 +1,5 @@ +"""Self Supervised deep learning transformers.""" + +__all__ = ["TRILITE"] + +from aeon.transformations.collection.self_supervised._trilite import TRILITE diff --git a/aeon/transformations/collection/self_supervised/_trilite.py b/aeon/transformations/collection/self_supervised/_trilite.py new file mode 100644 index 0000000000..b0ba47a879 --- /dev/null +++ b/aeon/transformations/collection/self_supervised/_trilite.py @@ -0,0 +1,681 @@ +"""TRILITE SSL transformer.""" + +from __future__ import annotations + +__maintainer__ = ["hadifawaz1999"] +__all__ = ["TRILITE"] + +import gc +import os +import sys +import time +from copy import deepcopy +from typing import TYPE_CHECKING + +import numpy as np +from sklearn.utils import check_random_state + +from aeon.networks import BaseDeepLearningNetwork +from aeon.transformations.collection import BaseCollectionTransformer +from aeon.utils.self_supervised.general import z_normalization + +if TYPE_CHECKING: + from tensorflow.keras.callbacks import Callback + from tensorflow.keras.optimizers import Optimizer + + +class TRILITE(BaseCollectionTransformer): + """TRIplet Loss In TimE (TRILITE). + + TRILITE [1]_ is a self-supervised model that learns a latent + space through the triplet loss mechanism by reducing + the loss between close samples and increasing it between + far samples. TRILITE generates the triplets using two techniques, + mixing up and masking. For each reference series (ref), a positive + representation of ref is generated by mixing it up with two other + randomly chosen time series from the dataset then masking a part + of it. The weights of the mixing up procedure are randomly chosen + for the two randomly selected series in a way that the ref still + has the highest weight. The same procedure is used to generated + the negative representation however by using another ref. + + Parameters + ---------- + alpha : float, default = 1e-2 + The value that controls the space of the triplet loss, + the smaller the value the more difficult the problem + becomes, the higher the value the more easy the problem + becomes, a balance should be found. + weight_ref_min : float, default = 0.6 + The weight of the reference series used for the triplet + generation. + percentage_mask_length : int, default = 0.2 + The percentage of time series length to calculate + the length of the masking used for the triplet + generation. Default is 20%. + use_mixing_up : bool, default = True + Wether or not to use mixing up during the triplet + generation phase. + use_masking : bool, default = True + Whether or not to use masking during the triplet + generation phase. + z_normalize_pos_neg : bool, default = True + Whether or not to z_normalize (mean 0 and std 1) + pos and neg samples after generating the triplet. + backbone_network : aeon Network, default = None + The backbone network used for the SSL model, + it can be any network from the aeon.networks + module on condition for it's structure to be + configured as "encoder", see _config attribute. + For TRILITE, the default network used is + FCNNetwork. + latent_space_dim : int, default = 128 + The size of the latent space, applied using a + fully connected layer at the end of the network's + output. + latent_space_activation : str, default = "linear" + The activation to control the range of values + of the latent space. + random_state : int, RandomState instance or None, default=None + If `int`, random_state is the seed used by the random number generator; + If `RandomState` instance, random_state is the random number generator; + If `None`, the random number generator is the `RandomState` instance used + by `np.random`. + Seeded random number generation can only be guaranteed on CPU processing, + GPU processing will be non-deterministic. + verbose : boolean, default = False + Whether to output extra information. + optimizer : keras.optimizer, default = tf.keras.optimizers.Adam() + The keras optimizer used for training. + file_path : str, default = "./" + File path to save best model. + save_best_model : bool, default = False + Whether or not to save the best model, if the + modelcheckpoint callback is used by default, + this condition, if True, will prevent the + automatic deletion of the best saved model from + file and the user can choose the file name. + save_last_model : bool, default = False + Whether or not to save the last model, last + epoch trained, using the base class method + save_last_model_to_file. + save_init_model : bool, default = False + Whether to save the initialization of the model. + best_file_name : str, default = "best_model" + The name of the file of the best model, if + save_best_model is set to False, this parameter + is discarded. + last_file_name : str, default = "last_model" + The name of the file of the last model, if + save_last_model is set to False, this parameter + is discarded. + init_file_name : str, default = "init_model" + The name of the file of the init model, if + save_init_model is set to False, + this parameter is discarded. + callbacks : keras callback or list of callbacks, + default = None + The default list of callbacks are set to + ModelCheckpoint and ReduceLROnPlateau. + batch_size : int, default = 64 + The number of samples per gradient update. + use_mini_batch_size : bool, default = False + Whether or not to use the mini batch size formula. + n_epochs : int, default = 2000 + The number of epochs to train the model. + + Notes + ----- + Adapted from the implementation from Ismail-Fawaz et. al + https://github.com/MSD-IRIMAS/TRILITE + + References + ---------- + .. [1] Ismail-Fawaz, Ali, Maxime Devanne, Jonathan Weber, + and Germain Forestier. "Enhancing time series classification + with self-supervised learning." In International Conference + on Agents and Artificial Intelligence (ICAART), pp. 40-47. + SCITEPRESS-Science and Technology Publications, 2023. + + Examples + -------- + >>> from aeon.transformations.collection.self_supervised import TRILITE + >>> from aeon.datasets import load_unit_test + >>> X_train, y_train = load_unit_test(split="train") + >>> ssl = TRILITE(latent_space_dim=2, n_epochs=5) # doctest: +SKIP + >>> ssl.fit(X_train) # doctest: +SKIP + TRILITE(...) + >>> X_train_transformed = ssl.transform(X_train) # doctest: +SKIP + """ + + _tags = { + "X_inner_type": "numpy3D", + "output_data_type": "Tabular", + "capability:multivariate": True, + "algorithm_type": "deeplearning", + "python_dependencies": "tensorflow", + "non_deterministic": True, + "cant_pickle": True, + } + + def __init__( + self, + alpha: float = 1e-2, + weight_ref_min: float = 0.6, + percentage_mask_length: float = 0.2, + use_mixing_up: bool = True, + use_masking: bool = True, + z_normalize_pos_neg: bool = True, + backbone_network: BaseDeepLearningNetwork = None, + latent_space_dim: int = 128, + latent_space_activation: str = "linear", + random_state: int | np.random.RandomState | None = None, + verbose: bool = False, + optimizer: Optimizer | None = None, + file_path: str = "./", + save_best_model: bool = False, + save_last_model: bool = False, + save_init_model: bool = False, + best_file_name: str = "best_model", + last_file_name: str = "last_model", + init_file_name: str = "init_model", + callbacks: Callback | list[Callback] | None = None, + batch_size: int = 64, + use_mini_batch_size: bool = False, + n_epochs: int = 2000, + ): + self.alpha = alpha + self.weight_ref_min = weight_ref_min + self.percentage_mask_length = percentage_mask_length + self.use_mixing_up = use_mixing_up + self.use_masking = use_masking + self.z_normalize_pos_neg = z_normalize_pos_neg + self.backbone_network = backbone_network + self.latent_space_dim = latent_space_dim + self.latent_space_activation = latent_space_activation + self.random_state = random_state + self.verbose = verbose + self.optimizer = optimizer + self.file_path = file_path + self.save_best_model = save_best_model + self.save_last_model = save_last_model + self.save_init_model = save_init_model + self.best_file_name = best_file_name + self.last_file_name = last_file_name + self.init_file_name = init_file_name + self.callbacks = callbacks + self.batch_size = batch_size + self.use_mini_batch_size = use_mini_batch_size + self.n_epochs = n_epochs + + super().__init__() + + def _fit(self, X: np.ndarray, y=None): + """Fit the SSL model on X, y is ignored. + + Parameters + ---------- + X : np.ndarray + The training input samples of shape (n_cases, n_channels, n_timepoints) + y : ignored argument for interface compatibility + + Returns + ------- + self : object + """ + import tensorflow as tf + + from aeon.networks import FCNNetwork + + if isinstance(self.backbone_network, BaseDeepLearningNetwork): + self._backbone_network = deepcopy(self.backbone_network) + elif self.backbone_network is None: + self._backbone_network = FCNNetwork() + else: + raise ValueError( + "The parameter backbone_network", "should be an aeon network." + ) + + X = X.transpose(0, 2, 1) + + self.input_shape = X.shape[1:] + self.training_model_ = self.build_model(self.input_shape) + + if self.save_init_model: + self.training_model_.save( + os.path.join(self.file_path, self.init_file_name + ".keras") + ) + + if self.verbose: + self.training_model_.summary() + + if self.use_mini_batch_size: + mini_batch_size = min(self.batch_size, X.shape[0] // 10) + else: + mini_batch_size = self.batch_size + + self.file_name_ = ( + self.best_file_name if self.save_best_model else str(time.time_ns()) + ) + + if self.callbacks is None: + self.callbacks_ = [ + tf.keras.callbacks.ReduceLROnPlateau( + monitor="loss", factor=0.5, patience=50, min_lr=0.0001 + ), + tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(self.file_path, self.file_name_ + ".keras"), + monitor="loss", + save_best_only=True, + ), + ] + else: + self.callbacks_ = self._get_model_checkpoint_callback( + callbacks=self.callbacks, + file_path=self.file_path, + file_name=self.file_name_, + ) + + fake_y = np.zeros(shape=len(X)) + + train_dataset = tf.data.Dataset.from_tensor_slices((X, fake_y)) + train_dataset = train_dataset.shuffle(buffer_size=1024).batch(mini_batch_size) + + history = {"loss": []} + + for callback in self.callbacks_: + callback.set_model(self.training_model_) + callback.on_train_begin() + + for epoch in range(self.n_epochs): + epoch_loss = 0 + num_batches = 0 + + for step, (x_batch_train, _) in enumerate(train_dataset): + ref_batch_train, pos_batch_train, neg_batch_train = ( + self._triplet_generation(X=x_batch_train) + ) + + with tf.GradientTape() as tape: + ref_pos_neg = self.training_model_( + [ref_batch_train, pos_batch_train, neg_batch_train] + ) + loss_batch = self._triplet_loss_function( + alpha=self.alpha, ref_pos_neg=ref_pos_neg + ) + loss_mean = tf.reduce_mean(loss_batch) + + gradients = tape.gradient( + loss_mean, self.training_model_.trainable_weights + ) + self.optimizer_.apply_gradients( + zip(gradients, self.training_model_.trainable_weights) + ) + + epoch_loss += float(loss_mean) + num_batches += 1 + + for callback in self.callbacks_: + callback.on_batch_end(step, {"loss": float(loss_mean)}) + + epoch_loss /= num_batches + history["loss"].append(epoch_loss) + + if self.verbose: + sys.stdout.write( + "Training loss at epoch %d: %.4f\n" % (epoch, float(epoch_loss)) + ) + + for callback in self.callbacks_: + callback.on_epoch_end(epoch, {"loss": float(epoch_loss)}) + + for callback in self.callbacks_: + callback.on_train_end() + + self.history = history + + try: + self.model_ = tf.keras.models.load_model( + os.path.join(self.file_path, self.file_name_ + ".keras"), compile=False + ) + if not self.save_best_model: + os.remove(os.path.join(self.file_path, self.file_name_ + ".keras")) + except FileNotFoundError: + self.model_ = deepcopy(self.training_model_) + + if self.save_last_model: + self.save_last_model_to_file(file_path=self.file_path) + + gc.collect() + return self + + def _transform(self, X, y=None): + """Transform input time series using TRILITE. + + Parameters + ---------- + X : 3D np.ndarray of shape = (n_cases, n_channels, n_timepoints) + collection of time series to transform + y : ignored argument for interface compatibility + + Returns + ------- + np.ndarray (n_cases, latent_space_dim), transformed features + """ + X = X.transpose(0, 2, 1) + X_ref_pos_neg_transformed = self.model_.predict([X, X, X], self.batch_size) + + X_transformed_ = np.delete(X_ref_pos_neg_transformed, obj=[1, 2], axis=2) + + X_transformed = np.reshape( + X_transformed_, (len(X_transformed_), self.latent_space_dim) + ) + + return X_transformed + + def build_model(self, input_shape): + """Construct a compiled, un-trained, keras model that is ready for training. + + In aeon, time series are stored in numpy arrays of shape (d,m), where d + is the number of dimensions, m is the series length. Keras/tensorflow assume + data is in shape (m,d). This method also assumes (m,d). Transpose should + happen in fit. + + Parameters + ---------- + input_shape : tuple[int, int] + The shape of the data fed into the input layer, should be (m, d). + + Returns + ------- + output : a compiled Keras Model + """ + import numpy as np + import tensorflow as tf + + rng = check_random_state(self.random_state) + self.random_state_ = rng.randint(0, np.iinfo(np.int32).max) + tf.keras.utils.set_random_seed(self.random_state_) + + input_ref_layer = tf.keras.layers.Input(input_shape) + input_pos_layer = tf.keras.layers.Input(input_shape) + input_neg_layer = tf.keras.layers.Input(input_shape) + + input_layer, gap_layer = self._backbone_network.build_network( + input_shape=input_shape + ) + output_layer = tf.keras.layers.Dense( + units=self.latent_space_dim, activation=self.latent_space_activation + )(gap_layer) + + encoder_model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer) + + output_layer_ref = tf.keras.layers.Reshape(target_shape=(-1, 1))( + encoder_model(input_ref_layer) + ) + output_layer_pos = tf.keras.layers.Reshape(target_shape=(-1, 1))( + encoder_model(input_pos_layer) + ) + output_layer_neg = tf.keras.layers.Reshape(target_shape=(-1, 1))( + encoder_model(input_neg_layer) + ) + + encoder_output_layer = tf.keras.layers.Concatenate(axis=-1)( + [output_layer_ref, output_layer_pos, output_layer_neg] + ) + + model = tf.keras.models.Model( + inputs=[input_ref_layer, input_pos_layer, input_neg_layer], + outputs=encoder_output_layer, + ) + + self.optimizer_ = ( + tf.keras.optimizers.Adam() if self.optimizer is None else self.optimizer + ) + + # compile but won't be used + model.compile(loss="mse", optimizer=self.optimizer_) + + return model + + def _triplet_loss_function(self, alpha, ref_pos_neg): + """Create a triplet loss function for triplet-based training.""" + import tensorflow as tf + + ref = ref_pos_neg[:, :, 0] + pos = ref_pos_neg[:, :, 1] + neg = ref_pos_neg[:, :, 2] + + ref = tf.cast(ref, dtype=ref.dtype) + pos = tf.cast(pos, dtype=ref.dtype) + neg = tf.cast(neg, dtype=ref.dtype) + + loss_pos_ref = tf.reduce_sum(tf.square(ref - pos), axis=1) + loss_neg_ref = tf.reduce_sum(tf.square(ref - neg), axis=1) + loss_add_sub = tf.math.subtract(tf.math.add(loss_pos_ref, alpha), loss_neg_ref) + loss = tf.maximum(loss_add_sub, 0) + + return loss + + def _triplet_generation(self, X): + """Generate triplet samples (ref, pos, neg) for triplet loss training.""" + n_channels = int(X.shape[-1]) + length_TS = int(X.shape[1]) + + # define mask length + self.mask_length = int(length_TS * self.percentage_mask_length) + + # define weight for each sample in the mixing up + w_ref = np.random.choice( + np.linspace(start=self.weight_ref_min, stop=1, num=1000), size=1 + ) + w_ts = (1 - w_ref) / 2 + + # define your ref as random permutation of X + ref = np.random.permutation(X[:]) + + n = int(ref.shape[0]) + + # define positive and negative sample arrays + _pos = np.zeros(shape=ref.shape) + _neg = np.zeros(shape=ref.shape) + + all_indices = np.arange(start=0, stop=n) + + for i_ref in range(n): + # remove the sample ref from the random choice of pos-neg + all_indices_without_ref = np.delete(arr=all_indices, obj=i_ref) + + # choose a random sample used for the negative generation + index_neg = int(np.random.choice(all_indices_without_ref, size=1)) + + _ref = ref[i_ref].copy() + + # remove the index_neg from choices + all_indices_without_ref_and_not_ref = np.delete( + arr=all_indices, obj=[i_ref, index_neg] + ) + + # choose samples used for the mixing up + index_ts1_pos = int( + np.random.choice(all_indices_without_ref_and_not_ref, size=1) + ) + index_ts2_pos = int( + np.random.choice(all_indices_without_ref_and_not_ref, size=1) + ) + + index_ts1_neg = int( + np.random.choice(all_indices_without_ref_and_not_ref, size=1) + ) + index_ts2_neg = int( + np.random.choice(all_indices_without_ref_and_not_ref, size=1) + ) + + _not_ref = ref[index_neg].copy() + + _ts1_pos = ref[index_ts1_pos].copy() + _ts2_pos = ref[index_ts2_pos].copy() + + _ts1_neg = ref[index_ts1_neg].copy() + _ts2_neg = ref[index_ts2_neg].copy() + + # MixingUp + + if self.use_mixing_up and self.use_masking: + # mix up the selected series with ref to obtain pos + _pos[i_ref] = w_ref * _ref + w_ts * _ts1_pos + w_ts * _ts2_pos + # mix up the selected series with neg ref to obtain neg + _neg[i_ref] = w_ref * _not_ref + w_ts * _ts1_neg + w_ts * _ts2_neg + + # apply masking + _pos[i_ref], _neg[i_ref] = self._apply_masking( + pos=_pos[i_ref], + neg=_neg[i_ref], + n_channels=n_channels, + length_TS=length_TS, + mask_length=self.mask_length, + ) + + elif self.use_mixing_up and not self.use_masking: + # mix up the selected series with ref to obtain pos + _pos[i_ref] = w_ref * _ref + w_ts * _ts1_pos + w_ts * _ts2_pos + # mix up the selected series with neg ref to obtain neg + _neg[i_ref] = w_ref * _not_ref + w_ts * _ts1_neg + w_ts * _ts2_neg + + elif self.use_masking and not self.use_mixing_up: + # apply masking + _pos[i_ref], _neg[i_ref] = self._apply_masking( + pos=_pos[i_ref], + neg=_neg[i_ref], + n_channels=n_channels, + length_TS=length_TS, + mask_length=self.mask_length, + ) + + else: + raise ValueError( + "At least masking or mixing up", + "should be chosen to generate", + "the triplets.", + ) + if self.z_normalize_pos_neg: + # z_normalize pos and neg + _pos_normalized = z_normalization(_pos) + _neg_normalized = z_normalization(_neg) + + return ref, _pos_normalized, _neg_normalized + else: + return ref, _pos, _neg + + def _apply_masking(self, pos, neg, n_channels, length_TS, mask_length): + """Apply masking phase on pos and neg.""" + # select a random start for the mask + start_mask = int(np.random.randint(low=0, high=length_TS - mask_length, size=1)) + stop_mask = start_mask + mask_length + + # define noise on replacement on the left side of the mask + noise_pos_left = np.random.random(size=(start_mask, n_channels)) + # normalize noise + noise_pos_left /= 5 + noise_pos_left -= 0.1 + + # define noise on replacement on the left side of the mask + noise_pos_right = np.random.random(size=(length_TS - stop_mask, n_channels)) + # normalize noise + noise_pos_right /= 5 + noise_pos_right -= 0.1 + + # replace left and right side of the mask by normalized noise + pos[0:start_mask, :] = noise_pos_left + pos[stop_mask:length_TS, :] = noise_pos_right + + # repeat the same procedure for the negative sample + noise_neg_left = np.random.random(size=(start_mask, n_channels)) + noise_neg_left /= 5 + noise_neg_left -= 0.1 + noise_neg_right = np.random.random(size=(length_TS - stop_mask, n_channels)) + noise_neg_right /= 5 + noise_neg_right -= 0.1 + + neg[0:start_mask, :] = noise_neg_left + neg[stop_mask:length_TS, :] = noise_neg_right + + return pos, neg + + def save_last_model_to_file(self, file_path="./"): + """Save the last epoch of the trained deep learning model. + + Parameters + ---------- + file_path : str, default = "./" + The directory where the model will be saved + + Returns + ------- + None + """ + self.model_.save(os.path.join(file_path, self.last_file_name + ".keras")) + + def load_model(self, model_path): + """Load a pre-trained keras model instead of fitting. + + When calling this function, all functionalities can be used + such as predict, predict_proba etc. with the loaded model. + + Parameters + ---------- + model_path : str (path including model name and extension) + The directory where the model will be saved including the model + name with a ".keras" extension. + Example: model_path="path/to/file/best_model.keras" + + Returns + ------- + None + """ + import tensorflow as tf + + self.model_ = tf.keras.models.load_model(model_path) + self.is_fitted = True + + def _get_model_checkpoint_callback(self, callbacks, file_path, file_name): + import tensorflow as tf + + model_checkpoint_ = tf.keras.callbacks.ModelCheckpoint( + filepath=os.path.join(file_path, file_name + ".keras"), + monitor="loss", + save_best_only=True, + ) + + if isinstance(callbacks, list): + return callbacks + [model_checkpoint_] + else: + return [callbacks] + [model_checkpoint_] + + @classmethod + def _get_test_params(cls, parameter_set="default"): + """Return testing parameter settings for the transformer. + + Parameters + ---------- + parameter_set : str, default="default" + Name of the set of test parameters to return, for use in tests. If no + special parameters are defined for a value, will return `"default"` set. + + + Returns + ------- + params : dict or list of dict, default = {} + Parameters to create testing instances of the class + Each dict are parameters to construct an "interesting" test instance, i.e., + `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance. + """ + from aeon.networks import FCNNetwork + + params = { + "latent_space_dim": 2, + "backbone_network": FCNNetwork(n_layers=1, n_filters=2, kernel_size=2), + "n_epochs": 3, + } + + return [params] diff --git a/aeon/transformations/collection/self_supervised/tests/__init__.py b/aeon/transformations/collection/self_supervised/tests/__init__.py new file mode 100644 index 0000000000..4bd29e9e65 --- /dev/null +++ b/aeon/transformations/collection/self_supervised/tests/__init__.py @@ -0,0 +1 @@ +"""Self-Supervised tests.""" diff --git a/aeon/transformations/collection/self_supervised/tests/test_trilite.py b/aeon/transformations/collection/self_supervised/tests/test_trilite.py new file mode 100644 index 0000000000..03f0e40747 --- /dev/null +++ b/aeon/transformations/collection/self_supervised/tests/test_trilite.py @@ -0,0 +1,251 @@ +"""Test TRILITE Self-supervised transformer.""" + +import tempfile + +import numpy as np +import pytest + +from aeon.networks import LITENetwork +from aeon.networks.tests.test_network_base import DummyDeepNetwork +from aeon.transformations.collection.self_supervised import TRILITE +from aeon.utils.validation._dependencies import _check_soft_dependencies + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("use_mixing_up", [True, False]) +def test_trilite_use_mixing_up(use_mixing_up): + """Test TRILITE with possible mixing up setups.""" + X = np.random.random((100, 2, 5)) + with tempfile.TemporaryDirectory() as tmp: + + ssl = TRILITE( + use_mixing_up=use_mixing_up, + latent_space_dim=2, + backbone_network=DummyDeepNetwork(), + n_epochs=3, + file_path=tmp, + ) + + ssl.fit(X=X) + + X_transformed = ssl.transform(X=X) + + assert len(X_transformed.shape) == 2 + assert int(X_transformed.shape[-1]) == 2 + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("use_masking", [True, False]) +def test_trilite_use_masking(use_masking): + """Test TRILITE with possible masking setups.""" + X = np.random.random((100, 2, 5)) + with tempfile.TemporaryDirectory() as tmp: + + ssl = TRILITE( + use_masking=use_masking, + backbone_network=DummyDeepNetwork(), + latent_space_dim=2, + n_epochs=3, + file_path=tmp, + ) + + ssl.fit(X=X) + + X_transformed = ssl.transform(X=X) + + assert len(X_transformed.shape) == 2 + assert int(X_transformed.shape[-1]) == 2 + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("z_normalize_pos_neg", [True, False]) +def test_trilite_z_normalize_pos_neg(z_normalize_pos_neg): + """Test TRILITE with possible znorm pos and neg setups.""" + X = np.random.random((100, 2, 5)) + with tempfile.TemporaryDirectory() as tmp: + + ssl = TRILITE( + z_normalize_pos_neg=z_normalize_pos_neg, + backbone_network=DummyDeepNetwork(), + latent_space_dim=2, + n_epochs=3, + file_path=tmp, + ) + + ssl.fit(X=X) + + X_transformed = ssl.transform(X=X) + + assert len(X_transformed.shape) == 2 + assert int(X_transformed.shape[-1]) == 2 + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("alpha", [1e-1, 1e-2]) +def test_trilite_alpha(alpha): + """Test TRILITE with possible alpha setups.""" + X = np.random.random((100, 2, 5)) + with tempfile.TemporaryDirectory() as tmp: + + ssl = TRILITE( + alpha=alpha, + backbone_network=DummyDeepNetwork(), + latent_space_dim=2, + n_epochs=3, + file_path=tmp, + ) + + ssl.fit(X=X) + + X_transformed = ssl.transform(X=X) + + assert len(X_transformed.shape) == 2 + assert int(X_transformed.shape[-1]) == 2 + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("weight_ref_min", [0.5, 0.6]) +def test_trilite_weight_ref_min(weight_ref_min): + """Test TRILITE with possible weight_ref_min setups.""" + X = np.random.random((100, 2, 5)) + with tempfile.TemporaryDirectory() as tmp: + + ssl = TRILITE( + weight_ref_min=weight_ref_min, + backbone_network=DummyDeepNetwork(), + latent_space_dim=2, + n_epochs=3, + file_path=tmp, + ) + + ssl.fit(X=X) + + X_transformed = ssl.transform(X=X) + + assert len(X_transformed.shape) == 2 + assert int(X_transformed.shape[-1]) == 2 + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("percentage_mask_length", [0.2, 0.3]) +def test_trilite_percentage_mask_length(percentage_mask_length): + """Test TRILITE with possible percentage_mask_length setups.""" + X = np.random.random((100, 2, 5)) + with tempfile.TemporaryDirectory() as tmp: + + ssl = TRILITE( + percentage_mask_length=percentage_mask_length, + backbone_network=DummyDeepNetwork(), + latent_space_dim=2, + n_epochs=3, + file_path=tmp, + ) + + ssl.fit(X=X) + + X_transformed = ssl.transform(X=X) + + assert len(X_transformed.shape) == 2 + assert int(X_transformed.shape[-1]) == 2 + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("latent_space_dim", [2, 3]) +def test_trilite_latent_space_dim(latent_space_dim): + """Test TRILITE with possible latent_space_dim setups.""" + X = np.random.random((100, 2, 5)) + with tempfile.TemporaryDirectory() as tmp: + + ssl = TRILITE( + latent_space_dim=latent_space_dim, + backbone_network=DummyDeepNetwork(), + n_epochs=3, + file_path=tmp, + ) + + ssl.fit(X=X) + + X_transformed = ssl.transform(X=X) + + assert len(X_transformed.shape) == 2 + assert int(X_transformed.shape[-1]) == latent_space_dim + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("latent_space_activation", ["linear", "relu"]) +def test_trilite_latent_space_activation(latent_space_activation): + """Test TRILITE with possible latent_space_activation setups.""" + X = np.random.random((100, 2, 5)) + with tempfile.TemporaryDirectory() as tmp: + + ssl = TRILITE( + latent_space_activation=latent_space_activation, + backbone_network=DummyDeepNetwork(), + latent_space_dim=2, + n_epochs=3, + file_path=tmp, + ) + + ssl.fit(X=X) + + X_transformed = ssl.transform(X=X) + + assert len(X_transformed.shape) == 2 + assert int(X_transformed.shape[-1]) == 2 + + +@pytest.mark.skipif( + not _check_soft_dependencies(["tensorflow"], severity="none"), + reason="Tensorflow soft dependency unavailable.", +) +@pytest.mark.parametrize("backbone_network", [None, DummyDeepNetwork, LITENetwork]) +def test_trilite_backbone_network(backbone_network): + """Test TRILITE with possible backbone_network setups.""" + X = np.random.random((100, 2, 5)) + with tempfile.TemporaryDirectory() as tmp: + + if backbone_network is not None: + ssl = TRILITE( + backbone_network=backbone_network(), + latent_space_dim=2, + n_epochs=3, + file_path=tmp, + ) + else: + ssl = TRILITE( + backbone_network=backbone_network, + latent_space_dim=2, + n_epochs=3, + file_path=tmp, + ) + + ssl.fit(X=X) + + X_transformed = ssl.transform(X=X) + + assert len(X_transformed.shape) == 2 + assert int(X_transformed.shape[-1]) == 2 diff --git a/aeon/utils/repr.py b/aeon/utils/repr.py new file mode 100644 index 0000000000..67b3b75142 --- /dev/null +++ b/aeon/utils/repr.py @@ -0,0 +1,55 @@ +"""Utilities for class __repr__ presentation.""" + +import inspect + +from aeon.testing.utils.deep_equals import deep_equals + + +def get_unchanged_and_required_params_as_str(obj): + """ + Get object parameters as a comma delimited string. + + Collects the parameters of an object that are either required + (no default) or different from the __init__ default value. Returns + the parameter names and values as a comma delimited string. + + Parameters + ---------- + obj : object + The object to inspect. + + Returns + ------- + str + A string representation of the objects parameters and values. + """ + cls = obj.__class__ + signature = inspect.signature(cls.__init__) + + params = {} + for name, param in signature.parameters.items(): + if name == "self": + continue + + has_default = param.default is not inspect.Parameter.empty + current_val = getattr(obj, name, None) + + if not has_default: + # No default = always include + params[name] = current_val + else: + # Default exists = include if unchanged + if not deep_equals(current_val, param.default): + params[name] = current_val + + if len(params) == 0: + return "" + + param_str = [] + for k, v in params.items(): + if isinstance(v, str): + param_str.append(f"{k}='{v}'") + else: + param_str.append(f"{k}={v}") + + return ", ".join(param_str) diff --git a/aeon/utils/self_supervised/__init__.py b/aeon/utils/self_supervised/__init__.py new file mode 100644 index 0000000000..8de476b255 --- /dev/null +++ b/aeon/utils/self_supervised/__init__.py @@ -0,0 +1 @@ +"""Utils for self_supervised.""" diff --git a/aeon/utils/self_supervised/general.py b/aeon/utils/self_supervised/general.py new file mode 100644 index 0000000000..b3d2da7bfd --- /dev/null +++ b/aeon/utils/self_supervised/general.py @@ -0,0 +1,28 @@ +"""General utils for self_supervised.""" + +__all__ = ["z_normalization"] + +import numpy as np + + +def z_normalization(X, axis=1): + """Z-Normalize collection of time series. + + Parameters + ---------- + X : np.ndarray + The input collection of time series of shape + (n_cases, n_channels, n_timepoints). + axis : int, default = 1 + The axis of time, on which z-normalization + is performed. + + Returns + ------- + Normalized version of X. + """ + stds = np.std(X, axis=axis, keepdims=True) + if len(stds[stds == 0.0]) > 0: + stds[stds == 0.0] = 1.0 + return (X - X.mean(axis=axis, keepdims=True)) / stds + return (X - X.mean(axis=axis, keepdims=True)) / (X.std(axis=axis, keepdims=True)) diff --git a/docs/api_reference/transformations.rst b/docs/api_reference/transformations.rst index 2a56fd847f..72518fffab 100644 --- a/docs/api_reference/transformations.rst +++ b/docs/api_reference/transformations.rst @@ -126,6 +126,17 @@ Interval based SupervisedIntervals QUANTTransformer +Self Supervised +~~~~~~~~~~~~~~~ + +.. currentmodule:: aeon.transformations.collection.self_supervised + +.. autosummary:: + :toctree: auto_generated/ + :template: class.rst + + TRILITE + Shapelet based ~~~~~~~~~~~~~~ diff --git a/docs/index.md b/docs/index.md index 11b558839e..bbc4ea118f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -276,6 +276,7 @@ experimental modules are: - `segmentation` - `similarity_search` - `visualisation` +- `transformations.collection.self_supervised` ```{toctree} :caption: Using aeon