From ce118c0a607dc12fec6b50540764ab908dfaee97 Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Tue, 20 May 2025 14:08:09 +0200 Subject: [PATCH 1/9] feat: add kdtw distance implementation (currently still a similarity and tests are broken) --- aeon/distances/__init__.py | 10 + aeon/distances/_distance.py | 24 + aeon/distances/kernel/__init__.py | 15 + aeon/distances/kernel/_kdtw.py | 417 ++++++++++++++++++ aeon/distances/kernel/tests/__init__.py | 1 + .../kernel/tests/test_distance_correctness.py | 70 +++ .../tests/test_numba_distance_parameters.py | 4 + .../expected_distance_results.py | 11 + 8 files changed, 552 insertions(+) create mode 100644 aeon/distances/kernel/__init__.py create mode 100644 aeon/distances/kernel/_kdtw.py create mode 100644 aeon/distances/kernel/tests/__init__.py create mode 100644 aeon/distances/kernel/tests/test_distance_correctness.py diff --git a/aeon/distances/__init__.py b/aeon/distances/__init__.py index d6ff3f776a..e218cc3fb7 100644 --- a/aeon/distances/__init__.py +++ b/aeon/distances/__init__.py @@ -82,6 +82,10 @@ "soft_dtw_pairwise_distance", "soft_dtw_alignment_path", "soft_dtw_cost_matrix", + "kdtw_distance", + "kdtw_alignment_path", + "kdtw_cost_matrix", + "kdtw_pairwise_distance", ] from aeon.distances._distance import ( @@ -157,6 +161,12 @@ wdtw_distance, wdtw_pairwise_distance, ) +from aeon.distances.kernel import ( + kdtw_alignment_path, + kdtw_cost_matrix, + kdtw_distance, + kdtw_pairwise_distance, +) from aeon.distances.mindist._dft_sfa import mindist_dft_sfa_distance from aeon.distances.mindist._paa_sax import mindist_paa_sax_distance from aeon.distances.mindist._sax import mindist_sax_distance diff --git a/aeon/distances/_distance.py b/aeon/distances/_distance.py index 33f9141440..c227ab577c 100644 --- a/aeon/distances/_distance.py +++ b/aeon/distances/_distance.py @@ -66,6 +66,12 @@ wdtw_distance, wdtw_pairwise_distance, ) +from aeon.distances.kernel import ( + kdtw_alignment_path, + kdtw_cost_matrix, + kdtw_distance, + kdtw_pairwise_distance, +) from aeon.distances.mindist import ( mindist_dft_sfa_distance, mindist_dft_sfa_pairwise_distance, @@ -109,6 +115,8 @@ class DistanceKwargs(TypedDict, total=False): m: int max_shift: Optional[int] gamma: float + sigma: float + noramlize: bool DistanceFunction = Callable[[np.ndarray, np.ndarray, Any], float] @@ -469,6 +477,7 @@ def get_distance_function(method: Union[str, DistanceFunction]) -> DistanceFunct 'sbd' distances.sbd_distance 'shift_scale' distances.shift_scale_invariant_distance 'soft_dtw' distances.soft_dtw_distance + 'kdtw' distances.kdtw_distance =============== ======================================== Parameters @@ -528,6 +537,7 @@ def get_pairwise_distance_function( 'sbd' distances.sbd_pairwise_distance 'shift_scale' distances.shift_scale_invariant_pairwise_distance 'soft_dtw' distances.soft_dtw_pairwise_distance + 'kdtw' distances.kdtw_pairwise_distance =============== ======================================== Parameters @@ -582,6 +592,7 @@ def get_alignment_path_function(method: str) -> AlignmentPathFunction: 'twe' distances.twe_alignment_path 'lcss' distances.lcss_alignment_path 'soft_dtw' distances.soft_dtw_alignment_path + 'kdtw' distances.kdtw_alignment_path =============== ======================================== Parameters @@ -631,6 +642,7 @@ def get_cost_matrix_function(method: str) -> CostMatrixFunction: 'twe' distances.twe_cost_matrix 'lcss' distances.lcss_cost_matrix 'soft_dtw' distances.soft_dtw_cost_matrix + 'kdtw' distances.kdtw_cost_matrix =============== ======================================== Parameters @@ -685,6 +697,7 @@ class DistanceType(Enum): POINTWISE = "pointwise" ELASTIC = "elastic" + KERNEL = "kernel" CROSS_CORRELATION = "cross-correlation" MIN_DISTANCE = "min-dist" MATRIX_PROFILE = "matrix-profile" @@ -909,6 +922,16 @@ class DistanceType(Enum): "symmetric": True, "unequal_support": True, }, + { + "name": "kdtw", + "distance": kdtw_distance, + "pairwise_distance": kdtw_pairwise_distance, + "cost_matrix": kdtw_cost_matrix, + "alignment_path": kdtw_alignment_path, + "type": DistanceType.KERNEL, + "symmetric": True, + "unequal_support": True, + }, ] DISTANCES_DICT = {d["name"]: d for d in DISTANCES} @@ -922,6 +945,7 @@ class DistanceType(Enum): ] ELASTIC_DISTANCES = [d["name"] for d in DISTANCES if d["type"] == DistanceType.ELASTIC] +KERNEL_DISTANCES = [d["name"] for d in DISTANCES if d["type"] == DistanceType.KERNEL] POINTWISE_DISTANCES = [ d["name"] for d in DISTANCES if d["type"] == DistanceType.POINTWISE ] diff --git a/aeon/distances/kernel/__init__.py b/aeon/distances/kernel/__init__.py new file mode 100644 index 0000000000..717ba943c0 --- /dev/null +++ b/aeon/distances/kernel/__init__.py @@ -0,0 +1,15 @@ +"""Kernel distances.""" + +__all__ = [ + "kdtw_distance", + "kdtw_alignment_path", + "kdtw_cost_matrix", + "kdtw_pairwise_distance", +] + +from aeon.distances.kernel._kdtw import ( + kdtw_alignment_path, + kdtw_cost_matrix, + kdtw_distance, + kdtw_pairwise_distance, +) diff --git a/aeon/distances/kernel/_kdtw.py b/aeon/distances/kernel/_kdtw.py new file mode 100644 index 0000000000..adccc9f9a7 --- /dev/null +++ b/aeon/distances/kernel/_kdtw.py @@ -0,0 +1,417 @@ +"""Dynamic time warping kernel (KDTW) distance between two time series.""" + +__maintainer__ = ["SebastianSchmidl"] + +from typing import Optional, Union + +import numpy as np +from numba import njit +from numba.typed import List as NumbaList + +from aeon.distances.elastic._alignment_paths import compute_min_return_path +from aeon.distances.pointwise import squared_distance +from aeon.utils.conversion._convert_collection import _convert_collection_to_numba_list +from aeon.utils.validation.collection import _is_numpy_list_multivariate + + +@njit(cache=True, fastmath=True) +def kdtw_distance( + x: np.ndarray, + y: np.ndarray, + gamma: float = 0.125, + epsilon: float = 1e-20, + normalize: bool = True, +) -> float: + r"""Compute the DTW kernel (KDTW) between two time series as a distance. + + KDTW is a similarity measure constructed from DTW and was introduced in [1]_. It has + the property that it is invariant to shifts in the time series. The kernel is + positive definite. This implementation provides a normalized distance [2]_ and + takes the default values from [2]_. Details can be found online: + https://people.irisa.fr/Pierre-Francois.Marteau/REDK/KDTW/KDTW.html + + Intuition of constructing a DTW kernel from DTW: + Instead of keeping only one of the best alignment paths, the new kernel will try to + sum up the costs of all the existing sub-sequence alignment paths with some + weighting factor that will favor good alignments while penalizing bad alignments. + + The current implementation performs no bounding on the dynamic programming matrix + (cost matrix) (uses no corridors). + + Parameters + ---------- + x : np.ndarray + First time series, either univariate, shape ``(n_timepoints,)``, or + multivariate, shape ``(n_channels, n_timepoints)``. + y : np.ndarray + Second time series, either univariate, shape ``(n_timepoints,)``, or + multivariate, shape ``(n_channels, n_timepoints)``. + gamma : float, default=0.125 + bandwidth parameter which weights the local contributions, i.e. the distances + between locally aligned positions. + epsilon : float, default=1e-20 + Small value to avoid zero. The default is 1e-20. + normalize : bool, default=True + Whether to normalize hte distance to make it shift invariant. The default is + True. + + Returns + ------- + float + KDTW distance between x and y. + + Raises + ------ + ValueError + If x and y are not 1D or 2D arrays. + + References + ---------- + .. [1] Pierre-François Marteau and Sylvie Gibet: On recursive edit distance kernels + with application to time series classification. IEEE Transactions on Neural + Networks and Learning Systems 26(6), 2014, pages 1121 - 1133. + + .. [2] Paparrizos, John, Chunwei Liu, Aaron J. Elmore, and Michael J. Franklin: + Debunking Four Long-Standing Misconceptions of Time-Series Distance Measures. In + Proceedings of the International Conference on Management of Data (SIGMOD), + 1887-1905, 2020. https://doi.org/10.1145/3318464.3389760. + + Examples + -------- + >>> import numpy as np + >>> from aeon.distances import kdtw_distance + >>> x = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + >>> y = np.array([[11, 12, 13, 14, 15, 16, 17, 18, 19, 20]]) + >>> dist = kdtw_distance(x, y) + 4.613613743977605e-203 + """ + if x.ndim == 1 and y.ndim == 1: + _x = x.reshape((1, x.shape[0])) + _y = y.reshape((1, y.shape[0])) + return _kdtw_distance(_x, _y, gamma, epsilon, normalize) + if x.ndim == 2 and y.ndim == 2: + return _kdtw_distance(x, y, gamma, epsilon, normalize) + raise ValueError("x and y must be 1D or 2D") + + +@njit(cache=True, fastmath=True) +def kdtw_cost_matrix( + x: np.ndarray, + y: np.ndarray, + gamma: float = 0.125, + epsilon: float = 1e-20, +) -> np.ndarray: + """Compute the cost matrix for KDTW between two time series. + + Parameters + ---------- + x : np.ndarray + First time series, either univariate, shape ``(n_timepoints,)``, or + multivariate, shape ``(n_channels, n_timepoints)``. + y : np.ndarray + Second time series, either univariate, shape ``(n_timepoints,)``, or + multivariate, shape ``(n_channels, n_timepoints)``. + gamma : float, default=0.125 + bandwidth parameter which weights the local contributions, i.e. the distances + between locally aligned positions. + epsilon : float, default=1e-20 + Small value to avoid zero. The default is 1e-20. + + Returns + ------- + np.ndarray (n_timepoints_x, n_timepoints_y) + KDTW cost matrix between x and y. + + Raises + ------ + ValueError + If x and y are not 1D or 2D arrays. + + Examples + -------- + >>> import numpy as np + >>> from aeon.distances import kdtw_cost_matrix + >>> x = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + >>> y = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + >>> kdtw_cost_matrix(x, y) + array([[ 0., 2., 4., 6., 8., 10., 12., 14., 16., 18.], + [ 2., 0., 2., 4., 6., 8., 10., 12., 14., 16.], + [ 4., 2., 0., 2., 4., 6., 8., 10., 12., 14.], + [ 6., 4., 2., 0., 2., 4., 6., 8., 10., 12.], + [ 8., 6., 4., 2., 0., 2., 4., 6., 8., 10.], + [10., 8., 6., 4., 2., 0., 2., 4., 6., 8.], + [12., 10., 8., 6., 4., 2., 0., 2., 4., 6.], + [14., 12., 10., 8., 6., 4., 2., 0., 2., 4.], + [16., 14., 12., 10., 8., 6., 4., 2., 0., 2.], + [18., 16., 14., 12., 10., 8., 6., 4., 2., 0.]]) + """ + if x.ndim == 1 and y.ndim == 1: + _x = x.reshape((1, x.shape[0])) + _y = y.reshape((1, y.shape[0])) + return _kdtw_cost_matrix(_x, _y, gamma, epsilon) + if x.ndim == 2 and y.ndim == 2: + return _kdtw_cost_matrix(x, y, gamma, epsilon) + raise ValueError("x and y must be 1D or 2D") + + +@njit(cache=True, fastmath=True) +def _kdtw_distance( + x: np.ndarray, + y: np.ndarray, + gamma: float, + epsilon: float, + normalize: bool, +) -> float: + # Do not subtract one because the cost matrix is one dimension larger: + n = x.shape[-1] + m = y.shape[-1] + if normalize: + self_x = _kdtw_cost_matrix(x, x, gamma, epsilon)[n, n] + self_y = _kdtw_cost_matrix(y, y, gamma, epsilon)[m, m] + norm_factor = np.sqrt(self_x * self_y) + else: + norm_factor = 1.0 + return _kdtw_cost_matrix(x, y, gamma, epsilon)[n, m] / norm_factor + + +@njit(cache=True, fastmath=True) +def _local_kernel( + x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float +) -> np.ndarray: + factor = 1.0 / 3.0 + # Incoming shape is (n_channels, n_timepoints) + # We want to calculate the multivariate squared distance between the two time series + # considering each point in the time series as a separate instance, thus we need to + # reshape to (m_cases, m_channels, 1), where m_cases = n_timepoints and + # m_channels = n_channels. + x = x.T.reshape(-1, x.shape[0], 1) + y = y.T.reshape(-1, y.shape[0], 1) + n_cases = len(x) + m_cases = len(y) + distances = np.zeros((n_cases, m_cases)) + + for i in range(n_cases): + for j in range(m_cases): + # expects each input to have shape (n_channels, n_timepoints = 1) + distances[i, j] = squared_distance(x[i], y[j]) + + return factor * (np.exp(-distances / gamma) + epsilon) + + +@njit(cache=True, fastmath=True) +def _kdtw_cost_matrix( + x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float +) -> np.ndarray: + # deals with multivariate time series, afterward, we just work with the distances + # and do not need to deal with the channels anymore + local_kernel = _local_kernel(x, y, gamma, epsilon) + + # For the initial values of the cost matrix, we add 1 + n = np.shape(x)[-1] + 1 + m = np.shape(y)[-1] + 1 + + cost_matrix = np.zeros((n, m)) + cumulative_dp_diag = np.zeros((n, m)) + diagonal_weights = np.zeros(max(n, m)) + + # Initialize the diagonal weights + min_timepoints = min(n, m) + diagonal_weights[0] = 1.0 + for i in range(1, min_timepoints): + diagonal_weights[i] = local_kernel[i - 1, i - 1] + + # Initialize the cost matrix and cumulative dp diagonal + cost_matrix[0, 0] = 1 + cumulative_dp_diag[0, 0] = 1 + + # - left column + for i in range(1, n): + cost_matrix[i, 0] = cost_matrix[i - 1, 0] * local_kernel[i - 1, 0] + cumulative_dp_diag[i, 0] = cumulative_dp_diag[i - 1, 0] * diagonal_weights[i] + + # - top row + for j in range(1, m): + cost_matrix[0, j] = cost_matrix[0, j - 1] * local_kernel[0, j - 1] + cumulative_dp_diag[0, j] = cumulative_dp_diag[0, j - 1] * diagonal_weights[j] + + # Perform the main dynamic programming loop + for i in range(1, n): + for j in range(1, m): + local_cost = local_kernel[i - 1, j - 1] + cost_matrix[i, j] = ( + cost_matrix[i - 1, j] + + cost_matrix[i, j - 1] + + cost_matrix[i - 1, j - 1] + ) * local_cost + cumulative_dp_diag[i, j] = ( + cumulative_dp_diag[i - 1, j] * diagonal_weights[i] + + cumulative_dp_diag[i, j - 1] * diagonal_weights[j] + ) + if i == j: + cumulative_dp_diag[i, j] += ( + cumulative_dp_diag[i - 1, j - 1] * local_cost + ) + + # Add the cumulative dp diagonal to the cost matrix + cost_matrix = cost_matrix + cumulative_dp_diag + return cost_matrix + + +def kdtw_pairwise_distance( + X: Union[np.ndarray, list[np.ndarray]], + y: Optional[Union[np.ndarray, list[np.ndarray]]] = None, + gamma: float = 0.125, + epsilon: float = 1e-20, + normalize: bool = True, +) -> np.ndarray: + """Compute the KDTW pairwise distance between a set of time series. + + Parameters + ---------- + X : np.ndarray + A collection of time series instances of shape ``(n_instances, n_timepoints)`` + or ``(n_instances, n_channels, n_timepoints)``. + y : np.ndarray or None, default=None + A single series or a collection of time series of shape ``(m_timepoints,)`` or + ``(m_instances, m_timepoints)`` or ``(m_instances, m_channels, m_timepoints)``. + If None, then the KDTW pairwise distance between the instances of X is + calculated. + gamma : float, default=0.125 + bandwidth parameter which weights the local contributions, i.e. the distances + between locally aligned positions. + epsilon : float, default=1e-20 + Small value to avoid zero. The default is 1e-20. + normalize : bool, default=True + Whether to normalize the distance to make it shift invariant. The default is + True. + + Returns + ------- + np.ndarray (n_instances, n_instances) + KDTW pairwise matrix between the instances of X. + + Raises + ------ + ValueError + If X is not 2D or 3D array when only passing X. + If X and y are not 1D, 2D or 3D arrays when passing both X and y. + + Examples + -------- + >>> import numpy as np + >>> from aeon.distances import kdtw_pairwise_distance + >>> # Distance between each time series in a collection of time series + >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) + >>> kdtw_pairwise_distance(X) + array([[ 0., 8., 12.], + [ 8., 0., 8.], + [12., 8., 0.]]) + + >>> # Distance between two collections of time series + >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) + >>> y = np.array([[[11, 12, 13]],[[14, 15, 16]], [[17, 18, 19]]]) + >>> kdtw_pairwise_distance(X, y) + array([[16., 19., 22.], + [13., 16., 19.], + [10., 13., 16.]]) + + >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) + >>> y_univariate = np.array([[11, 12, 13],[14, 15, 16], [17, 18, 19]]) + >>> kdtw_pairwise_distance(X, y_univariate) + array([[16.], + [13.], + [10.]]) + + """ + multivariate_conversion = _is_numpy_list_multivariate(X, y) + _X, _ = _convert_collection_to_numba_list(X, "X", multivariate_conversion) + if y is None: + # To self + return _kdtw_pairwise_distance(_X, gamma, epsilon, normalize) + + _y, _ = _convert_collection_to_numba_list(y, "y", multivariate_conversion) + return _kdtw_from_multiple_to_multiple_distance(_X, _y, gamma, epsilon, normalize) + + +@njit(cache=True, fastmath=True) +def _kdtw_pairwise_distance( + X: NumbaList[np.ndarray], gamma: float, epsilon: float, normalize: bool +) -> np.ndarray: + n_instances = len(X) + distances = np.zeros((n_instances, n_instances)) + + for i in range(n_instances): + for j in range(i + 1, n_instances): + distances[i, j] = _kdtw_distance(X[i], X[j], gamma, epsilon, normalize) + distances[j, i] = distances[i, j] + + return distances + + +@njit(cache=True, fastmath=True) +def _kdtw_from_multiple_to_multiple_distance( + x: NumbaList[np.ndarray], + y: NumbaList[np.ndarray], + gamma: float, + epsilon: float, + normalize: bool, +) -> np.ndarray: + n_instances = len(x) + m_instances = len(y) + distances = np.zeros((n_instances, m_instances)) + + for i in range(n_instances): + for j in range(m_instances): + distances[i, j] = _kdtw_distance(x[i], y[j], gamma, epsilon, normalize) + return distances + + +@njit(cache=True) +def kdtw_alignment_path( + x: np.ndarray, + y: np.ndarray, + gamma: float = 0.125, + epsilon: float = 1e-20, + normalize: bool = True, +) -> tuple[list[tuple[int, int]], float]: + """Compute the kdtw alignment path between two time series. + + Parameters + ---------- + x : np.ndarray + First time series, shape ``(n_channels, n_timepoints)`` or ``(n_timepoints,)``. + y : np.ndarray + Second time series, shape ``(m_channels, m_timepoints)`` or ``(m_timepoints,)``. + gamma : float, default=0.125 + bandwidth parameter which weights the local contributions, i.e. the distances + between locally aligned positions. + epsilon : float, default=1e-20 + Small value to avoid zero. The default is 1e-20. + + Returns + ------- + List[Tuple[int, int]] + The alignment path between the two time series where each element is a tuple + of the index in x and the index in y that have the best alignment according + to the cost matrix. + float + The unnormalized kdtw distance betweeen the two time series. + + Raises + ------ + ValueError + If x and y are not 1D or 2D arrays. + + Examples + -------- + >>> import numpy as np + >>> from aeon.distances import kdtw_alignment_path + >>> x = np.array([[1, 2, 3, 6]]) + >>> y = np.array([[1, 2, 3, 4]]) + >>> kdtw_alignment_path(x, y) + ([(0, 0), (1, 1), (2, 2), (3, 3)], 2.0) + """ + x_size = x.shape[-1] + y_size = y.shape[-1] + cost_matrix = kdtw_cost_matrix(x, y, gamma, epsilon) + return compute_min_return_path(cost_matrix), cost_matrix[x_size, y_size] diff --git a/aeon/distances/kernel/tests/__init__.py b/aeon/distances/kernel/tests/__init__.py new file mode 100644 index 0000000000..6ccdc73d52 --- /dev/null +++ b/aeon/distances/kernel/tests/__init__.py @@ -0,0 +1 @@ +"""Kernel distance tests.""" diff --git a/aeon/distances/kernel/tests/test_distance_correctness.py b/aeon/distances/kernel/tests/test_distance_correctness.py new file mode 100644 index 0000000000..e4d4be01e7 --- /dev/null +++ b/aeon/distances/kernel/tests/test_distance_correctness.py @@ -0,0 +1,70 @@ +"""Test the distance calculations are correct. + +Compare the distance calculations on the 1D and 2D (d,m) format input against the +results generated with tsml, in distances.tests.TestDistances. +""" + +from numpy.testing import assert_almost_equal + +from aeon.datasets import load_basic_motions, load_unit_test +from aeon.distances import kdtw_distance + +distances = [ + "kdtw", +] + +distance_parameters = { + "kdtw": [0.0, 0.1, 1.0], # gamma +} +unit_test_distances = { + "kdtw": [0.0, 0.0, 0.0], + "kdtw_norm": [0.0, 0.0, 0.0], +} +basic_motions_distances = { + "kdtw": [0.0, 0.0, 0.0], + "kdtw_norm": [0.0, 0.0, 0.0], +} + + +def test_multivariate_correctness(): + """Test distance correctness on BasicMotions: multivariate, equal length.""" + trainX, _ = load_basic_motions(return_type="numpy3D") + case1 = trainX[0] + case2 = trainX[1] + + for j in range(0, 3): + d = kdtw_distance( + case1, case2, gamma=distance_parameters["kdtw"][j], normalize=False + ) + assert_almost_equal(d, basic_motions_distances["kdtw"][j], 4) + d = kdtw_distance( + case1, case2, gamma=distance_parameters["kdtw"][j], normalize=True + ) + assert_almost_equal(d, basic_motions_distances["kdtw_norm"][j], 4) + + +def test_univariate_correctness(): + """Test correctness on UnitTest: univariate, equal length.""" + trainX, _ = load_unit_test(return_type="numpy3D") + trainX2, _ = load_unit_test(return_type="numpy2D") + # Test 2D and 3D instances from UnitTest + cases1 = [trainX[0], trainX2[0]] + cases2 = [trainX[2], trainX2[2]] + + for j in range(0, 3): + d = kdtw_distance( + cases1[0], cases2[0], gamma=distance_parameters["kdtw"][j], normalize=False + ) + d2 = kdtw_distance( + cases1[1], cases2[1], gamma=distance_parameters["kdtw"][j], normalize=False + ) + assert_almost_equal(d, unit_test_distances["kdtw"][j], 4) + assert d == d2 + d = kdtw_distance( + cases1[0], cases2[0], gamma=distance_parameters["kdtw"][j], normalize=True + ) + d2 = kdtw_distance( + cases1[1], cases2[1], gamma=distance_parameters["kdtw"][j], normalize=True + ) + assert_almost_equal(d, unit_test_distances["kdtw_norm"][j], 4) + assert d == d2 diff --git a/aeon/distances/tests/test_numba_distance_parameters.py b/aeon/distances/tests/test_numba_distance_parameters.py index 2293a239bd..0bf8c06167 100644 --- a/aeon/distances/tests/test_numba_distance_parameters.py +++ b/aeon/distances/tests/test_numba_distance_parameters.py @@ -132,6 +132,10 @@ def _test_distance_params( "shape_dtw": BASIC_BOUNDING_PARAMS + [{"reach": 4}], "shift_scale": [{"max_shift": 1}, {"max_shift": None}], "soft_dtw": BASIC_BOUNDING_PARAMS + [{"gamma": 0.2}], + "kdtw": [ + {"gamma": 0.125, "epsilon": 1e-3}, + {"gamma": 0.125, "epsilon": 1e-3, "normalize": True}, + ], } diff --git a/aeon/testing/expected_results/expected_distance_results.py b/aeon/testing/expected_results/expected_distance_results.py index 7126c5c624..8355d9779e 100644 --- a/aeon/testing/expected_results/expected_distance_results.py +++ b/aeon/testing/expected_results/expected_distance_results.py @@ -125,6 +125,13 @@ None, 0.0, ], + "kdtw": [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], } @@ -190,4 +197,8 @@ [0.8103639073457298, 5.535457073146429], [0.6519267432870345, 5.491208968546096], ], + "kdtw": [ + [0.0, 0.0], + [0.0, 0.0], + ], } From 16753525671ced8098fe2ad44d2bae48fb9e279f Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Tue, 20 May 2025 18:30:39 +0200 Subject: [PATCH 2/9] feat: improve implementation by including normalization options --- aeon/distances/elastic/_alignment_paths.py | 24 ++- aeon/distances/kernel/_kdtw.py | 183 ++++++++++++++------- 2 files changed, 139 insertions(+), 68 deletions(-) diff --git a/aeon/distances/elastic/_alignment_paths.py b/aeon/distances/elastic/_alignment_paths.py index f70a374cb2..7c4b6822a3 100644 --- a/aeon/distances/elastic/_alignment_paths.py +++ b/aeon/distances/elastic/_alignment_paths.py @@ -7,13 +7,17 @@ @njit(cache=True, fastmath=True) -def compute_min_return_path(cost_matrix: np.ndarray) -> list[tuple]: +def compute_min_return_path( + cost_matrix: np.ndarray, larger_is_better: bool = False +) -> list[tuple]: """Compute the minimum return path through a cost matrix. Parameters ---------- cost_matrix : np.ndarray, of shape (n_timepoints_x, n_timepoints_y) Cost matrix. + larger_is_better : bool, default=False + If True, the path will be computed for the maximum cost instead of the minimum. Returns ------- @@ -32,15 +36,17 @@ def compute_min_return_path(cost_matrix: np.ndarray) -> list[tuple]: elif j == 0: i -= 1 else: - min_index = np.argmin( - np.array( - [ - cost_matrix[i - 1, j - 1], - cost_matrix[i - 1, j], - cost_matrix[i, j - 1], - ] - ) + costs = np.array( + [ + cost_matrix[i - 1, j - 1], + cost_matrix[i - 1, j], + cost_matrix[i, j - 1], + ] ) + if larger_is_better: + min_index = np.argmax(costs) + else: + min_index = np.argmin(costs) if min_index == 0: i, j = i - 1, j - 1 diff --git a/aeon/distances/kernel/_kdtw.py b/aeon/distances/kernel/_kdtw.py index adccc9f9a7..25c05ea647 100644 --- a/aeon/distances/kernel/_kdtw.py +++ b/aeon/distances/kernel/_kdtw.py @@ -1,6 +1,12 @@ """Dynamic time warping kernel (KDTW) distance between two time series.""" __maintainer__ = ["SebastianSchmidl"] +__all__ = [ + "kdtw_distance", + "kdtw_cost_matrix", + "kdtw_pairwise_distance", + "kdtw_alignment_path", +] from typing import Optional, Union @@ -20,7 +26,8 @@ def kdtw_distance( y: np.ndarray, gamma: float = 0.125, epsilon: float = 1e-20, - normalize: bool = True, + normalize_input: bool = True, + normalize_dist: bool = False, ) -> float: r"""Compute the DTW kernel (KDTW) between two time series as a distance. @@ -51,9 +58,12 @@ def kdtw_distance( between locally aligned positions. epsilon : float, default=1e-20 Small value to avoid zero. The default is 1e-20. - normalize : bool, default=True - Whether to normalize hte distance to make it shift invariant. The default is - True. + normalize_input : bool, default=True + Whether to normalize the time series' channels to zero mean, unit variance, and + unit length before computing the distance. Highly recommended! + normalize_dist : bool, default=False + Whether to normalize the distance by the product of the self distances of x and + y to avoid scaling effects and put the distance value between 0 and 1. Returns ------- @@ -82,15 +92,17 @@ def kdtw_distance( >>> from aeon.distances import kdtw_distance >>> x = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) >>> y = np.array([[11, 12, 13, 14, 15, 16, 17, 18, 19, 20]]) - >>> dist = kdtw_distance(x, y) - 4.613613743977605e-203 + >>> kdtw_distance(x, y) + 0.8051764348248271 + >>> kdtw_distance(x, y, normalize_dist=True) + 0.0 """ if x.ndim == 1 and y.ndim == 1: _x = x.reshape((1, x.shape[0])) _y = y.reshape((1, y.shape[0])) - return _kdtw_distance(_x, _y, gamma, epsilon, normalize) + return _kdtw_distance(_x, _y, gamma, epsilon, normalize_input, normalize_dist) if x.ndim == 2 and y.ndim == 2: - return _kdtw_distance(x, y, gamma, epsilon, normalize) + return _kdtw_distance(x, y, gamma, epsilon, normalize_input, normalize_dist) raise ValueError("x and y must be 1D or 2D") @@ -100,6 +112,7 @@ def kdtw_cost_matrix( y: np.ndarray, gamma: float = 0.125, epsilon: float = 1e-20, + normalize_input: bool = True, ) -> np.ndarray: """Compute the cost matrix for KDTW between two time series. @@ -116,6 +129,9 @@ def kdtw_cost_matrix( between locally aligned positions. epsilon : float, default=1e-20 Small value to avoid zero. The default is 1e-20. + normalize_input : bool, default=True + Whether to normalize the time series' channels to zero mean, unit variance, and + unit length before computing the distance. Returns ------- @@ -131,26 +147,28 @@ def kdtw_cost_matrix( -------- >>> import numpy as np >>> from aeon.distances import kdtw_cost_matrix - >>> x = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) - >>> y = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + >>> x = np.array([[1, 2, 3, 4, 5]]) + >>> y = np.array([[1, 2, 3, 4, 5]]) >>> kdtw_cost_matrix(x, y) - array([[ 0., 2., 4., 6., 8., 10., 12., 14., 16., 18.], - [ 2., 0., 2., 4., 6., 8., 10., 12., 14., 16.], - [ 4., 2., 0., 2., 4., 6., 8., 10., 12., 14.], - [ 6., 4., 2., 0., 2., 4., 6., 8., 10., 12.], - [ 8., 6., 4., 2., 0., 2., 4., 6., 8., 10.], - [10., 8., 6., 4., 2., 0., 2., 4., 6., 8.], - [12., 10., 8., 6., 4., 2., 0., 2., 4., 6.], - [14., 12., 10., 8., 6., 4., 2., 0., 2., 4.], - [16., 14., 12., 10., 8., 6., 4., 2., 0., 2.], - [18., 16., 14., 12., 10., 8., 6., 4., 2., 0.]]) + array([[2. , 0.66666667, 0.22222222, 0.07407407, 0.02469136, + 0.00823045], + [0.66666667, 1.11111111, 0.55555556, 0.24691358, 0.10288066, + 0.04115226], + [0.22222222, 0.55555556, 0.74074074, 0.44032922, 0.2345679 , + 0.11522634], + [0.07407407, 0.24691358, 0.44032922, 0.54046639, 0.35848194, + 0.21688767], + [0.02469136, 0.10288066, 0.2345679 , 0.35848194, 0.41914342, + 0.30239293], + [0.00823045, 0.04115226, 0.11522634, 0.21688767, 0.30239293, + 0.34130976]]) """ if x.ndim == 1 and y.ndim == 1: _x = x.reshape((1, x.shape[0])) _y = y.reshape((1, y.shape[0])) - return _kdtw_cost_matrix(_x, _y, gamma, epsilon) + return _kdtw_cost_matrix(_x, _y, gamma, epsilon, normalize_input) if x.ndim == 2 and y.ndim == 2: - return _kdtw_cost_matrix(x, y, gamma, epsilon) + return _kdtw_cost_matrix(x, y, gamma, epsilon, normalize_input) raise ValueError("x and y must be 1D or 2D") @@ -160,51 +178,76 @@ def _kdtw_distance( y: np.ndarray, gamma: float, epsilon: float, - normalize: bool, + normalize_input: bool, + normalize_dist: bool, ) -> float: # Do not subtract one because the cost matrix is one dimension larger: n = x.shape[-1] m = y.shape[-1] - if normalize: - self_x = _kdtw_cost_matrix(x, x, gamma, epsilon)[n, n] - self_y = _kdtw_cost_matrix(y, y, gamma, epsilon)[m, m] + if normalize_dist: + self_x = _kdtw_cost_matrix(x, x, gamma, epsilon, normalize_input)[n, n] + self_y = _kdtw_cost_matrix(y, y, gamma, epsilon, normalize_input)[m, m] norm_factor = np.sqrt(self_x * self_y) else: norm_factor = 1.0 - return _kdtw_cost_matrix(x, y, gamma, epsilon)[n, m] / norm_factor + return ( + 1.0 + - _kdtw_cost_matrix(x, y, gamma, epsilon, normalize_input)[n, m] / norm_factor + ) @njit(cache=True, fastmath=True) def _local_kernel( - x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float + x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float, normalize_input: bool ) -> np.ndarray: + if normalize_input: + # First apply z-score normalization, then unit length normalization on each + # channel + eps = np.finfo(np.float64).eps + _x = np.empty_like(x) + _y = np.empty_like(y) + + # Numba mean and std do not support axis parameters and + # np.linalg.norm is not supported in numba. + for i in range(x.shape[0]): + _x[i] = (x[i] - np.mean(x[i])) / (np.std(x[i]) + eps) + norm = np.sqrt(np.sum(x[i] ** 2)) + _x[i] = _x[i] / (norm + eps) + for i in range(y.shape[0]): + _y[i] = (y[i] - np.mean(y[i])) / (np.std(y[i]) + eps) + norm = np.sqrt(np.sum(y[i] ** 2)) + _y[i] = _y[i] / (norm + eps) + else: + _x = x + _y = y + factor = 1.0 / 3.0 # Incoming shape is (n_channels, n_timepoints) # We want to calculate the multivariate squared distance between the two time series # considering each point in the time series as a separate instance, thus we need to # reshape to (m_cases, m_channels, 1), where m_cases = n_timepoints and # m_channels = n_channels. - x = x.T.reshape(-1, x.shape[0], 1) - y = y.T.reshape(-1, y.shape[0], 1) - n_cases = len(x) - m_cases = len(y) + _x = _x.T.reshape(-1, _x.shape[0], 1) + _y = _y.T.reshape(-1, _y.shape[0], 1) + n_cases = _x.shape[0] + m_cases = _y.shape[0] distances = np.zeros((n_cases, m_cases)) for i in range(n_cases): for j in range(m_cases): # expects each input to have shape (n_channels, n_timepoints = 1) - distances[i, j] = squared_distance(x[i], y[j]) + distances[i, j] = squared_distance(_x[i], _y[j]) return factor * (np.exp(-distances / gamma) + epsilon) @njit(cache=True, fastmath=True) def _kdtw_cost_matrix( - x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float + x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float, normalize_input: bool ) -> np.ndarray: # deals with multivariate time series, afterward, we just work with the distances # and do not need to deal with the channels anymore - local_kernel = _local_kernel(x, y, gamma, epsilon) + local_kernel = _local_kernel(x, y, gamma, epsilon, normalize_input) # For the initial values of the cost matrix, we add 1 n = np.shape(x)[-1] + 1 @@ -262,7 +305,8 @@ def kdtw_pairwise_distance( y: Optional[Union[np.ndarray, list[np.ndarray]]] = None, gamma: float = 0.125, epsilon: float = 1e-20, - normalize: bool = True, + normalize_input: bool = True, + normalize_dist: bool = False, ) -> np.ndarray: """Compute the KDTW pairwise distance between a set of time series. @@ -281,9 +325,12 @@ def kdtw_pairwise_distance( between locally aligned positions. epsilon : float, default=1e-20 Small value to avoid zero. The default is 1e-20. - normalize : bool, default=True - Whether to normalize the distance to make it shift invariant. The default is - True. + normalize_input : bool, default=True + Whether to normalize the time series' channels to zero mean, unit variance, and + unit length before computing the distance. + normalize_dist : bool, default=False + Whether to normalize the distance by the product of the self distances of x and + y to avoid scaling effects and put the distance between 0 and 1. Returns ------- @@ -303,46 +350,55 @@ def kdtw_pairwise_distance( >>> # Distance between each time series in a collection of time series >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) >>> kdtw_pairwise_distance(X) - array([[ 0., 8., 12.], - [ 8., 0., 8.], - [12., 8., 0.]]) + array([[0. , 0.45953361, 0.45953361], + [0.45953361, 0. , 0.45953361], + [0.45953361, 0.45953361, 0. ]]) >>> # Distance between two collections of time series >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) >>> y = np.array([[[11, 12, 13]],[[14, 15, 16]], [[17, 18, 19]]]) >>> kdtw_pairwise_distance(X, y) - array([[16., 19., 22.], - [13., 16., 19.], - [10., 13., 16.]]) + array([[0.45953361, 0.45953361, 0.45953361], + [0.45953361, 0.45953361, 0.45953361], + [0.45953361, 0.45953361, 0.45953361]]) >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) >>> y_univariate = np.array([[11, 12, 13],[14, 15, 16], [17, 18, 19]]) >>> kdtw_pairwise_distance(X, y_univariate) - array([[16.], - [13.], - [10.]]) - + array([[0.45953361, 0.45953361, 0.45953361], + [0.45953361, 0.45953361, 0.45953361], + [0.45953361, 0.45953361, 0.45953361]]) """ multivariate_conversion = _is_numpy_list_multivariate(X, y) _X, _ = _convert_collection_to_numba_list(X, "X", multivariate_conversion) if y is None: # To self - return _kdtw_pairwise_distance(_X, gamma, epsilon, normalize) + return _kdtw_pairwise_distance( + _X, gamma, epsilon, normalize_input, normalize_dist + ) _y, _ = _convert_collection_to_numba_list(y, "y", multivariate_conversion) - return _kdtw_from_multiple_to_multiple_distance(_X, _y, gamma, epsilon, normalize) + return _kdtw_from_multiple_to_multiple_distance( + _X, _y, gamma, epsilon, normalize_input, normalize_dist + ) @njit(cache=True, fastmath=True) def _kdtw_pairwise_distance( - X: NumbaList[np.ndarray], gamma: float, epsilon: float, normalize: bool + X: NumbaList[np.ndarray], + gamma: float, + epsilon: float, + normalize_input: bool, + normalize_dist: bool, ) -> np.ndarray: n_instances = len(X) distances = np.zeros((n_instances, n_instances)) for i in range(n_instances): for j in range(i + 1, n_instances): - distances[i, j] = _kdtw_distance(X[i], X[j], gamma, epsilon, normalize) + distances[i, j] = _kdtw_distance( + X[i], X[j], gamma, epsilon, normalize_input, normalize_dist + ) distances[j, i] = distances[i, j] return distances @@ -354,7 +410,8 @@ def _kdtw_from_multiple_to_multiple_distance( y: NumbaList[np.ndarray], gamma: float, epsilon: float, - normalize: bool, + normalize_input: bool, + normalize_dist: bool, ) -> np.ndarray: n_instances = len(x) m_instances = len(y) @@ -362,7 +419,9 @@ def _kdtw_from_multiple_to_multiple_distance( for i in range(n_instances): for j in range(m_instances): - distances[i, j] = _kdtw_distance(x[i], y[j], gamma, epsilon, normalize) + distances[i, j] = _kdtw_distance( + x[i], y[j], gamma, epsilon, normalize_input, normalize_dist + ) return distances @@ -372,7 +431,7 @@ def kdtw_alignment_path( y: np.ndarray, gamma: float = 0.125, epsilon: float = 1e-20, - normalize: bool = True, + normalize_input: bool = True, ) -> tuple[list[tuple[int, int]], float]: """Compute the kdtw alignment path between two time series. @@ -387,6 +446,9 @@ def kdtw_alignment_path( between locally aligned positions. epsilon : float, default=1e-20 Small value to avoid zero. The default is 1e-20. + normalize_input : bool, default=True + Whether to normalize the time series' channels to zero mean, unit variance, and + unit length before computing the distance. Returns ------- @@ -409,9 +471,12 @@ def kdtw_alignment_path( >>> x = np.array([[1, 2, 3, 6]]) >>> y = np.array([[1, 2, 3, 4]]) >>> kdtw_alignment_path(x, y) - ([(0, 0), (1, 1), (2, 2), (3, 3)], 2.0) + ([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)], 0.4191434232586494) """ x_size = x.shape[-1] y_size = y.shape[-1] - cost_matrix = kdtw_cost_matrix(x, y, gamma, epsilon) - return compute_min_return_path(cost_matrix), cost_matrix[x_size, y_size] + cost_matrix = kdtw_cost_matrix(x, y, gamma, epsilon, normalize_input) + return ( + compute_min_return_path(cost_matrix, larger_is_better=True), + cost_matrix[x_size, y_size], + ) From 7bbcebc230db3fe07c8fa8b9674acfeb522890a9 Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Tue, 20 May 2025 18:30:57 +0200 Subject: [PATCH 3/9] feat: add kdtw distance to the documentation --- docs/api_reference/distances.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/api_reference/distances.rst b/docs/api_reference/distances.rst index 05db50f45f..d6ede03c90 100644 --- a/docs/api_reference/distances.rst +++ b/docs/api_reference/distances.rst @@ -222,6 +222,20 @@ Weighted Dynamic Time Warping (DTW) wdtw_cost_matrix wdtw_alignment_path +Dynamic Time Warping kernel distance (KDTW) +------------------------------------------- + +.. currentmodule:: aeon.distances + +.. autosummary:: + :toctree: auto_generated/ + :template: function.rst + + kdtw_distance + kdtw_pairwise_distance + kdtw_cost_matrix + kdtw_alignment_path + General methods with distance argument -------------------------------------- From 54b092bb0342153e8c6d954bc3f874785963c9ee Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Tue, 20 May 2025 19:09:17 +0200 Subject: [PATCH 4/9] feat: work on tests --- aeon/distances/_distance.py | 3 +- aeon/distances/kernel/_kdtw.py | 2 +- .../kernel/tests/test_distance_correctness.py | 51 +++++++++---------- .../tests/test_numba_distance_parameters.py | 2 +- 4 files changed, 29 insertions(+), 29 deletions(-) diff --git a/aeon/distances/_distance.py b/aeon/distances/_distance.py index c227ab577c..cbde1a6f05 100644 --- a/aeon/distances/_distance.py +++ b/aeon/distances/_distance.py @@ -116,7 +116,8 @@ class DistanceKwargs(TypedDict, total=False): max_shift: Optional[int] gamma: float sigma: float - noramlize: bool + normalize_input: bool + normalize_dist: bool DistanceFunction = Callable[[np.ndarray, np.ndarray, Any], float] diff --git a/aeon/distances/kernel/_kdtw.py b/aeon/distances/kernel/_kdtw.py index 25c05ea647..0e5ea71c3b 100644 --- a/aeon/distances/kernel/_kdtw.py +++ b/aeon/distances/kernel/_kdtw.py @@ -55,7 +55,7 @@ def kdtw_distance( multivariate, shape ``(n_channels, n_timepoints)``. gamma : float, default=0.125 bandwidth parameter which weights the local contributions, i.e. the distances - between locally aligned positions. + between locally aligned positions. Must fulfill 0 < gamma < 1! epsilon : float, default=1e-20 Small value to avoid zero. The default is 1e-20. normalize_input : bool, default=True diff --git a/aeon/distances/kernel/tests/test_distance_correctness.py b/aeon/distances/kernel/tests/test_distance_correctness.py index e4d4be01e7..99424ff812 100644 --- a/aeon/distances/kernel/tests/test_distance_correctness.py +++ b/aeon/distances/kernel/tests/test_distance_correctness.py @@ -1,10 +1,11 @@ """Test the distance calculations are correct. Compare the distance calculations on the 1D and 2D (d,m) format input against the -results generated with tsml, in distances.tests.TestDistances. +results generated with the Octave implementation from +https://people.irisa.fr/Pierre-Francois.Marteau/REDK/KDTW/kdtw.m (adapted). """ -from numpy.testing import assert_almost_equal +from numpy.testing import assert_array_almost_equal_nulp from aeon.datasets import load_basic_motions, load_unit_test from aeon.distances import kdtw_distance @@ -14,15 +15,13 @@ ] distance_parameters = { - "kdtw": [0.0, 0.1, 1.0], # gamma -} -unit_test_distances = { - "kdtw": [0.0, 0.0, 0.0], - "kdtw_norm": [0.0, 0.0, 0.0], + "kdtw": [1e-5, 1e-2, 1e-1], # gamma } basic_motions_distances = { - "kdtw": [0.0, 0.0, 0.0], - "kdtw_norm": [0.0, 0.0, 0.0], + "kdtw": [1.0 - 2.7435e-123, 1.0 - 4.9180e-72, 1.0 - 6.5392e-53], +} +unit_test_distances = { + "kdtw": [1.0 - 6.6667e-21, 1.0 - 6.6667e-21, 1.0 - 6.6667e-21], } @@ -34,13 +33,13 @@ def test_multivariate_correctness(): for j in range(0, 3): d = kdtw_distance( - case1, case2, gamma=distance_parameters["kdtw"][j], normalize=False - ) - assert_almost_equal(d, basic_motions_distances["kdtw"][j], 4) - d = kdtw_distance( - case1, case2, gamma=distance_parameters["kdtw"][j], normalize=True + case1, + case2, + gamma=distance_parameters["kdtw"][j], + normalize_input=False, + normalize_dist=False, ) - assert_almost_equal(d, basic_motions_distances["kdtw_norm"][j], 4) + assert_array_almost_equal_nulp(d, basic_motions_distances["kdtw"][j]) def test_univariate_correctness(): @@ -53,18 +52,18 @@ def test_univariate_correctness(): for j in range(0, 3): d = kdtw_distance( - cases1[0], cases2[0], gamma=distance_parameters["kdtw"][j], normalize=False - ) - d2 = kdtw_distance( - cases1[1], cases2[1], gamma=distance_parameters["kdtw"][j], normalize=False - ) - assert_almost_equal(d, unit_test_distances["kdtw"][j], 4) - assert d == d2 - d = kdtw_distance( - cases1[0], cases2[0], gamma=distance_parameters["kdtw"][j], normalize=True + cases1[0], + cases2[0], + gamma=distance_parameters["kdtw"][j], + normalize_input=False, + normalize_dist=False, ) d2 = kdtw_distance( - cases1[1], cases2[1], gamma=distance_parameters["kdtw"][j], normalize=True + cases1[1], + cases2[1], + gamma=distance_parameters["kdtw"][j], + normalize_input=False, + normalize_dist=False, ) - assert_almost_equal(d, unit_test_distances["kdtw_norm"][j], 4) + assert_array_almost_equal_nulp(d, unit_test_distances["kdtw"][j]) assert d == d2 diff --git a/aeon/distances/tests/test_numba_distance_parameters.py b/aeon/distances/tests/test_numba_distance_parameters.py index 0bf8c06167..f963b27890 100644 --- a/aeon/distances/tests/test_numba_distance_parameters.py +++ b/aeon/distances/tests/test_numba_distance_parameters.py @@ -134,7 +134,7 @@ def _test_distance_params( "soft_dtw": BASIC_BOUNDING_PARAMS + [{"gamma": 0.2}], "kdtw": [ {"gamma": 0.125, "epsilon": 1e-3}, - {"gamma": 0.125, "epsilon": 1e-3, "normalize": True}, + {"gamma": 0.125, "epsilon": 1e-3, "normalize_dist": True}, ], } From dccef27a495cdb082fbf4a5c8d4bad0dd459057d Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Wed, 21 May 2025 10:13:35 +0200 Subject: [PATCH 5/9] fix: normalize just once --- aeon/distances/kernel/_kdtw.py | 103 +++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 42 deletions(-) diff --git a/aeon/distances/kernel/_kdtw.py b/aeon/distances/kernel/_kdtw.py index 0e5ea71c3b..af7ab6c9bb 100644 --- a/aeon/distances/kernel/_kdtw.py +++ b/aeon/distances/kernel/_kdtw.py @@ -20,6 +20,35 @@ from aeon.utils.validation.collection import _is_numpy_list_multivariate +@njit(cache=True, fastmath=True) +def _normalize_time_series(x: np.ndarray) -> np.ndarray: + """Normalize the time series to zero mean, unit variance, and unit length. + + First apply z-score normalization, then unit length normalization on each + channel. + + Parameters + ---------- + x : np.ndarray + Time series of shape ``(n_channels, n_timepoints)``. + + Returns + ------- + np.ndarray + Normalized time series of shape ``(n_channels, n_timepoints)``. + """ + eps = np.finfo(np.float64).eps + _x = np.empty_like(x) + + # Numba mean and std do not support axis parameters and + # np.linalg.norm is not supported in numba. + for i in range(x.shape[0]): + _x[i] = (x[i] - np.mean(x[i])) / (np.std(x[i]) + eps) + norm = np.sqrt(np.sum(x[i] ** 2)) + _x[i] = _x[i] / (norm + eps) + return _x + + @njit(cache=True, fastmath=True) def kdtw_distance( x: np.ndarray, @@ -163,13 +192,20 @@ def kdtw_cost_matrix( [0.00823045, 0.04115226, 0.11522634, 0.21688767, 0.30239293, 0.34130976]]) """ - if x.ndim == 1 and y.ndim == 1: + _x = x + _y = y + if x.ndim == 1: _x = x.reshape((1, x.shape[0])) + if y.ndim == 1: _y = y.reshape((1, y.shape[0])) - return _kdtw_cost_matrix(_x, _y, gamma, epsilon, normalize_input) - if x.ndim == 2 and y.ndim == 2: - return _kdtw_cost_matrix(x, y, gamma, epsilon, normalize_input) - raise ValueError("x and y must be 1D or 2D") + if x.ndim != 2 or y.ndim != 2: + raise ValueError("x and y must be 1D or 2D") + + if normalize_input: + _x = _normalize_time_series(x) + _y = _normalize_time_series(y) + + return _kdtw_cost_matrix(_x, _y, gamma, epsilon) @njit(cache=True, fastmath=True) @@ -181,73 +217,56 @@ def _kdtw_distance( normalize_input: bool, normalize_dist: bool, ) -> float: + if normalize_input: + _x = _normalize_time_series(x) + _y = _normalize_time_series(y) + else: + _x = x + _y = y + # Do not subtract one because the cost matrix is one dimension larger: - n = x.shape[-1] - m = y.shape[-1] + n = _x.shape[-1] + m = _y.shape[-1] if normalize_dist: - self_x = _kdtw_cost_matrix(x, x, gamma, epsilon, normalize_input)[n, n] - self_y = _kdtw_cost_matrix(y, y, gamma, epsilon, normalize_input)[m, m] + self_x = _kdtw_cost_matrix(_x, _x, gamma, epsilon)[n, n] + self_y = _kdtw_cost_matrix(_y, _y, gamma, epsilon)[m, m] norm_factor = np.sqrt(self_x * self_y) else: norm_factor = 1.0 - return ( - 1.0 - - _kdtw_cost_matrix(x, y, gamma, epsilon, normalize_input)[n, m] / norm_factor - ) + return 1.0 - _kdtw_cost_matrix(_x, _y, gamma, epsilon)[n, m] / norm_factor @njit(cache=True, fastmath=True) def _local_kernel( - x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float, normalize_input: bool + x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float ) -> np.ndarray: - if normalize_input: - # First apply z-score normalization, then unit length normalization on each - # channel - eps = np.finfo(np.float64).eps - _x = np.empty_like(x) - _y = np.empty_like(y) - - # Numba mean and std do not support axis parameters and - # np.linalg.norm is not supported in numba. - for i in range(x.shape[0]): - _x[i] = (x[i] - np.mean(x[i])) / (np.std(x[i]) + eps) - norm = np.sqrt(np.sum(x[i] ** 2)) - _x[i] = _x[i] / (norm + eps) - for i in range(y.shape[0]): - _y[i] = (y[i] - np.mean(y[i])) / (np.std(y[i]) + eps) - norm = np.sqrt(np.sum(y[i] ** 2)) - _y[i] = _y[i] / (norm + eps) - else: - _x = x - _y = y - factor = 1.0 / 3.0 # Incoming shape is (n_channels, n_timepoints) # We want to calculate the multivariate squared distance between the two time series # considering each point in the time series as a separate instance, thus we need to # reshape to (m_cases, m_channels, 1), where m_cases = n_timepoints and # m_channels = n_channels. - _x = _x.T.reshape(-1, _x.shape[0], 1) - _y = _y.T.reshape(-1, _y.shape[0], 1) - n_cases = _x.shape[0] - m_cases = _y.shape[0] + x = x.T.reshape(-1, x.shape[0], 1) + y = y.T.reshape(-1, y.shape[0], 1) + n_cases = x.shape[0] + m_cases = y.shape[0] distances = np.zeros((n_cases, m_cases)) for i in range(n_cases): for j in range(m_cases): # expects each input to have shape (n_channels, n_timepoints = 1) - distances[i, j] = squared_distance(_x[i], _y[j]) + distances[i, j] = squared_distance(x[i], y[j]) return factor * (np.exp(-distances / gamma) + epsilon) @njit(cache=True, fastmath=True) def _kdtw_cost_matrix( - x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float, normalize_input: bool + x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float ) -> np.ndarray: # deals with multivariate time series, afterward, we just work with the distances # and do not need to deal with the channels anymore - local_kernel = _local_kernel(x, y, gamma, epsilon, normalize_input) + local_kernel = _local_kernel(x, y, gamma, epsilon) # For the initial values of the cost matrix, we add 1 n = np.shape(x)[-1] + 1 From 6da516c913199c466332f4e96d6ee21664c11789 Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Wed, 21 May 2025 13:57:51 +0200 Subject: [PATCH 6/9] chore: just z-score is recommended; reduce cost matrix to proper size --- aeon/distances/kernel/_kdtw.py | 70 +++++++++++++++------------------- 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/aeon/distances/kernel/_kdtw.py b/aeon/distances/kernel/_kdtw.py index af7ab6c9bb..902a741226 100644 --- a/aeon/distances/kernel/_kdtw.py +++ b/aeon/distances/kernel/_kdtw.py @@ -19,13 +19,12 @@ from aeon.utils.conversion._convert_collection import _convert_collection_to_numba_list from aeon.utils.validation.collection import _is_numpy_list_multivariate +_eps = np.finfo(np.float64).eps + @njit(cache=True, fastmath=True) def _normalize_time_series(x: np.ndarray) -> np.ndarray: - """Normalize the time series to zero mean, unit variance, and unit length. - - First apply z-score normalization, then unit length normalization on each - channel. + """Normalize the time series to zero mean and unit variance. Parameters ---------- @@ -37,15 +36,11 @@ def _normalize_time_series(x: np.ndarray) -> np.ndarray: np.ndarray Normalized time series of shape ``(n_channels, n_timepoints)``. """ - eps = np.finfo(np.float64).eps _x = np.empty_like(x) - # Numba mean and std do not support axis parameters and - # np.linalg.norm is not supported in numba. + # Numba mean and std do not support axis parameters for i in range(x.shape[0]): - _x[i] = (x[i] - np.mean(x[i])) / (np.std(x[i]) + eps) - norm = np.sqrt(np.sum(x[i] ** 2)) - _x[i] = _x[i] / (norm + eps) + _x[i] = (x[i] - np.mean(x[i])) / (np.std(x[i]) + _eps) return _x @@ -88,8 +83,8 @@ def kdtw_distance( epsilon : float, default=1e-20 Small value to avoid zero. The default is 1e-20. normalize_input : bool, default=True - Whether to normalize the time series' channels to zero mean, unit variance, and - unit length before computing the distance. Highly recommended! + Whether to normalize the time series' channels to zero mean and unit variance + before computing the distance. Highly recommended! normalize_dist : bool, default=False Whether to normalize the distance by the product of the self distances of x and y to avoid scaling effects and put the distance value between 0 and 1. @@ -159,8 +154,8 @@ def kdtw_cost_matrix( epsilon : float, default=1e-20 Small value to avoid zero. The default is 1e-20. normalize_input : bool, default=True - Whether to normalize the time series' channels to zero mean, unit variance, and - unit length before computing the distance. + Whether to normalize the time series' channels to zero mean and unit variance + before computing the distance. Returns ------- @@ -179,18 +174,11 @@ def kdtw_cost_matrix( >>> x = np.array([[1, 2, 3, 4, 5]]) >>> y = np.array([[1, 2, 3, 4, 5]]) >>> kdtw_cost_matrix(x, y) - array([[2. , 0.66666667, 0.22222222, 0.07407407, 0.02469136, - 0.00823045], - [0.66666667, 1.11111111, 0.55555556, 0.24691358, 0.10288066, - 0.04115226], - [0.22222222, 0.55555556, 0.74074074, 0.44032922, 0.2345679 , - 0.11522634], - [0.07407407, 0.24691358, 0.44032922, 0.54046639, 0.35848194, - 0.21688767], - [0.02469136, 0.10288066, 0.2345679 , 0.35848194, 0.41914342, - 0.30239293], - [0.00823045, 0.04115226, 0.11522634, 0.21688767, 0.30239293, - 0.34130976]]) + array([[1.11111111, 0.55555556, 0.24691358, 0.10288066, 0.04115226], + [0.55555556, 0.74074074, 0.44032922, 0.2345679 , 0.11522634], + [0.24691358, 0.44032922, 0.54046639, 0.35848194, 0.21688767], + [0.10288066, 0.2345679 , 0.35848194, 0.41914342, 0.30239293], + [0.04115226, 0.11522634, 0.21688767, 0.30239293, 0.34130976]]) """ _x = x _y = y @@ -198,12 +186,12 @@ def kdtw_cost_matrix( _x = x.reshape((1, x.shape[0])) if y.ndim == 1: _y = y.reshape((1, y.shape[0])) - if x.ndim != 2 or y.ndim != 2: + if _x.ndim != 2 or _y.ndim != 2: raise ValueError("x and y must be 1D or 2D") if normalize_input: - _x = _normalize_time_series(x) - _y = _normalize_time_series(y) + _x = _normalize_time_series(_x) + _y = _normalize_time_series(_y) return _kdtw_cost_matrix(_x, _y, gamma, epsilon) @@ -224,13 +212,14 @@ def _kdtw_distance( _x = x _y = y - # Do not subtract one because the cost matrix is one dimension larger: - n = _x.shape[-1] - m = _y.shape[-1] + n = _x.shape[-1] - 1 + m = _y.shape[-1] - 1 if normalize_dist: self_x = _kdtw_cost_matrix(_x, _x, gamma, epsilon)[n, n] self_y = _kdtw_cost_matrix(_y, _y, gamma, epsilon)[m, m] norm_factor = np.sqrt(self_x * self_y) + if norm_factor < _eps: + norm_factor = 1.0 else: norm_factor = 1.0 return 1.0 - _kdtw_cost_matrix(_x, _y, gamma, epsilon)[n, m] / norm_factor @@ -240,6 +229,7 @@ def _kdtw_distance( def _local_kernel( x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float ) -> np.ndarray: + # 1 / c in the paper; beta on the website factor = 1.0 / 3.0 # Incoming shape is (n_channels, n_timepoints) # We want to calculate the multivariate squared distance between the two time series @@ -265,7 +255,7 @@ def _kdtw_cost_matrix( x: np.ndarray, y: np.ndarray, gamma: float, epsilon: float ) -> np.ndarray: # deals with multivariate time series, afterward, we just work with the distances - # and do not need to deal with the channels anymore + # and do not handle the channels anymore local_kernel = _local_kernel(x, y, gamma, epsilon) # For the initial values of the cost matrix, we add 1 @@ -316,7 +306,7 @@ def _kdtw_cost_matrix( # Add the cumulative dp diagonal to the cost matrix cost_matrix = cost_matrix + cumulative_dp_diag - return cost_matrix + return cost_matrix[1:, 1:] def kdtw_pairwise_distance( @@ -345,8 +335,8 @@ def kdtw_pairwise_distance( epsilon : float, default=1e-20 Small value to avoid zero. The default is 1e-20. normalize_input : bool, default=True - Whether to normalize the time series' channels to zero mean, unit variance, and - unit length before computing the distance. + Whether to normalize the time series' channels to zero mean and unit variance + before computing the distance. normalize_dist : bool, default=False Whether to normalize the distance by the product of the self distances of x and y to avoid scaling effects and put the distance between 0 and 1. @@ -466,8 +456,8 @@ def kdtw_alignment_path( epsilon : float, default=1e-20 Small value to avoid zero. The default is 1e-20. normalize_input : bool, default=True - Whether to normalize the time series' channels to zero mean, unit variance, and - unit length before computing the distance. + Whether to normalize the time series' channels to zero mean and unit variance + before computing the distance. Returns ------- @@ -492,8 +482,8 @@ def kdtw_alignment_path( >>> kdtw_alignment_path(x, y) ([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)], 0.4191434232586494) """ - x_size = x.shape[-1] - y_size = y.shape[-1] + x_size = x.shape[-1] - 1 + y_size = y.shape[-1] - 1 cost_matrix = kdtw_cost_matrix(x, y, gamma, epsilon, normalize_input) return ( compute_min_return_path(cost_matrix, larger_is_better=True), From a675143bd2418b8704d044cd42227eaf5b272324 Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Fri, 23 May 2025 16:14:54 +0200 Subject: [PATCH 7/9] chore: better dist normalization handling --- aeon/distances/kernel/_kdtw.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/aeon/distances/kernel/_kdtw.py b/aeon/distances/kernel/_kdtw.py index 902a741226..3320c79695 100644 --- a/aeon/distances/kernel/_kdtw.py +++ b/aeon/distances/kernel/_kdtw.py @@ -214,15 +214,14 @@ def _kdtw_distance( n = _x.shape[-1] - 1 m = _y.shape[-1] - 1 + current_cost = _kdtw_cost_matrix(_x, _y, gamma, epsilon)[n, m] if normalize_dist: self_x = _kdtw_cost_matrix(_x, _x, gamma, epsilon)[n, n] self_y = _kdtw_cost_matrix(_y, _y, gamma, epsilon)[m, m] norm_factor = np.sqrt(self_x * self_y) - if norm_factor < _eps: - norm_factor = 1.0 - else: - norm_factor = 1.0 - return 1.0 - _kdtw_cost_matrix(_x, _y, gamma, epsilon)[n, m] / norm_factor + if norm_factor != 0.0: + current_cost /= norm_factor + return 1.0 - current_cost @njit(cache=True, fastmath=True) From 8e351d36aa8a6337466598a1f7a0eb1759b9b16b Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Wed, 4 Jun 2025 11:00:12 +0200 Subject: [PATCH 8/9] feat: normalize kdtw by default and fix tests --- .../elastic/tests/test_cost_matrix.py | 20 +++++++ aeon/distances/kernel/_kdtw.py | 54 ++++++++++++++----- .../kernel/tests/test_distance_correctness.py | 23 ++++---- .../tests/test_numba_distance_parameters.py | 2 +- .../expected_distance_results.py | 14 ++--- 5 files changed, 81 insertions(+), 32 deletions(-) diff --git a/aeon/distances/elastic/tests/test_cost_matrix.py b/aeon/distances/elastic/tests/test_cost_matrix.py index 36c75d0c02..3c91faac43 100644 --- a/aeon/distances/elastic/tests/test_cost_matrix.py +++ b/aeon/distances/elastic/tests/test_cost_matrix.py @@ -66,6 +66,26 @@ def _validate_cost_matrix_result( assert_almost_equal(curr_distance, distance_result) elif name == "soft_dtw": assert_almost_equal(abs(cost_matrix_result[-1, -1]), distance_result) + elif name == "kdtw": + # distance is normalized by default, so we need to do this here as well: + from aeon.distances.kernel._kdtw import ( + _kdtw_cost_to_distance, + _normalize_time_series, + ) + + _x = x + _y = y + if x.ndim == 1: + _x = x.reshape((1, x.shape[0])) + if y.ndim == 1: + _y = y.reshape((1, y.shape[0])) + + _x = _normalize_time_series(_x) + _y = _normalize_time_series(_y) + d = _kdtw_cost_to_distance( + cost_matrix_result, _x, _y, gamma=0.125, epsilon=1e-20, normalize_dist=True + ) + assert_almost_equal(d, distance_result) else: assert_almost_equal(cost_matrix_result[-1, -1], distance_result) diff --git a/aeon/distances/kernel/_kdtw.py b/aeon/distances/kernel/_kdtw.py index 3320c79695..b6441b88d3 100644 --- a/aeon/distances/kernel/_kdtw.py +++ b/aeon/distances/kernel/_kdtw.py @@ -51,7 +51,7 @@ def kdtw_distance( gamma: float = 0.125, epsilon: float = 1e-20, normalize_input: bool = True, - normalize_dist: bool = False, + normalize_dist: bool = True, ) -> float: r"""Compute the DTW kernel (KDTW) between two time series as a distance. @@ -85,7 +85,7 @@ def kdtw_distance( normalize_input : bool, default=True Whether to normalize the time series' channels to zero mean and unit variance before computing the distance. Highly recommended! - normalize_dist : bool, default=False + normalize_dist : bool, default=True Whether to normalize the distance by the product of the self distances of x and y to avoid scaling effects and put the distance value between 0 and 1. @@ -212,12 +212,25 @@ def _kdtw_distance( _x = x _y = y - n = _x.shape[-1] - 1 - m = _y.shape[-1] - 1 - current_cost = _kdtw_cost_matrix(_x, _y, gamma, epsilon)[n, m] + cost_matrix = _kdtw_cost_matrix(_x, _y, gamma, epsilon) + return _kdtw_cost_to_distance(cost_matrix, _x, _y, gamma, epsilon, normalize_dist) + + +@njit(cache=True, fastmath=True) +def _kdtw_cost_to_distance( + cost_matrix: np.ndarray, + x: np.ndarray, + y: np.ndarray, + gamma: float, + epsilon: float, + normalize_dist: bool, +) -> float: + n = x.shape[-1] - 1 + m = y.shape[-1] - 1 + current_cost = cost_matrix[n, m] if normalize_dist: - self_x = _kdtw_cost_matrix(_x, _x, gamma, epsilon)[n, n] - self_y = _kdtw_cost_matrix(_y, _y, gamma, epsilon)[m, m] + self_x = _kdtw_cost_matrix(x, x, gamma, epsilon)[n, n] + self_y = _kdtw_cost_matrix(y, y, gamma, epsilon)[m, m] norm_factor = np.sqrt(self_x * self_y) if norm_factor != 0.0: current_cost /= norm_factor @@ -314,7 +327,7 @@ def kdtw_pairwise_distance( gamma: float = 0.125, epsilon: float = 1e-20, normalize_input: bool = True, - normalize_dist: bool = False, + normalize_dist: bool = True, ) -> np.ndarray: """Compute the KDTW pairwise distance between a set of time series. @@ -336,7 +349,7 @@ def kdtw_pairwise_distance( normalize_input : bool, default=True Whether to normalize the time series' channels to zero mean and unit variance before computing the distance. - normalize_dist : bool, default=False + normalize_dist : bool, default=True Whether to normalize the distance by the product of the self distances of x and y to avoid scaling effects and put the distance between 0 and 1. @@ -440,6 +453,7 @@ def kdtw_alignment_path( gamma: float = 0.125, epsilon: float = 1e-20, normalize_input: bool = True, + normalize_dist: bool = True, ) -> tuple[list[tuple[int, int]], float]: """Compute the kdtw alignment path between two time series. @@ -457,6 +471,9 @@ def kdtw_alignment_path( normalize_input : bool, default=True Whether to normalize the time series' channels to zero mean and unit variance before computing the distance. + normalize_dist : bool, default=True + Whether to normalize the distance by the product of the self distances of x and + y to avoid scaling effects and put the distance between 0 and 1. Returns ------- @@ -481,10 +498,21 @@ def kdtw_alignment_path( >>> kdtw_alignment_path(x, y) ([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)], 0.4191434232586494) """ - x_size = x.shape[-1] - 1 - y_size = y.shape[-1] - 1 - cost_matrix = kdtw_cost_matrix(x, y, gamma, epsilon, normalize_input) + _x = x + _y = y + if x.ndim == 1: + _x = x.reshape((1, x.shape[0])) + if y.ndim == 1: + _y = y.reshape((1, y.shape[0])) + if _x.ndim != 2 or _y.ndim != 2: + raise ValueError("x and y must be 1D or 2D") + + if normalize_input: + _x = _normalize_time_series(_x) + _y = _normalize_time_series(_y) + + cost_matrix = _kdtw_cost_matrix(_x, _y, gamma, epsilon) return ( compute_min_return_path(cost_matrix, larger_is_better=True), - cost_matrix[x_size, y_size], + _kdtw_cost_to_distance(cost_matrix, _x, _y, gamma, epsilon, normalize_dist), ) diff --git a/aeon/distances/kernel/tests/test_distance_correctness.py b/aeon/distances/kernel/tests/test_distance_correctness.py index 99424ff812..efd2caea28 100644 --- a/aeon/distances/kernel/tests/test_distance_correctness.py +++ b/aeon/distances/kernel/tests/test_distance_correctness.py @@ -1,8 +1,9 @@ -"""Test the distance calculations are correct. +"""Test the correctness of the KDTW distance calculations. Compare the distance calculations on the 1D and 2D (d,m) format input against the -results generated with the Octave implementation from -https://people.irisa.fr/Pierre-Francois.Marteau/REDK/KDTW/kdtw.m (adapted). +results generated with the Matlab/Octave implementation from +https://people.irisa.fr/Pierre-Francois.Marteau/REDK/KDTW/kdtw.m (adapted to support +multivariate time series). """ from numpy.testing import assert_array_almost_equal_nulp @@ -18,10 +19,10 @@ "kdtw": [1e-5, 1e-2, 1e-1], # gamma } basic_motions_distances = { - "kdtw": [1.0 - 2.7435e-123, 1.0 - 4.9180e-72, 1.0 - 6.5392e-53], + "kdtw": [1.0, 1.0, 1.0], } unit_test_distances = { - "kdtw": [1.0 - 6.6667e-21, 1.0 - 6.6667e-21, 1.0 - 6.6667e-21], + "kdtw": [1.0, 1.0, 0.9984255139339736], } @@ -36,8 +37,8 @@ def test_multivariate_correctness(): case1, case2, gamma=distance_parameters["kdtw"][j], - normalize_input=False, - normalize_dist=False, + normalize_input=True, + normalize_dist=True, ) assert_array_almost_equal_nulp(d, basic_motions_distances["kdtw"][j]) @@ -55,15 +56,15 @@ def test_univariate_correctness(): cases1[0], cases2[0], gamma=distance_parameters["kdtw"][j], - normalize_input=False, - normalize_dist=False, + normalize_input=True, + normalize_dist=True, ) d2 = kdtw_distance( cases1[1], cases2[1], gamma=distance_parameters["kdtw"][j], - normalize_input=False, - normalize_dist=False, + normalize_input=True, + normalize_dist=True, ) assert_array_almost_equal_nulp(d, unit_test_distances["kdtw"][j]) assert d == d2 diff --git a/aeon/distances/tests/test_numba_distance_parameters.py b/aeon/distances/tests/test_numba_distance_parameters.py index f963b27890..ad734b12de 100644 --- a/aeon/distances/tests/test_numba_distance_parameters.py +++ b/aeon/distances/tests/test_numba_distance_parameters.py @@ -134,7 +134,7 @@ def _test_distance_params( "soft_dtw": BASIC_BOUNDING_PARAMS + [{"gamma": 0.2}], "kdtw": [ {"gamma": 0.125, "epsilon": 1e-3}, - {"gamma": 0.125, "epsilon": 1e-3, "normalize_dist": True}, + {"gamma": 0.125, "epsilon": 1e-3, "normalize_dist": False}, ], } diff --git a/aeon/testing/expected_results/expected_distance_results.py b/aeon/testing/expected_results/expected_distance_results.py index 8355d9779e..bc11369904 100644 --- a/aeon/testing/expected_results/expected_distance_results.py +++ b/aeon/testing/expected_results/expected_distance_results.py @@ -126,11 +126,11 @@ 0.0, ], "kdtw": [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, ], } @@ -198,7 +198,7 @@ [0.6519267432870345, 5.491208968546096], ], "kdtw": [ - [0.0, 0.0], - [0.0, 0.0], + [1.0, 1.0], + [1.0, 1.0], ], } From 6d1b797f81e2088473a00cceb678a0e040864260 Mon Sep 17 00:00:00 2001 From: CodeLionX Date: Wed, 4 Jun 2025 13:54:51 +0200 Subject: [PATCH 9/9] fix: doctests --- aeon/distances/kernel/_kdtw.py | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/aeon/distances/kernel/_kdtw.py b/aeon/distances/kernel/_kdtw.py index b6441b88d3..ef4268af30 100644 --- a/aeon/distances/kernel/_kdtw.py +++ b/aeon/distances/kernel/_kdtw.py @@ -116,9 +116,9 @@ def kdtw_distance( >>> from aeon.distances import kdtw_distance >>> x = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) >>> y = np.array([[11, 12, 13, 14, 15, 16, 17, 18, 19, 20]]) + >>> kdtw_distance(x, y, normalize_dist=False) + 0.9860651016277014 >>> kdtw_distance(x, y) - 0.8051764348248271 - >>> kdtw_distance(x, y, normalize_dist=True) 0.0 """ if x.ndim == 1 and y.ndim == 1: @@ -174,11 +174,11 @@ def kdtw_cost_matrix( >>> x = np.array([[1, 2, 3, 4, 5]]) >>> y = np.array([[1, 2, 3, 4, 5]]) >>> kdtw_cost_matrix(x, y) - array([[1.11111111, 0.55555556, 0.24691358, 0.10288066, 0.04115226], - [0.55555556, 0.74074074, 0.44032922, 0.2345679 , 0.11522634], - [0.24691358, 0.44032922, 0.54046639, 0.35848194, 0.21688767], - [0.10288066, 0.2345679 , 0.35848194, 0.41914342, 0.30239293], - [0.04115226, 0.11522634, 0.21688767, 0.30239293, 0.34130976]]) + array([[1.11111111, 0.22232162, 0.08641977, 0.03292181, 0.01234568], + [0.22232162, 0.51858479, 0.20170132, 0.07820771, 0.02332192], + [0.08641977, 0.20170132, 0.30732914, 0.14910683, 0.03689383], + [0.03292181, 0.07820771, 0.14910683, 0.2018476 , 0.05442779], + [0.01234568, 0.02332192, 0.03689383, 0.05442779, 0.10356772]]) """ _x = x _y = y @@ -370,25 +370,25 @@ def kdtw_pairwise_distance( >>> from aeon.distances import kdtw_pairwise_distance >>> # Distance between each time series in a collection of time series >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) - >>> kdtw_pairwise_distance(X) - array([[0. , 0.45953361, 0.45953361], - [0.45953361, 0. , 0.45953361], - [0.45953361, 0.45953361, 0. ]]) + >>> kdtw_pairwise_distance(X, normalize_dist=False) + array([[0. , 0.73384612, 0.73384612], + [0.73384612, 0. , 0.73384612], + [0.73384612, 0.73384612, 0. ]] >>> # Distance between two collections of time series >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) - >>> y = np.array([[[11, 12, 13]],[[14, 15, 16]], [[17, 18, 19]]]) + >>> y = np.array([[[11, 12, 13, 14]],[[15, 16, 17, 18]], [[19, 20, 21, 22]]]) >>> kdtw_pairwise_distance(X, y) - array([[0.45953361, 0.45953361, 0.45953361], - [0.45953361, 0.45953361, 0.45953361], - [0.45953361, 0.45953361, 0.45953361]]) + array([[0.90035627, 0.90035627, 0.90035627], + [0.90035627, 0.90035627, 0.90035627], + [0.90035627, 0.90035627, 0.90035627]]) >>> X = np.array([[[1, 2, 3]],[[4, 5, 6]], [[7, 8, 9]]]) >>> y_univariate = np.array([[11, 12, 13],[14, 15, 16], [17, 18, 19]]) - >>> kdtw_pairwise_distance(X, y_univariate) - array([[0.45953361, 0.45953361, 0.45953361], - [0.45953361, 0.45953361, 0.45953361], - [0.45953361, 0.45953361, 0.45953361]]) + >>> kdtw_pairwise_distance(X, y_univariate, normalize_dist=False) + array([[0.73384612, 0.73384612, 0.73384612], + [0.73384612, 0.73384612, 0.73384612], + [0.73384612, 0.73384612, 0.73384612]]) """ multivariate_conversion = _is_numpy_list_multivariate(X, y) _X, _ = _convert_collection_to_numba_list(X, "X", multivariate_conversion) @@ -495,8 +495,8 @@ def kdtw_alignment_path( >>> from aeon.distances import kdtw_alignment_path >>> x = np.array([[1, 2, 3, 6]]) >>> y = np.array([[1, 2, 3, 4]]) - >>> kdtw_alignment_path(x, y) - ([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)], 0.4191434232586494) + >>> kdtw_alignment_path(x, y, normalize_dist=False) + ([(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)], 0.8393218410741822) """ _x = x _y = y