From 02988dbe1ab88d71d11bec98ffba6f79a98e732b Mon Sep 17 00:00:00 2001 From: tanishy7777 Date: Sun, 13 Apr 2025 18:30:35 +0530 Subject: [PATCH 1/2] Fixes sbd distances for multivariate case --- aeon/distances/_distance.py | 2 +- aeon/distances/_sbd.py | 23 ++++++++++++++----- .../expected_distance_results.py | 2 +- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/aeon/distances/_distance.py b/aeon/distances/_distance.py index 33f9141440..dccf1e0bd5 100644 --- a/aeon/distances/_distance.py +++ b/aeon/distances/_distance.py @@ -859,7 +859,7 @@ class DistanceType(Enum): "pairwise_distance": sbd_pairwise_distance, "type": DistanceType.CROSS_CORRELATION, "symmetric": True, - "unequal_support": True, + "unequal_support": False, }, { "name": "shift_scale", diff --git a/aeon/distances/_sbd.py b/aeon/distances/_sbd.py index 1097f27b5a..52cb5e5aa3 100644 --- a/aeon/distances/_sbd.py +++ b/aeon/distances/_sbd.py @@ -104,11 +104,14 @@ def sbd_distance(x: np.ndarray, y: np.ndarray, standardize: bool = True) -> floa return _univariate_sbd_distance(_x, _y, standardize) else: # independent (time series should have the same number of channels!) - nchannels = min(x.shape[0], y.shape[0]) - distance = 0.0 + if x.shape[0] != y.shape[0]: + raise ValueError("x and y must have the same number of channels ") + nchannels = x.shape[0] # both x and y have the same number of channels + norm = np.linalg.norm(x) * np.linalg.norm(y) + distance = np.zeros((2 * x.shape[1] - 1,)) for i in range(nchannels): - distance += _univariate_sbd_distance(x[i], y[i], standardize) - return distance / nchannels + distance += _helper_sbd(x[i], y[i], standardize) + return np.abs(1 - np.max(distance) / norm) raise ValueError("x and y must be 1D or 2D") @@ -240,8 +243,16 @@ def _univariate_sbd_distance(x: np.ndarray, y: np.ndarray, standardize: bool) -> x = (x - np.mean(x)) / np.std(x) y = (y - np.mean(y)) / np.std(y) - with objmode(a="float64[:]"): - a = correlate(x, y, method="fft") + a = _helper_sbd(x, y, standardize) b = np.sqrt(np.dot(x, x) * np.dot(y, y)) return np.abs(1.0 - np.max(a / b)) + + +@njit(cache=True, fastmath=True) +def _helper_sbd(x, y, standardize): + + with objmode(a="float64[:]"): + a = correlate(x, y, method="fft") + + return a diff --git a/aeon/testing/expected_results/expected_distance_results.py b/aeon/testing/expected_results/expected_distance_results.py index 7126c5c624..2145a3208f 100644 --- a/aeon/testing/expected_results/expected_distance_results.py +++ b/aeon/testing/expected_results/expected_distance_results.py @@ -115,7 +115,7 @@ 0.6617308353925114, 0.6617308353925114, 0.5750093257763462, - 0.5263609881742105, + None, 0.0, ], "shift_scale": [ From 18542e0bbe9be8dec67554afff917b47fc3ce1d0 Mon Sep 17 00:00:00 2001 From: tanishy7777 Date: Sun, 13 Apr 2025 19:10:25 +0530 Subject: [PATCH 2/2] Fixes doctest and changes expected distance with params --- aeon/distances/_sbd.py | 4 +++- aeon/testing/expected_results/expected_distance_results.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/aeon/distances/_sbd.py b/aeon/distances/_sbd.py index 52cb5e5aa3..e4860fe07f 100644 --- a/aeon/distances/_sbd.py +++ b/aeon/distances/_sbd.py @@ -107,7 +107,9 @@ def sbd_distance(x: np.ndarray, y: np.ndarray, standardize: bool = True) -> floa if x.shape[0] != y.shape[0]: raise ValueError("x and y must have the same number of channels ") nchannels = x.shape[0] # both x and y have the same number of channels - norm = np.linalg.norm(x) * np.linalg.norm(y) + norm = np.linalg.norm(x.astype(np.float64)) * np.linalg.norm( + y.astype(np.float64) + ) distance = np.zeros((2 * x.shape[1] - 1,)) for i in range(nchannels): distance += _helper_sbd(x[i], y[i], standardize) diff --git a/aeon/testing/expected_results/expected_distance_results.py b/aeon/testing/expected_results/expected_distance_results.py index 2145a3208f..953b8682cd 100644 --- a/aeon/testing/expected_results/expected_distance_results.py +++ b/aeon/testing/expected_results/expected_distance_results.py @@ -185,7 +185,7 @@ [8.602610210695161, 8.645028399102344], [1.750534284134988, 12.516745017325773], ], - "sbd": [[0.2435580798173309, 0.18613277150939772]], + "sbd": [[0.2435580798173309, 0.21430477859140418]], "shift_scale": [ [0.8103639073457298, 5.535457073146429], [0.6519267432870345, 5.491208968546096],