Skip to content

[ENH] Implement K-Shape clusterer #2676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
250 changes: 229 additions & 21 deletions aeon/clustering/_k_shape.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,79 @@
"""Time series kshapes."""

from typing import Optional, Union
from typing import Callable, Optional, Union

import numpy as np
from numba import njit, objmode, prange
from numpy.random import RandomState
from sklearn.utils import check_random_state

from aeon.clustering.base import BaseClusterer
from aeon.transformations.collection import Normalizer

# from aeon.distances._distance import sbd_pairwise_distance


@njit(fastmath=True)
def normalized_cc(s1, s2):
assert s1.shape[1] == s2.shape[1]
n_timepoints = s1.shape[0]
n_bits = 1 + int(np.log2(2 * n_timepoints - 1))
fft_sz = 2**n_bits

norm1 = np.linalg.norm(s1)
norm2 = np.linalg.norm(s2)

denom = norm1 * norm2
if denom < 1e-9: # To avoid NaNs
denom = np.inf

with objmode(cc="float64[:, :]"):
cc = np.real(
np.fft.ifft(
np.fft.fft(s1, fft_sz, axis=0)
* np.conj(np.fft.fft(s2, fft_sz, axis=0)),
axis=0,
)
)

cc = np.vstack((cc[-(n_timepoints - 1) :], cc[:n_timepoints]))
norm_cc = np.real(cc).sum(axis=-1) / denom
return norm_cc


@njit(parallel=True, fastmath=True)
def cdist_normalized_cc(dataset1, dataset2):
n_ts1, n_timepoints, n_channels = dataset1.shape
n_ts2 = dataset2.shape[0]
assert n_channels == dataset2.shape[2]
dists = np.zeros((n_ts1, n_ts2))

norms1 = np.zeros(n_ts1)
norms2 = np.zeros(n_ts2)
for i_ts1 in prange(n_ts1):
norms1[i_ts1] = np.linalg.norm(dataset1[i_ts1, ...])

for i_ts2 in prange(n_ts2):
norms2[i_ts2] = np.linalg.norm(dataset2[i_ts2, ...])

for i in prange(n_ts1):
for j in range(n_ts2):
dists[i, j] = normalized_cc(dataset1[i], dataset2[j]).max()
return dists


class EmptyClusterError(Exception):
"""Error raised when an empty cluster is encountered."""

pass


class TimeSeriesKShape(BaseClusterer):
"""Kshape algorithm: wrapper of the ``tslearn`` implementation.
"""Kshape algorithm: inspired by ``tslearn`` implementation.

Implementation References:
1. https://github.com/tslearn-team/tslearn/blob/9937946/tslearn/clustering/kshape.py#L21-L291 # noqa: E501
2. https://github.com/TheDatumOrg/kshape-python

Parameters
----------
Expand Down Expand Up @@ -100,6 +164,144 @@ def __init__(

super().__init__()

def _incorrect_params_print(self):
return (
f"The value provided for init: {self.init} is "
f"invalid. The following are a list of valid init algorithms "
f"strings: random. You can also pass a"
f"np.ndarray of size (n_clusters, n_channels, n_timepoints)"
)

def _check_params(self, X: np.ndarray) -> None:
self._random_state = check_random_state(self.random_state)

if isinstance(self.init, str):
if self.init == "random":
self._init = self._random_center_initializer
else:
self._incorrect_params_print()
else:
if isinstance(self.init, np.ndarray) and len(self.init) == self.n_clusters:
self._init = self.init.copy()
else:
self._incorrect_params_print()

if self.n_clusters > X.shape[0]:
raise ValueError(
f"n_clusters ({self.n_clusters}) cannot be larger than "
f"n_cases ({X.shape[0]})"
)

def _random_center_initializer(self, X: np.ndarray) -> np.ndarray:
return X[self._random_state.choice(X.shape[0], self.n_clusters)]

def _check_no_empty_cluster(self, labels, n_clusters):
for k in range(n_clusters):
if np.sum(labels == k) == 0:
raise EmptyClusterError

def _sbd_pairwise(self, X, Y):
return 1.0 - cdist_normalized_cc(
np.transpose(X, (0, 2, 1)),
np.transpose(Y, (0, 2, 1)),
)

def _sbd_dist(self, X, Y):
return 1.0 - normalized_cc(np.transpose(X, (1, 0)), np.transpose(Y, (1, 0)))

def _align_data_to_reference(self, partition_centroid, X_partition):
n_cases, n_channels, n_timepoints = X_partition.shape
aligned_X_to_centroid = np.zeros((n_cases, n_channels, n_timepoints))
for i in range(n_cases):
cc = self._sbd_dist(partition_centroid, X_partition[i])
idx = np.argmax(cc)
shift = idx - n_timepoints
if shift >= 0:
aligned_X_to_centroid[i, :, shift:] = X_partition[
i, :, : n_timepoints - shift
]
elif shift < 0:
aligned_X_to_centroid[i, :, : n_timepoints + shift] = X_partition[
i, :, -shift:
]

return aligned_X_to_centroid

def _shape_extraction(self, X, k, cluster_centers, labels):
n_timepoints = X.shape[2]
n_channels = X.shape[1]
_X = self._align_data_to_reference(cluster_centers[k], X[labels == k])
S = _X[:, 0, :].T @ _X[:, 0, :]
Q = np.eye(n_timepoints) - np.ones((n_timepoints, n_timepoints)) / n_timepoints
M = Q.T @ S @ Q

_, vec = np.linalg.eigh(M)
centroid = vec[:, -1].reshape((n_timepoints, 1))

mu_k_broadcast = centroid.reshape((1, 1, n_timepoints))
dist_plus_mu = np.sum(np.linalg.norm(_X - mu_k_broadcast, axis=(1, 2)))
dist_minus_mu = np.sum(np.linalg.norm(_X + mu_k_broadcast, axis=(1, 2)))

if dist_minus_mu < dist_plus_mu:
centroid *= -1

n_channels = _X.shape[1]
centroid = np.tile(centroid.T, (n_channels, 1))
return centroid

def _assign(self, X, cluster_centers):
dists = self._sbd_pairwise(X, cluster_centers)
labels = dists.argmin(axis=1)
inertia = dists.min(axis=1).sum()

for i in range(self.n_clusters):
if np.sum(labels == i) == 0:
raise EmptyClusterError

return labels, inertia

def _fit_one_init(self, X):
if isinstance(self._init, Callable):
cluster_centers = self._init(X)
else:
cluster_centers = self._init.copy()

self.cluster_centers_ = cluster_centers

cur_labels, _ = self._assign(X, cluster_centers)
prev_inertia = np.inf
prev_labels = None
it = 0
for it in range(self.max_iter): # noqa: B007
prev_centers = cluster_centers

# Refinement step
for k in range(self.n_clusters):
cluster_centers[k] = self._shape_extraction(
X, k, cluster_centers, cur_labels
)
cluster_centers = Normalizer().fit_transform(cluster_centers)

# Assignment step
cur_labels, cur_inertia = self._assign(X, cluster_centers)

if self.verbose:
print("%.3f" % cur_inertia, end=" --> ") # noqa: T001, T201

if np.abs(prev_inertia - cur_inertia) < self.tol or (
prev_inertia - cur_inertia < 0
):
cluster_centers = prev_centers
cur_labels, cur_inertia = self._assign(X, cluster_centers)
break

prev_inertia = cur_inertia
prev_labels = cur_labels
if self.verbose:
print("") # noqa: T001, T201

return prev_labels, cluster_centers, prev_inertia, it + 1

def _fit(self, X, y=None):
"""Fit time series clusterer to training data.

Expand All @@ -115,25 +317,31 @@ def _fit(self, X, y=None):
self:
Fitted estimator.
"""
from tslearn.clustering import KShape

self._tslearn_k_shapes = KShape(
n_clusters=self.n_clusters,
max_iter=self.max_iter,
tol=self.tol,
random_state=self.random_state,
n_init=self.n_init,
verbose=self.verbose,
init=self.init,
)
self._check_params(X)

best_centroids = None
best_inertia = np.inf
best_labels = None
best_iters = self.max_iter

_X = X.swapaxes(1, 2)
for _ in range(self.n_init):
try:
labels, centers, inertia, n_iters = self._fit_one_init(X)
if inertia < best_inertia:
best_centroids = centers
best_labels = labels
best_iters = n_iters
best_inertia = inertia
except EmptyClusterError:
if self.verbose:
print("Resumed because of empty cluster") # noqa: T001, T201

self._tslearn_k_shapes.fit(_X)
self._cluster_centers = self._tslearn_k_shapes.cluster_centers_
self.labels_ = self._tslearn_k_shapes.predict(_X)
self.inertia_ = self._tslearn_k_shapes.inertia_
self.n_iter_ = self._tslearn_k_shapes.n_iter_
if best_centroids is not None:
self.cluster_centers_ = best_centroids
self.inertia_ = best_inertia
self.labels_ = best_labels
self.n_iter_ = best_iters
return self

def _predict(self, X, y=None) -> np.ndarray:
"""Predict the closest cluster each sample in X belongs to.
Expand All @@ -150,8 +358,8 @@ def _predict(self, X, y=None) -> np.ndarray:
np.ndarray (1d array of shape (n_cases,))
Index of the cluster each time series in X belongs to.
"""
_X = X.swapaxes(1, 2)
return self._tslearn_k_shapes.predict(_X)
dists = self._sbd_pairwise(X, self.cluster_centers_)
return dists.argmin(axis=1)

@classmethod
def _get_test_params(cls, parameter_set="default"):
Expand Down