Skip to content

Commit 20b8d0b

Browse files
lucyleeowogriselbetatim
authored
Add array API tests for pairwise_distances (scikit-learn#31658)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Tim Head <betatim@gmail.com>
1 parent 20d33d5 commit 20b8d0b

File tree

4 files changed

+48
-1
lines changed

4 files changed

+48
-1
lines changed

doc/modules/array_api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,12 @@ Metrics
156156
- :func:`sklearn.metrics.pairwise.chi2_kernel`
157157
- :func:`sklearn.metrics.pairwise.cosine_similarity`
158158
- :func:`sklearn.metrics.pairwise.cosine_distances`
159+
- :func:`sklearn.metrics.pairwise.pairwise_distances` (only supports "cosine", "euclidean" and "l2" metrics)
159160
- :func:`sklearn.metrics.pairwise.euclidean_distances` (see :ref:`device_support_for_float64`)
160161
- :func:`sklearn.metrics.pairwise.linear_kernel`
161162
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
162163
- :func:`sklearn.metrics.pairwise.paired_euclidean_distances`
163-
- :func:`sklearn.metrics.pairwise.pairwise_kernels` (supports all metrics except :func:`sklearn.metrics.pairwise.laplacian_kernel`)
164+
- :func:`sklearn.metrics.pairwise.pairwise_kernels` (supports all `sklearn.pairwise.PAIRWISE_KERNEL_FUNCTIONS` except :func:`sklearn.metrics.pairwise.laplacian_kernel`)
164165
- :func:`sklearn.metrics.pairwise.polynomial_kernel`
165166
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)
166167
- :func:`sklearn.metrics.pairwise.sigmoid_kernel`

doc/whats_new/upcoming_changes/sklearn.metrics/29822.enhancement.rst renamed to doc/whats_new/upcoming_changes/array-api/29822.enhancement.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,8 @@
22
compatible inputs, when the underling `metric` does (the only metric NOT currently
33
supported is :func:`sklearn.metrics.pairwise.laplacian_kernel`).
44
By :user:`Emily Chen <EmilyXinyi>` and :user:`Lucy Liu <lucyleeow>`.
5+
6+
- :func:`metrics.pairwise.pairwise_distances` now supports Array API
7+
compatible inputs, when the underlying `metric` does (currently
8+
"cosine", "euclidean" and "l2").
9+
By :user:`Emily Chen <EmilyXinyi>` and :user:`Lucy Liu <lucyleeow>`.

sklearn/metrics/tests/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
linear_kernel,
6666
paired_cosine_distances,
6767
paired_euclidean_distances,
68+
pairwise_distances,
6869
pairwise_kernels,
6970
polynomial_kernel,
7071
rbf_kernel,
@@ -2282,6 +2283,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
22822283
roc_curve: [
22832284
check_array_api_binary_classification_metric,
22842285
],
2286+
pairwise_distances: [check_array_api_metric_pairwise],
22852287
}
22862288

22872289

sklearn/metrics/tests/test_pairwise.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,45 @@ def test_pairwise_distances_for_dense_data(global_dtype):
151151
assert_allclose(S, S2)
152152

153153

154+
@pytest.mark.parametrize(
155+
"array_namespace, device, dtype_name",
156+
yield_namespace_device_dtype_combinations(),
157+
ids=_get_namespace_device_dtype_ids,
158+
)
159+
@pytest.mark.parametrize("metric", ["cosine", "euclidean"])
160+
def test_pairwise_distances_array_api(array_namespace, device, dtype_name, metric):
161+
# Test array API support in pairwise_distances.
162+
xp = _array_api_for_tests(array_namespace, device)
163+
164+
rng = np.random.RandomState(0)
165+
# Euclidean distance should be equivalent to calling the function.
166+
X_np = rng.random_sample((5, 4)).astype(dtype_name, copy=False)
167+
Y_np = rng.random_sample((5, 4)).astype(dtype_name, copy=False)
168+
X_xp = xp.asarray(X_np, device=device)
169+
Y_xp = xp.asarray(Y_np, device=device)
170+
171+
with config_context(array_api_dispatch=True):
172+
# Test with Y=None
173+
D_xp = pairwise_distances(X_xp, metric=metric)
174+
D_xp_np = _convert_to_numpy(D_xp, xp=xp)
175+
assert get_namespace(D_xp)[0].__name__ == xp.__name__
176+
assert D_xp.device == X_xp.device
177+
assert D_xp.dtype == X_xp.dtype
178+
179+
D_np = pairwise_distances(X_np, metric=metric)
180+
assert_allclose(D_xp_np, D_np)
181+
182+
# Test with Y=Y_np/Y_xp
183+
D_xp = pairwise_distances(X_xp, Y=Y_xp, metric=metric)
184+
D_xp_np = _convert_to_numpy(D_xp, xp=xp)
185+
assert get_namespace(D_xp)[0].__name__ == xp.__name__
186+
assert D_xp.device == X_xp.device
187+
assert D_xp.dtype == X_xp.dtype
188+
189+
D_np = pairwise_distances(X_np, Y=Y_np, metric=metric)
190+
assert_allclose(D_xp_np, D_np)
191+
192+
154193
@pytest.mark.parametrize("coo_container", COO_CONTAINERS)
155194
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
156195
@pytest.mark.parametrize("bsr_container", BSR_CONTAINERS)

0 commit comments

Comments
 (0)