From 56e6b8f065ae3ffb6f0f5bcf8cf718a12d93a698 Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Fri, 21 Mar 2025 22:19:30 +0530 Subject: [PATCH 01/16] K shape implementation --- aeon/clustering/_k_shape.py | 271 +++++++++++++++++++++++++++++++----- 1 file changed, 240 insertions(+), 31 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index aa8d8a3b64..b1314b7b0a 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -1,11 +1,21 @@ """Time series kshapes.""" -from typing import Optional, Union +from typing import Callable, Optional, Union import numpy as np from numpy.random import RandomState +from sklearn.utils import check_random_state +from tslearn.metrics import cdist_normalized_cc +from tslearn.metrics.cycc import normalized_cc from aeon.clustering.base import BaseClusterer +from aeon.transformations.collection import Normalizer + + +class EmptyClusterError(Exception): + """Error raised when an empty cluster is encountered.""" + + pass class TimeSeriesKShape(BaseClusterer): @@ -100,40 +110,210 @@ def __init__( super().__init__() - def _fit(self, X, y=None): - """Fit time series clusterer to training data. + def _check_params(self, X: np.ndarray) -> None: + self._random_state = check_random_state(self.random_state) - Parameters - ---------- - X: np.ndarray, of shape (n_cases, n_channels, n_timepoints) or - (n_cases, n_timepoints) - A collection of time series instances. - y: ignored, exists for API consistency reasons. + if isinstance(self.init, str): + if self.init == "random": + self._init = self._random_center_initializer + else: + if isinstance(self.init, np.ndarray) and len(self.init) == self.n_clusters: + self._init = self.init.copy() + else: + raise ValueError( + f"The value provided for init: {self.init} is " + f"invalid. The following are a list of valid init algorithms " + f"strings: random, kmedoids++, first. You can also pass a" + f"np.ndarray of size (n_clusters, n_channels, n_timepoints)" + ) - Returns - ------- - self: - Fitted estimator. + if self.n_clusters > X.shape[0]: + raise ValueError( + f"n_clusters ({self.n_clusters}) cannot be larger than " + f"n_cases ({X.shape[0]})" + ) + + def _random_center_initializer(self, X: np.ndarray) -> np.ndarray: + return X[self._random_state.choice(X.shape[0], self.n_clusters)] + + def _check_no_empty_cluster(self, labels, n_clusters): + """Check that all clusters have at least one sample assigned.""" + for k in range(n_clusters): + if np.sum(labels == k) == 0: + raise EmptyClusterError + + def _compute_inertia(self, distances, assignments, squared=True): + """Derive inertia from pre-computed distances and assignments. + + Examples + -------- + >>> dists = numpy.array([[1., 2., 0.5], [0., 3., 1.]]) + >>> assign = numpy.array([2, 0]) + >>> _compute_inertia(dists, assign) + 0.125 """ - from tslearn.clustering import KShape - - self._tslearn_k_shapes = KShape( - n_clusters=self.n_clusters, - max_iter=self.max_iter, - tol=self.tol, - random_state=self.random_state, - n_init=self.n_init, - verbose=self.verbose, - init=self.init, + n_ts = distances.shape[0] + if squared: + return np.sum(distances[np.arange(n_ts), assignments] ** 2) / n_ts + else: + return np.sum(distances[np.arange(n_ts), assignments]) / n_ts + + def _assign(self, X): + X_temp = np.transpose(X, (0, 2, 1)) + cluster_temp = np.transpose(self.cluster_centers_, (0, 2, 1)) + dists = 1.0 - cdist_normalized_cc( + X_temp, cluster_temp, self.norms_, self.norms_centroids_, False ) + self.labels_ = dists.argmin(axis=1) + self._check_no_empty_cluster(self.labels_, self.n_clusters) + self.inertia_ = self._compute_inertia(dists, self.labels_) + + def y_shifted_sbd_vec(self, ref_ts, dataset, norm_ref, norms_dataset): + n_ts = dataset.shape[0] + d = dataset.shape[1] + sz = dataset.shape[2] + assert sz == ref_ts.shape[1] and d == ref_ts.shape[0] + dataset_shifted = np.zeros((n_ts, d, sz)) + + if norm_ref < 0: + norm_ref = np.linalg.norm(ref_ts) + if (norms_dataset < 0.0).any(): + for i_ts in range(n_ts): + norms_dataset[i_ts] = np.linalg.norm(dataset[i_ts, ...]) + + for i in range(n_ts): + # TODO: remove dependency on normalized_cc + ref_ts_temp = ref_ts.T + dataset_temp = np.transpose(dataset, (0, 2, 1)) + cc = normalized_cc( + ref_ts_temp, dataset_temp[i], norm1=norm_ref, norm2=norms_dataset[i] + ) + idx = np.argmax(cc) + shift = idx - sz + if shift > 0: + dataset_shifted[i, :, shift:] = dataset[i, :, : sz - shift] + elif shift < 0: + dataset_shifted[i, :, : sz + shift] = dataset[i, :, -shift:] + else: + dataset_shifted[i] = dataset[i] + + return dataset_shifted + + def _shape_extraction(self, X, k): + # X is of dim (n_ts, d, sz) + sz = X.shape[2] + d = X.shape[1] + Xp = self.y_shifted_sbd_vec( + self.cluster_centers_[k], + X[self.labels_ == k], + -1, + self.norms_[self.labels_ == k], + ) + # Xp is of dim (n_ts, d, sz) + S = np.dot(Xp[:, 0, :].T, Xp[:, 0, :]) + Q = np.eye(sz) - np.ones((sz, sz)) / sz + M = np.dot(Q.T, np.dot(S, Q)) + + _, vec = np.linalg.eigh(M) + mu_k = vec[:, -1].reshape((sz, 1)) + + mu_k_broadcast = mu_k.reshape((1, 1, sz)) + dist_plus_mu = np.sum(np.linalg.norm(Xp - mu_k_broadcast, axis=(1, 2))) + dist_minus_mu = np.sum(np.linalg.norm(Xp + mu_k_broadcast, axis=(1, 2))) + + if dist_minus_mu < dist_plus_mu: + mu_k *= -1 - _X = X.swapaxes(1, 2) + d = Xp.shape[1] + mu_k = np.tile(mu_k.T, (d, 1)) + return mu_k - self._tslearn_k_shapes.fit(_X) - self._cluster_centers = self._tslearn_k_shapes.cluster_centers_ - self.labels_ = self._tslearn_k_shapes.predict(_X) - self.inertia_ = self._tslearn_k_shapes.inertia_ - self.n_iter_ = self._tslearn_k_shapes.n_iter_ + def _update_centroids(self, X): + # X is (n, d, sz) + for k in range(self.n_clusters): + self.cluster_centers_[k] = self._shape_extraction(X, k) + + normaliser = Normalizer() + self.cluster_centers_ = normaliser.fit_transform(self.cluster_centers_) + self.norms_centroids_ = np.linalg.norm(self.cluster_centers_, axis=(1, 2)) + + def _fit_one_init(self, X): + if isinstance(self._init, Callable): + self.cluster_centers_ = self._init(X) + else: + self.cluster_centers_ = self._init.copy() + + self.norms_centroids_ = np.linalg.norm(self.cluster_centers_, axis=(1, 2)) + self._assign(X) + old_inertia = np.inf + + it = 0 + for it in range(self.max_iter): # noqa: B007 + old_cluster_centers = self.cluster_centers_.copy() + self._update_centroids(X) + self._assign(X) + if self.verbose: + print("%.3f" % self.inertia_, end=" --> ") # noqa: T001, T201 + + if np.abs(old_inertia - self.inertia_) < self.tol or ( + old_inertia - self.inertia_ < 0 + ): + self.cluster_centers_ = old_cluster_centers + self._assign(X) + break + + old_inertia = self.inertia_ + if self.verbose: + print("") # noqa: T001, T201 + + self._iter = it + 1 + + return self + + def _fit(self, X, y=None): + # X = check_array(X, allow_nd=True) add aeon version + self._check_params(X) + + max_attempts = max(self.n_init, 10) + + self.inertia_ = np.inf + + self.norms_ = 0.0 + self.norms_centroids_ = 0.0 + + self._X_fit = X + self.norms_ = np.linalg.norm(X, axis=(1, 2)) + + best_correct_centroids = None + min_inertia = np.inf + n_successful = 0 + n_attempts = 0 + while n_successful < self.n_init and n_attempts < max_attempts: + try: + if self.verbose and self.n_init > 1: + print("Init %d" % (n_successful + 1)) # noqa: T001, T201 + n_attempts += 1 + self._fit_one_init(X) + if self.inertia_ < min_inertia: + best_correct_centroids = self.cluster_centers_.copy() + min_inertia = self.inertia_ + self.n_iter_ = self._iter + n_successful += 1 + except EmptyClusterError: + if self.verbose: + print("Resumed because of empty cluster") # noqa: T001, T201 + self.norms_centroids_ = np.linalg.norm(self.cluster_centers_, axis=(1, 2)) + self._post_fit(X, best_correct_centroids, min_inertia) + return self + + def _post_fit(self, X_fitted, centroids, inertia): + if np.isfinite(inertia) and (centroids is not None): + self.cluster_centers_ = centroids + self._assign(X_fitted) + self._X_fit = X_fitted + self.inertia_ = inertia + else: + self._X_fit = None def _predict(self, X, y=None) -> np.ndarray: """Predict the closest cluster each sample in X belongs to. @@ -150,8 +330,37 @@ def _predict(self, X, y=None) -> np.ndarray: np.ndarray (1d array of shape (n_cases,)) Index of the cluster each time series in X belongs to. """ - _X = X.swapaxes(1, 2) - return self._tslearn_k_shapes.predict(_X) + # TODO remove dependence on cdist_normalized_cc + normaliser = Normalizer() + X_ = X.copy() + X_ = normaliser.fit_transform(X_) + X_temp = np.transpose(X_, (0, 2, 1)) + cluster_temp = np.transpose(self.cluster_centers_, (0, 2, 1)) + n1 = np.linalg.norm(X_temp, axis=(1, 2)) + n2 = np.linalg.norm(cluster_temp, axis=(1, 2)) + dists = 1.0 - cdist_normalized_cc(X_temp, cluster_temp, n1, n2, False) + return dists.argmin(axis=1) + + def fit_predict(self, X, y=None): + """Fit X using k-Shape clustering then predict the closest clusters. + + It is more efficient to use this method than to sequentially call fit + and predict. + + Parameters + ---------- + X : array-like of shape=(n_ts, sz, d) + Time series dataset to predict. + + y + Ignored + + Returns + ------- + labels : array of shape=(n_ts, ) + Index of the cluster each sample belongs to. + """ + return self._fit(X, y).labels_ @classmethod def _get_test_params(cls, parameter_set="default"): From 3fe34bc0e1cbe7bac0773a80526d93337256de4e Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Fri, 21 Mar 2025 22:20:08 +0530 Subject: [PATCH 02/16] K shape implementation --- aeon/clustering/_k_shape.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index b1314b7b0a..583a2464b1 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -143,15 +143,7 @@ def _check_no_empty_cluster(self, labels, n_clusters): raise EmptyClusterError def _compute_inertia(self, distances, assignments, squared=True): - """Derive inertia from pre-computed distances and assignments. - - Examples - -------- - >>> dists = numpy.array([[1., 2., 0.5], [0., 3., 1.]]) - >>> assign = numpy.array([2, 0]) - >>> _compute_inertia(dists, assign) - 0.125 - """ + """Derive inertia from pre-computed distances and assignments.""" n_ts = distances.shape[0] if squared: return np.sum(distances[np.arange(n_ts), assignments] ** 2) / n_ts From 840461847837f9219d4bed37a0ef6936683bcbc9 Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Fri, 21 Mar 2025 22:34:52 +0530 Subject: [PATCH 03/16] Minor changes --- aeon/clustering/_k_shape.py | 93 ++++++++++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index 583a2464b1..b13c777909 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -3,15 +3,104 @@ from typing import Callable, Optional, Union import numpy as np +from numba import njit, objmode, prange from numpy.random import RandomState from sklearn.utils import check_random_state -from tslearn.metrics import cdist_normalized_cc -from tslearn.metrics.cycc import normalized_cc from aeon.clustering.base import BaseClusterer from aeon.transformations.collection import Normalizer +@njit(fastmath=True) +def normalized_cc(s1, s2, norm1=-1.0, norm2=-1.0): + """Normalize cc. + + Parameters + ---------- + s1 : array-like, shape=(sz, d), dtype=float64 + A time series. + s2 : array-like, shape=(sz, d), dtype=float64 + Another time series. + norm1 : float64, default=-1.0 + norm2 : float64, default=-1.0 + + Returns + ------- + norm_cc : array-like, shape=(2 * sz - 1), dtype=float64 + """ + assert s1.shape[1] == s2.shape[1] + sz = s1.shape[0] + n_bits = 1 + int(np.log2(2 * sz - 1)) + fft_sz = 2**n_bits + + if norm1 < 0.0: + norm1 = np.linalg.norm(s1) + if norm2 < 0.0: + norm2 = np.linalg.norm(s2) + + denom = norm1 * norm2 + if denom < 1e-9: # To avoid NaNs + denom = np.inf + + with objmode(cc="float64[:, :]"): + cc = np.real( + np.fft.ifft( + np.fft.fft(s1, fft_sz, axis=0) + * np.conj(np.fft.fft(s2, fft_sz, axis=0)), + axis=0, + ) + ) + + cc = np.vstack((cc[-(sz - 1) :], cc[:sz])) + norm_cc = np.real(cc).sum(axis=-1) / denom + return norm_cc + + +@njit(parallel=True, fastmath=True) +def cdist_normalized_cc(dataset1, dataset2, norms1, norms2, self_similarity): + """Compute the distance matrix between two time series dataset. + + Parameters + ---------- + dataset1 : array-like, shape=(n_ts1, sz, d), dtype=float64 + A dataset of time series. + dataset2 : array-like, shape=(n_ts2, sz, d), dtype=float64 + Another dataset of time series. + norms1 : array-like, shape=(n_ts1,), dtype=float64 + norms2 : array-like, shape=(n_ts2,), dtype=float64 + self_similarity : bool + + Returns + ------- + dists : array-like, shape=(n_ts1, n_ts2), dtype=float64 + """ + n_ts1, sz, d = dataset1.shape + n_ts2 = dataset2.shape[0] + assert d == dataset2.shape[2] + dists = np.zeros((n_ts1, n_ts2)) + + if (norms1 < 0.0).any(): + for i_ts1 in prange(n_ts1): + norms1[i_ts1] = np.linalg.norm(dataset1[i_ts1, ...]) + if (norms2 < 0.0).any(): + for i_ts2 in prange(n_ts2): + norms2[i_ts2] = np.linalg.norm(dataset2[i_ts2, ...]) + if self_similarity: + for i in prange(1, n_ts1): + for j in range(i): + dists[i, j] = normalized_cc( + dataset1[i], dataset2[j], norm1=norms1[i], norm2=norms2[j] + ).max() + dists += dists.T + else: + for i in prange(n_ts1): + for j in range(n_ts2): + dists[i, j] = normalized_cc( + dataset1[i], dataset2[j], norm1=norms1[i], norm2=norms2[j] + ).max() + return dists + + class EmptyClusterError(Exception): """Error raised when an empty cluster is encountered.""" From 2a525bab3914b11295184abda7de6a4d7fb95d5e Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sat, 22 Mar 2025 19:48:54 +0530 Subject: [PATCH 04/16] kshape minor fixes --- aeon/clustering/_k_shape.py | 40 ++++++------------------------------- 1 file changed, 6 insertions(+), 34 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index b13c777909..f58b90173e 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -10,6 +10,8 @@ from aeon.clustering.base import BaseClusterer from aeon.transformations.collection import Normalizer +# from aeon.distances._distance import sbd_pairwise_distance + @njit(fastmath=True) def normalized_cc(s1, s2, norm1=-1.0, norm2=-1.0): @@ -245,6 +247,7 @@ def _assign(self, X): dists = 1.0 - cdist_normalized_cc( X_temp, cluster_temp, self.norms_, self.norms_centroids_, False ) + # dists = sbd_pairwise_distance(X, self.cluster_centers_) self.labels_ = dists.argmin(axis=1) self._check_no_empty_cluster(self.labels_, self.n_clusters) self.inertia_ = self._compute_inertia(dists, self.labels_) @@ -397,50 +400,19 @@ def _post_fit(self, X_fitted, centroids, inertia): self._X_fit = None def _predict(self, X, y=None) -> np.ndarray: - """Predict the closest cluster each sample in X belongs to. - - Parameters - ---------- - X: np.ndarray, of shape (n_cases, n_channels, n_timepoints) or - (n_cases, n_timepoints) - A collection of time series instances. - y: ignored, exists for API consistency reasons. - - Returns - ------- - np.ndarray (1d array of shape (n_cases,)) - Index of the cluster each time series in X belongs to. - """ # TODO remove dependence on cdist_normalized_cc - normaliser = Normalizer() + # normaliser = Normalizer() X_ = X.copy() - X_ = normaliser.fit_transform(X_) + # X_ = normaliser.fit_transform(X_) X_temp = np.transpose(X_, (0, 2, 1)) cluster_temp = np.transpose(self.cluster_centers_, (0, 2, 1)) n1 = np.linalg.norm(X_temp, axis=(1, 2)) n2 = np.linalg.norm(cluster_temp, axis=(1, 2)) dists = 1.0 - cdist_normalized_cc(X_temp, cluster_temp, n1, n2, False) + # dists = sbd_pairwise_distance(X_, self.cluster_centers_, standardize=False) return dists.argmin(axis=1) def fit_predict(self, X, y=None): - """Fit X using k-Shape clustering then predict the closest clusters. - - It is more efficient to use this method than to sequentially call fit - and predict. - - Parameters - ---------- - X : array-like of shape=(n_ts, sz, d) - Time series dataset to predict. - - y - Ignored - - Returns - ------- - labels : array of shape=(n_ts, ) - Index of the cluster each sample belongs to. - """ return self._fit(X, y).labels_ @classmethod From 0fec4d0f23c1a029d28325440ae63da061dc921b Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 01:15:52 +0530 Subject: [PATCH 05/16] minor changes --- aeon/clustering/_k_shape.py | 187 ++++++++++++------------------------ 1 file changed, 62 insertions(+), 125 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index f58b90173e..f2810d4069 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -14,31 +14,14 @@ @njit(fastmath=True) -def normalized_cc(s1, s2, norm1=-1.0, norm2=-1.0): - """Normalize cc. - - Parameters - ---------- - s1 : array-like, shape=(sz, d), dtype=float64 - A time series. - s2 : array-like, shape=(sz, d), dtype=float64 - Another time series. - norm1 : float64, default=-1.0 - norm2 : float64, default=-1.0 - - Returns - ------- - norm_cc : array-like, shape=(2 * sz - 1), dtype=float64 - """ +def normalized_cc(s1, s2): assert s1.shape[1] == s2.shape[1] - sz = s1.shape[0] - n_bits = 1 + int(np.log2(2 * sz - 1)) + n_timepoints = s1.shape[0] + n_bits = 1 + int(np.log2(2 * n_timepoints - 1)) fft_sz = 2**n_bits - if norm1 < 0.0: - norm1 = np.linalg.norm(s1) - if norm2 < 0.0: - norm2 = np.linalg.norm(s2) + norm1 = np.linalg.norm(s1) + norm2 = np.linalg.norm(s2) denom = norm1 * norm2 if denom < 1e-9: # To avoid NaNs @@ -53,53 +36,29 @@ def normalized_cc(s1, s2, norm1=-1.0, norm2=-1.0): ) ) - cc = np.vstack((cc[-(sz - 1) :], cc[:sz])) + cc = np.vstack((cc[-(n_timepoints - 1) :], cc[:n_timepoints])) norm_cc = np.real(cc).sum(axis=-1) / denom return norm_cc @njit(parallel=True, fastmath=True) -def cdist_normalized_cc(dataset1, dataset2, norms1, norms2, self_similarity): - """Compute the distance matrix between two time series dataset. - - Parameters - ---------- - dataset1 : array-like, shape=(n_ts1, sz, d), dtype=float64 - A dataset of time series. - dataset2 : array-like, shape=(n_ts2, sz, d), dtype=float64 - Another dataset of time series. - norms1 : array-like, shape=(n_ts1,), dtype=float64 - norms2 : array-like, shape=(n_ts2,), dtype=float64 - self_similarity : bool - - Returns - ------- - dists : array-like, shape=(n_ts1, n_ts2), dtype=float64 - """ - n_ts1, sz, d = dataset1.shape +def cdist_normalized_cc(dataset1, dataset2): + n_ts1, n_timepoints, n_channels = dataset1.shape n_ts2 = dataset2.shape[0] - assert d == dataset2.shape[2] + assert n_channels == dataset2.shape[2] dists = np.zeros((n_ts1, n_ts2)) - if (norms1 < 0.0).any(): - for i_ts1 in prange(n_ts1): - norms1[i_ts1] = np.linalg.norm(dataset1[i_ts1, ...]) - if (norms2 < 0.0).any(): - for i_ts2 in prange(n_ts2): - norms2[i_ts2] = np.linalg.norm(dataset2[i_ts2, ...]) - if self_similarity: - for i in prange(1, n_ts1): - for j in range(i): - dists[i, j] = normalized_cc( - dataset1[i], dataset2[j], norm1=norms1[i], norm2=norms2[j] - ).max() - dists += dists.T - else: - for i in prange(n_ts1): - for j in range(n_ts2): - dists[i, j] = normalized_cc( - dataset1[i], dataset2[j], norm1=norms1[i], norm2=norms2[j] - ).max() + norms1 = np.zeros(n_ts1) + norms2 = np.zeros(n_ts2) + for i_ts1 in prange(n_ts1): + norms1[i_ts1] = np.linalg.norm(dataset1[i_ts1, ...]) + + for i_ts2 in prange(n_ts2): + norms2[i_ts2] = np.linalg.norm(dataset2[i_ts2, ...]) + + for i in prange(n_ts1): + for j in range(n_ts2): + dists[i, j] = normalized_cc(dataset1[i], dataset2[j]).max() return dists @@ -235,85 +194,72 @@ def _check_no_empty_cluster(self, labels, n_clusters): def _compute_inertia(self, distances, assignments, squared=True): """Derive inertia from pre-computed distances and assignments.""" - n_ts = distances.shape[0] + n_cases = distances.shape[0] if squared: - return np.sum(distances[np.arange(n_ts), assignments] ** 2) / n_ts + return np.sum(distances[np.arange(n_cases), assignments] ** 2) / n_cases else: - return np.sum(distances[np.arange(n_ts), assignments]) / n_ts + return np.sum(distances[np.arange(n_cases), assignments]) / n_cases - def _assign(self, X): - X_temp = np.transpose(X, (0, 2, 1)) - cluster_temp = np.transpose(self.cluster_centers_, (0, 2, 1)) - dists = 1.0 - cdist_normalized_cc( - X_temp, cluster_temp, self.norms_, self.norms_centroids_, False + def _sbd_pairwise(self, X, Y): + # TODO remove dependence on cdist_normalized_cc + return 1.0 - cdist_normalized_cc( + np.transpose(X, (0, 2, 1)), + np.transpose(Y, (0, 2, 1)), ) - # dists = sbd_pairwise_distance(X, self.cluster_centers_) + + def _assign(self, X): + dists = self._sbd_pairwise(X, self.cluster_centers_) self.labels_ = dists.argmin(axis=1) self._check_no_empty_cluster(self.labels_, self.n_clusters) self.inertia_ = self._compute_inertia(dists, self.labels_) - def y_shifted_sbd_vec(self, ref_ts, dataset, norm_ref, norms_dataset): - n_ts = dataset.shape[0] - d = dataset.shape[1] - sz = dataset.shape[2] - assert sz == ref_ts.shape[1] and d == ref_ts.shape[0] - dataset_shifted = np.zeros((n_ts, d, sz)) + def _sbd_dist(self, X, Y): + return 1.0 - normalized_cc(np.transpose(X, (1, 0)), np.transpose(Y, (1, 0))) - if norm_ref < 0: - norm_ref = np.linalg.norm(ref_ts) - if (norms_dataset < 0.0).any(): - for i_ts in range(n_ts): - norms_dataset[i_ts] = np.linalg.norm(dataset[i_ts, ...]) - - for i in range(n_ts): + def _align_data_to_reference(self, partition_centroid, X_partition): + n_cases, n_channels, n_timepoints = X_partition.shape + aligned_X_to_centroid = np.zeros((n_cases, n_channels, n_timepoints)) + for i in range(n_cases): # TODO: remove dependency on normalized_cc - ref_ts_temp = ref_ts.T - dataset_temp = np.transpose(dataset, (0, 2, 1)) - cc = normalized_cc( - ref_ts_temp, dataset_temp[i], norm1=norm_ref, norm2=norms_dataset[i] - ) + cc = self._sbd_dist(partition_centroid, X_partition[i]) idx = np.argmax(cc) - shift = idx - sz - if shift > 0: - dataset_shifted[i, :, shift:] = dataset[i, :, : sz - shift] + shift = idx - n_timepoints + if shift >= 0: + aligned_X_to_centroid[i, :, shift:] = X_partition[ + i, :, : n_timepoints - shift + ] elif shift < 0: - dataset_shifted[i, :, : sz + shift] = dataset[i, :, -shift:] - else: - dataset_shifted[i] = dataset[i] + aligned_X_to_centroid[i, :, : n_timepoints + shift] = X_partition[ + i, :, -shift: + ] - return dataset_shifted + return aligned_X_to_centroid def _shape_extraction(self, X, k): - # X is of dim (n_ts, d, sz) - sz = X.shape[2] - d = X.shape[1] - Xp = self.y_shifted_sbd_vec( - self.cluster_centers_[k], - X[self.labels_ == k], - -1, - self.norms_[self.labels_ == k], + n_timepoints = X.shape[2] + n_channels = X.shape[1] + _X = self._align_data_to_reference( + self.cluster_centers_[k], X[self.labels_ == k] ) - # Xp is of dim (n_ts, d, sz) - S = np.dot(Xp[:, 0, :].T, Xp[:, 0, :]) - Q = np.eye(sz) - np.ones((sz, sz)) / sz + S = np.dot(_X[:, 0, :].T, _X[:, 0, :]) + Q = np.eye(n_timepoints) - np.ones((n_timepoints, n_timepoints)) / n_timepoints M = np.dot(Q.T, np.dot(S, Q)) _, vec = np.linalg.eigh(M) - mu_k = vec[:, -1].reshape((sz, 1)) + centroid = vec[:, -1].reshape((n_timepoints, 1)) - mu_k_broadcast = mu_k.reshape((1, 1, sz)) - dist_plus_mu = np.sum(np.linalg.norm(Xp - mu_k_broadcast, axis=(1, 2))) - dist_minus_mu = np.sum(np.linalg.norm(Xp + mu_k_broadcast, axis=(1, 2))) + mu_k_broadcast = centroid.reshape((1, 1, n_timepoints)) + dist_plus_mu = np.sum(np.linalg.norm(_X - mu_k_broadcast, axis=(1, 2))) + dist_minus_mu = np.sum(np.linalg.norm(_X + mu_k_broadcast, axis=(1, 2))) if dist_minus_mu < dist_plus_mu: - mu_k *= -1 + centroid *= -1 - d = Xp.shape[1] - mu_k = np.tile(mu_k.T, (d, 1)) - return mu_k + n_channels = _X.shape[1] + centroid = np.tile(centroid.T, (n_channels, 1)) + return centroid def _update_centroids(self, X): - # X is (n, d, sz) for k in range(self.n_clusters): self.cluster_centers_[k] = self._shape_extraction(X, k) @@ -400,16 +346,7 @@ def _post_fit(self, X_fitted, centroids, inertia): self._X_fit = None def _predict(self, X, y=None) -> np.ndarray: - # TODO remove dependence on cdist_normalized_cc - # normaliser = Normalizer() - X_ = X.copy() - # X_ = normaliser.fit_transform(X_) - X_temp = np.transpose(X_, (0, 2, 1)) - cluster_temp = np.transpose(self.cluster_centers_, (0, 2, 1)) - n1 = np.linalg.norm(X_temp, axis=(1, 2)) - n2 = np.linalg.norm(cluster_temp, axis=(1, 2)) - dists = 1.0 - cdist_normalized_cc(X_temp, cluster_temp, n1, n2, False) - # dists = sbd_pairwise_distance(X_, self.cluster_centers_, standardize=False) + dists = self._sbd_pairwise(X, self.cluster_centers_) return dists.argmin(axis=1) def fit_predict(self, X, y=None): From 108c63a40175ea0558d03d3a55bf02223a306c9d Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 01:47:52 +0530 Subject: [PATCH 06/16] Minor changes --- aeon/clustering/_k_shape.py | 62 ++++++++++++------------------------- 1 file changed, 20 insertions(+), 42 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index f2810d4069..4bba022f3b 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -160,22 +160,27 @@ def __init__( super().__init__() + def _incorrect_params_print(self): + return ( + f"The value provided for init: {self.init} is " + f"invalid. The following are a list of valid init algorithms " + f"strings: random. You can also pass a" + f"np.ndarray of size (n_clusters, n_channels, n_timepoints)" + ) + def _check_params(self, X: np.ndarray) -> None: self._random_state = check_random_state(self.random_state) if isinstance(self.init, str): if self.init == "random": self._init = self._random_center_initializer + else: + self._incorrect_params_print() else: if isinstance(self.init, np.ndarray) and len(self.init) == self.n_clusters: self._init = self.init.copy() else: - raise ValueError( - f"The value provided for init: {self.init} is " - f"invalid. The following are a list of valid init algorithms " - f"strings: random, kmedoids++, first. You can also pass a" - f"np.ndarray of size (n_clusters, n_channels, n_timepoints)" - ) + self._incorrect_params_print() if self.n_clusters > X.shape[0]: raise ValueError( @@ -192,13 +197,10 @@ def _check_no_empty_cluster(self, labels, n_clusters): if np.sum(labels == k) == 0: raise EmptyClusterError - def _compute_inertia(self, distances, assignments, squared=True): - """Derive inertia from pre-computed distances and assignments.""" + def _compute_inertia(self, distances, labels): + """Find inertia based on distances and labels.""" n_cases = distances.shape[0] - if squared: - return np.sum(distances[np.arange(n_cases), assignments] ** 2) / n_cases - else: - return np.sum(distances[np.arange(n_cases), assignments]) / n_cases + return np.sum(distances[np.arange(n_cases), labels] ** 2) / n_cases def _sbd_pairwise(self, X, Y): # TODO remove dependence on cdist_normalized_cc @@ -265,7 +267,6 @@ def _update_centroids(self, X): normaliser = Normalizer() self.cluster_centers_ = normaliser.fit_transform(self.cluster_centers_) - self.norms_centroids_ = np.linalg.norm(self.cluster_centers_, axis=(1, 2)) def _fit_one_init(self, X): if isinstance(self._init, Callable): @@ -273,7 +274,6 @@ def _fit_one_init(self, X): else: self.cluster_centers_ = self._init.copy() - self.norms_centroids_ = np.linalg.norm(self.cluster_centers_, axis=(1, 2)) self._assign(X) old_inertia = np.inf @@ -301,49 +301,27 @@ def _fit_one_init(self, X): return self def _fit(self, X, y=None): - # X = check_array(X, allow_nd=True) add aeon version self._check_params(X) - max_attempts = max(self.n_init, 10) - self.inertia_ = np.inf - - self.norms_ = 0.0 - self.norms_centroids_ = 0.0 - - self._X_fit = X - self.norms_ = np.linalg.norm(X, axis=(1, 2)) - best_correct_centroids = None min_inertia = np.inf - n_successful = 0 - n_attempts = 0 - while n_successful < self.n_init and n_attempts < max_attempts: + + for _ in range(self.n_init): try: - if self.verbose and self.n_init > 1: - print("Init %d" % (n_successful + 1)) # noqa: T001, T201 - n_attempts += 1 self._fit_one_init(X) if self.inertia_ < min_inertia: best_correct_centroids = self.cluster_centers_.copy() min_inertia = self.inertia_ self.n_iter_ = self._iter - n_successful += 1 except EmptyClusterError: if self.verbose: print("Resumed because of empty cluster") # noqa: T001, T201 - self.norms_centroids_ = np.linalg.norm(self.cluster_centers_, axis=(1, 2)) - self._post_fit(X, best_correct_centroids, min_inertia) - return self - def _post_fit(self, X_fitted, centroids, inertia): - if np.isfinite(inertia) and (centroids is not None): - self.cluster_centers_ = centroids - self._assign(X_fitted) - self._X_fit = X_fitted - self.inertia_ = inertia - else: - self._X_fit = None + self.cluster_centers_ = best_correct_centroids + self._assign(X) + self.inertia_ = min_inertia + return self def _predict(self, X, y=None) -> np.ndarray: dists = self._sbd_pairwise(X, self.cluster_centers_) From f84e071a897692cac95b5a2ee22567f5394d312d Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 02:50:29 +0530 Subject: [PATCH 07/16] Minor changes --- aeon/clustering/_k_shape.py | 93 ++++++++++++++++++------------------- 1 file changed, 45 insertions(+), 48 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index 4bba022f3b..ace6be202e 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -197,11 +197,6 @@ def _check_no_empty_cluster(self, labels, n_clusters): if np.sum(labels == k) == 0: raise EmptyClusterError - def _compute_inertia(self, distances, labels): - """Find inertia based on distances and labels.""" - n_cases = distances.shape[0] - return np.sum(distances[np.arange(n_cases), labels] ** 2) / n_cases - def _sbd_pairwise(self, X, Y): # TODO remove dependence on cdist_normalized_cc return 1.0 - cdist_normalized_cc( @@ -209,12 +204,6 @@ def _sbd_pairwise(self, X, Y): np.transpose(Y, (0, 2, 1)), ) - def _assign(self, X): - dists = self._sbd_pairwise(X, self.cluster_centers_) - self.labels_ = dists.argmin(axis=1) - self._check_no_empty_cluster(self.labels_, self.n_clusters) - self.inertia_ = self._compute_inertia(dists, self.labels_) - def _sbd_dist(self, X, Y): return 1.0 - normalized_cc(np.transpose(X, (1, 0)), np.transpose(Y, (1, 0))) @@ -237,15 +226,13 @@ def _align_data_to_reference(self, partition_centroid, X_partition): return aligned_X_to_centroid - def _shape_extraction(self, X, k): + def _shape_extraction(self, X, k, cluster_centers, labels): n_timepoints = X.shape[2] n_channels = X.shape[1] - _X = self._align_data_to_reference( - self.cluster_centers_[k], X[self.labels_ == k] - ) - S = np.dot(_X[:, 0, :].T, _X[:, 0, :]) + _X = self._align_data_to_reference(cluster_centers[k], X[labels == k]) + S = _X[:, 0, :].T @ _X[:, 0, :] Q = np.eye(n_timepoints) - np.ones((n_timepoints, n_timepoints)) / n_timepoints - M = np.dot(Q.T, np.dot(S, Q)) + M = Q.T @ S @ Q _, vec = np.linalg.eigh(M) centroid = vec[:, -1].reshape((n_timepoints, 1)) @@ -261,66 +248,76 @@ def _shape_extraction(self, X, k): centroid = np.tile(centroid.T, (n_channels, 1)) return centroid - def _update_centroids(self, X): + def _update_centroids(self, X, cluster_centers, labels): for k in range(self.n_clusters): - self.cluster_centers_[k] = self._shape_extraction(X, k) + cluster_centers[k] = self._shape_extraction(X, k, cluster_centers, labels) normaliser = Normalizer() - self.cluster_centers_ = normaliser.fit_transform(self.cluster_centers_) + return normaliser.fit_transform(cluster_centers) + + def _assign(self, X, cluster_centers): + dists = self._sbd_pairwise(X, cluster_centers) + labels = dists.argmin(axis=1) + inertia = dists.min(axis=0).sum() + self._check_no_empty_cluster(labels, self.n_clusters) + return labels, inertia def _fit_one_init(self, X): if isinstance(self._init, Callable): - self.cluster_centers_ = self._init(X) + cluster_centers = self._init(X) else: - self.cluster_centers_ = self._init.copy() - - self._assign(X) - old_inertia = np.inf + cluster_centers = self._init.copy() + cur_labels, _ = self._assign(X, cluster_centers) + prev_inertia = np.inf + prev_labels = None it = 0 for it in range(self.max_iter): # noqa: B007 - old_cluster_centers = self.cluster_centers_.copy() - self._update_centroids(X) - self._assign(X) + prev_centers = cluster_centers + cluster_centers = self._update_centroids(X, prev_centers, cur_labels) + cur_labels, cur_inertia = self._assign(X, cluster_centers) + if self.verbose: - print("%.3f" % self.inertia_, end=" --> ") # noqa: T001, T201 + print("%.3f" % cur_inertia, end=" --> ") # noqa: T001, T201 - if np.abs(old_inertia - self.inertia_) < self.tol or ( - old_inertia - self.inertia_ < 0 + if np.abs(prev_inertia - cur_inertia) < self.tol or ( + prev_inertia - cur_inertia < 0 ): - self.cluster_centers_ = old_cluster_centers - self._assign(X) + cluster_centers = prev_centers + cur_labels, cur_inertia = self._assign(X, cluster_centers) break - old_inertia = self.inertia_ + prev_inertia = cur_inertia + prev_labels = cur_labels if self.verbose: print("") # noqa: T001, T201 - self._iter = it + 1 - - return self + return prev_labels, cluster_centers, prev_inertia, it + 1 def _fit(self, X, y=None): self._check_params(X) - self.inertia_ = np.inf - best_correct_centroids = None - min_inertia = np.inf + best_centroids = None + best_inertia = np.inf + best_labels = None + best_iters = self.max_iter for _ in range(self.n_init): try: - self._fit_one_init(X) - if self.inertia_ < min_inertia: - best_correct_centroids = self.cluster_centers_.copy() - min_inertia = self.inertia_ - self.n_iter_ = self._iter + labels, centers, inertia, n_iters = self._fit_one_init(X) + if inertia < best_inertia: + best_centroids = centers + best_labels = labels + best_iters = n_iters + best_inertia = inertia except EmptyClusterError: if self.verbose: print("Resumed because of empty cluster") # noqa: T001, T201 - self.cluster_centers_ = best_correct_centroids - self._assign(X) - self.inertia_ = min_inertia + self.cluster_centers_ = best_centroids + self.inertia_ = best_inertia + self.labels_ = best_labels + self.n_iter_ = best_iters return self def _predict(self, X, y=None) -> np.ndarray: From 797e814dc255f73258adfd428695a195e4af0742 Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 02:55:21 +0530 Subject: [PATCH 08/16] Minor changes --- aeon/clustering/_k_shape.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index ace6be202e..087187a323 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -248,13 +248,6 @@ def _shape_extraction(self, X, k, cluster_centers, labels): centroid = np.tile(centroid.T, (n_channels, 1)) return centroid - def _update_centroids(self, X, cluster_centers, labels): - for k in range(self.n_clusters): - cluster_centers[k] = self._shape_extraction(X, k, cluster_centers, labels) - - normaliser = Normalizer() - return normaliser.fit_transform(cluster_centers) - def _assign(self, X, cluster_centers): dists = self._sbd_pairwise(X, cluster_centers) labels = dists.argmin(axis=1) @@ -274,7 +267,15 @@ def _fit_one_init(self, X): it = 0 for it in range(self.max_iter): # noqa: B007 prev_centers = cluster_centers - cluster_centers = self._update_centroids(X, prev_centers, cur_labels) + + # Refinement step + for k in range(self.n_clusters): + cluster_centers[k] = self._shape_extraction( + X, k, cluster_centers, cur_labels + ) + cluster_centers = Normalizer().fit_transform(cluster_centers) + + # Assignment step cur_labels, cur_inertia = self._assign(X, cluster_centers) if self.verbose: From 2b09d2134296ee11048d4938db3e6ba495efcbba Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 03:46:53 +0530 Subject: [PATCH 09/16] Minor changes --- aeon/clustering/_k_shape.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index 087187a323..f6cf4145b5 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -192,7 +192,6 @@ def _random_center_initializer(self, X: np.ndarray) -> np.ndarray: return X[self._random_state.choice(X.shape[0], self.n_clusters)] def _check_no_empty_cluster(self, labels, n_clusters): - """Check that all clusters have at least one sample assigned.""" for k in range(n_clusters): if np.sum(labels == k) == 0: raise EmptyClusterError @@ -252,7 +251,11 @@ def _assign(self, X, cluster_centers): dists = self._sbd_pairwise(X, cluster_centers) labels = dists.argmin(axis=1) inertia = dists.min(axis=0).sum() - self._check_no_empty_cluster(labels, self.n_clusters) + + for i in range(self.n_clusters): + if np.sum(labels == i) == 0: + raise EmptyClusterError + return labels, inertia def _fit_one_init(self, X): From 41faa3c15a3117c56704d11bb6869bc4fa7e41af Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 03:51:36 +0530 Subject: [PATCH 10/16] Minor changes --- aeon/clustering/_k_shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index f6cf4145b5..8e1b6cb098 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -69,7 +69,7 @@ class EmptyClusterError(Exception): class TimeSeriesKShape(BaseClusterer): - """Kshape algorithm: wrapper of the ``tslearn`` implementation. + """Kshape algorithm: inspired by ``tslearn`` implementation. Parameters ---------- From f41376161adf60080661ecc567e3a5d82a2e88b0 Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 03:54:35 +0530 Subject: [PATCH 11/16] minor changes --- aeon/clustering/_k_shape.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index 8e1b6cb098..a405b3ab00 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -71,6 +71,10 @@ class EmptyClusterError(Exception): class TimeSeriesKShape(BaseClusterer): """Kshape algorithm: inspired by ``tslearn`` implementation. + Implementation References: + 1. https://github.com/tslearn-team/tslearn/blob/9937946/tslearn/clustering/kshape.py#L21-L291 # noqa: E501 + 2. https://github.com/TheDatumOrg/kshape-python + Parameters ---------- n_clusters: int, default=8 From 50316375fdaeada5ad5fd433265d22ad265f1136 Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 03:58:49 +0530 Subject: [PATCH 12/16] minor changes --- aeon/clustering/_k_shape.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index a405b3ab00..8ce6f0ed7d 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -303,6 +303,20 @@ def _fit_one_init(self, X): return prev_labels, cluster_centers, prev_inertia, it + 1 def _fit(self, X, y=None): + """Fit time series clusterer to training data. + + Parameters + ---------- + X: np.ndarray, of shape (n_cases, n_channels, n_timepoints) or + (n_cases, n_timepoints) + A collection of time series instances. + y: ignored, exists for API consistency reasons. + + Returns + ------- + self: + Fitted estimator. + """ self._check_params(X) best_centroids = None @@ -329,12 +343,23 @@ def _fit(self, X, y=None): return self def _predict(self, X, y=None) -> np.ndarray: + """Predict the closest cluster each sample in X belongs to. + + Parameters + ---------- + X: np.ndarray, of shape (n_cases, n_channels, n_timepoints) or + (n_cases, n_timepoints) + A collection of time series instances. + y: ignored, exists for API consistency reasons. + + Returns + ------- + np.ndarray (1d array of shape (n_cases,)) + Index of the cluster each time series in X belongs to. + """ dists = self._sbd_pairwise(X, self.cluster_centers_) return dists.argmin(axis=1) - def fit_predict(self, X, y=None): - return self._fit(X, y).labels_ - @classmethod def _get_test_params(cls, parameter_set="default"): """Return testing parameter settings for the estimator. From 972d474ca771520e47eb6487fc3552dc78103e85 Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 04:29:43 +0530 Subject: [PATCH 13/16] update inertia calculation --- aeon/clustering/_k_shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index 8ce6f0ed7d..1018bfb396 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -254,7 +254,7 @@ def _shape_extraction(self, X, k, cluster_centers, labels): def _assign(self, X, cluster_centers): dists = self._sbd_pairwise(X, cluster_centers) labels = dists.argmin(axis=1) - inertia = dists.min(axis=0).sum() + inertia = dists.min(axis=1).sum() for i in range(self.n_clusters): if np.sum(labels == i) == 0: From 282df783d0d84c2f340498b3359807b80df11b8f Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Sun, 23 Mar 2025 04:47:12 +0530 Subject: [PATCH 14/16] Minor changes --- aeon/clustering/_k_shape.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index 1018bfb396..751c7ab646 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -201,7 +201,6 @@ def _check_no_empty_cluster(self, labels, n_clusters): raise EmptyClusterError def _sbd_pairwise(self, X, Y): - # TODO remove dependence on cdist_normalized_cc return 1.0 - cdist_normalized_cc( np.transpose(X, (0, 2, 1)), np.transpose(Y, (0, 2, 1)), @@ -214,7 +213,6 @@ def _align_data_to_reference(self, partition_centroid, X_partition): n_cases, n_channels, n_timepoints = X_partition.shape aligned_X_to_centroid = np.zeros((n_cases, n_channels, n_timepoints)) for i in range(n_cases): - # TODO: remove dependency on normalized_cc cc = self._sbd_dist(partition_centroid, X_partition[i]) idx = np.argmax(cc) shift = idx - n_timepoints From 69eb71957ac7f3bd5379c92a0ee5845b7d216cb6 Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Mon, 24 Mar 2025 03:01:23 +0530 Subject: [PATCH 15/16] Fixes failing tests --- aeon/clustering/_k_shape.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index 751c7ab646..3936088de9 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -266,6 +266,8 @@ def _fit_one_init(self, X): else: cluster_centers = self._init.copy() + self.cluster_centers_ = cluster_centers + cur_labels, _ = self._assign(X, cluster_centers) prev_inertia = np.inf prev_labels = None From 3d31d5feb953dc7a52774b6a4d14858dfd4c3aa1 Mon Sep 17 00:00:00 2001 From: tanishy7777 <23110328@iitgn.ac.in> Date: Mon, 24 Mar 2025 03:15:31 +0530 Subject: [PATCH 16/16] Fixes failing tests --- aeon/clustering/_k_shape.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aeon/clustering/_k_shape.py b/aeon/clustering/_k_shape.py index 3936088de9..3b6bfeb886 100644 --- a/aeon/clustering/_k_shape.py +++ b/aeon/clustering/_k_shape.py @@ -336,7 +336,8 @@ def _fit(self, X, y=None): if self.verbose: print("Resumed because of empty cluster") # noqa: T001, T201 - self.cluster_centers_ = best_centroids + if best_centroids is not None: + self.cluster_centers_ = best_centroids self.inertia_ = best_inertia self.labels_ = best_labels self.n_iter_ = best_iters