Skip to content

[ENH]: Add kdtw distance implementation #2827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
10 changes: 10 additions & 0 deletions aeon/distances/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@
"soft_dtw_pairwise_distance",
"soft_dtw_alignment_path",
"soft_dtw_cost_matrix",
"kdtw_distance",
"kdtw_alignment_path",
"kdtw_cost_matrix",
"kdtw_pairwise_distance",
]

from aeon.distances._distance import (
Expand Down Expand Up @@ -157,6 +161,12 @@
wdtw_distance,
wdtw_pairwise_distance,
)
from aeon.distances.kernel import (
kdtw_alignment_path,
kdtw_cost_matrix,
kdtw_distance,
kdtw_pairwise_distance,
)
from aeon.distances.mindist._dft_sfa import mindist_dft_sfa_distance
from aeon.distances.mindist._paa_sax import mindist_paa_sax_distance
from aeon.distances.mindist._sax import mindist_sax_distance
Expand Down
25 changes: 25 additions & 0 deletions aeon/distances/_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@
wdtw_distance,
wdtw_pairwise_distance,
)
from aeon.distances.kernel import (
kdtw_alignment_path,
kdtw_cost_matrix,
kdtw_distance,
kdtw_pairwise_distance,
)
from aeon.distances.mindist import (
mindist_dft_sfa_distance,
mindist_dft_sfa_pairwise_distance,
Expand Down Expand Up @@ -109,6 +115,9 @@ class DistanceKwargs(TypedDict, total=False):
m: int
max_shift: Optional[int]
gamma: float
sigma: float
normalize_input: bool
normalize_dist: bool


DistanceFunction = Callable[[np.ndarray, np.ndarray, Any], float]
Expand Down Expand Up @@ -469,6 +478,7 @@ def get_distance_function(method: Union[str, DistanceFunction]) -> DistanceFunct
'sbd' distances.sbd_distance
'shift_scale' distances.shift_scale_invariant_distance
'soft_dtw' distances.soft_dtw_distance
'kdtw' distances.kdtw_distance
=============== ========================================

Parameters
Expand Down Expand Up @@ -528,6 +538,7 @@ def get_pairwise_distance_function(
'sbd' distances.sbd_pairwise_distance
'shift_scale' distances.shift_scale_invariant_pairwise_distance
'soft_dtw' distances.soft_dtw_pairwise_distance
'kdtw' distances.kdtw_pairwise_distance
=============== ========================================

Parameters
Expand Down Expand Up @@ -582,6 +593,7 @@ def get_alignment_path_function(method: str) -> AlignmentPathFunction:
'twe' distances.twe_alignment_path
'lcss' distances.lcss_alignment_path
'soft_dtw' distances.soft_dtw_alignment_path
'kdtw' distances.kdtw_alignment_path
=============== ========================================

Parameters
Expand Down Expand Up @@ -631,6 +643,7 @@ def get_cost_matrix_function(method: str) -> CostMatrixFunction:
'twe' distances.twe_cost_matrix
'lcss' distances.lcss_cost_matrix
'soft_dtw' distances.soft_dtw_cost_matrix
'kdtw' distances.kdtw_cost_matrix
=============== ========================================

Parameters
Expand Down Expand Up @@ -685,6 +698,7 @@ class DistanceType(Enum):

POINTWISE = "pointwise"
ELASTIC = "elastic"
KERNEL = "kernel"
CROSS_CORRELATION = "cross-correlation"
MIN_DISTANCE = "min-dist"
MATRIX_PROFILE = "matrix-profile"
Expand Down Expand Up @@ -909,6 +923,16 @@ class DistanceType(Enum):
"symmetric": True,
"unequal_support": True,
},
{
"name": "kdtw",
"distance": kdtw_distance,
"pairwise_distance": kdtw_pairwise_distance,
"cost_matrix": kdtw_cost_matrix,
"alignment_path": kdtw_alignment_path,
"type": DistanceType.KERNEL,
"symmetric": True,
"unequal_support": True,
},
]

DISTANCES_DICT = {d["name"]: d for d in DISTANCES}
Expand All @@ -922,6 +946,7 @@ class DistanceType(Enum):
]

ELASTIC_DISTANCES = [d["name"] for d in DISTANCES if d["type"] == DistanceType.ELASTIC]
KERNEL_DISTANCES = [d["name"] for d in DISTANCES if d["type"] == DistanceType.KERNEL]
POINTWISE_DISTANCES = [
d["name"] for d in DISTANCES if d["type"] == DistanceType.POINTWISE
]
Expand Down
24 changes: 15 additions & 9 deletions aeon/distances/elastic/_alignment_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@


@njit(cache=True, fastmath=True)
def compute_min_return_path(cost_matrix: np.ndarray) -> list[tuple]:
def compute_min_return_path(
cost_matrix: np.ndarray, larger_is_better: bool = False
) -> list[tuple]:
"""Compute the minimum return path through a cost matrix.

Parameters
----------
cost_matrix : np.ndarray, of shape (n_timepoints_x, n_timepoints_y)
Cost matrix.
larger_is_better : bool, default=False
If True, the path will be computed for the maximum cost instead of the minimum.

Returns
-------
Expand All @@ -32,15 +36,17 @@ def compute_min_return_path(cost_matrix: np.ndarray) -> list[tuple]:
elif j == 0:
i -= 1
else:
min_index = np.argmin(
np.array(
[
cost_matrix[i - 1, j - 1],
cost_matrix[i - 1, j],
cost_matrix[i, j - 1],
]
)
costs = np.array(
[
cost_matrix[i - 1, j - 1],
cost_matrix[i - 1, j],
cost_matrix[i, j - 1],
]
)
if larger_is_better:
min_index = np.argmax(costs)
else:
min_index = np.argmin(costs)

if min_index == 0:
i, j = i - 1, j - 1
Expand Down
20 changes: 20 additions & 0 deletions aeon/distances/elastic/tests/test_cost_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ def _validate_cost_matrix_result(
assert_almost_equal(curr_distance, distance_result)
elif name == "soft_dtw":
assert_almost_equal(abs(cost_matrix_result[-1, -1]), distance_result)
elif name == "kdtw":
# distance is normalized by default, so we need to do this here as well:
from aeon.distances.kernel._kdtw import (
_kdtw_cost_to_distance,
_normalize_time_series,
)

_x = x
_y = y
if x.ndim == 1:
_x = x.reshape((1, x.shape[0]))
if y.ndim == 1:
_y = y.reshape((1, y.shape[0]))

_x = _normalize_time_series(_x)
_y = _normalize_time_series(_y)
d = _kdtw_cost_to_distance(
cost_matrix_result, _x, _y, gamma=0.125, epsilon=1e-20, normalize_dist=True
)
assert_almost_equal(d, distance_result)
else:
assert_almost_equal(cost_matrix_result[-1, -1], distance_result)

Expand Down
15 changes: 15 additions & 0 deletions aeon/distances/kernel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Kernel distances."""

__all__ = [
"kdtw_distance",
"kdtw_alignment_path",
"kdtw_cost_matrix",
"kdtw_pairwise_distance",
]

from aeon.distances.kernel._kdtw import (
kdtw_alignment_path,
kdtw_cost_matrix,
kdtw_distance,
kdtw_pairwise_distance,
)
Loading
Loading