Skip to content

Commit 1813b4a

Browse files
authored
array API support for cosine_distances (scikit-learn#29265)
1 parent cc97b80 commit 1813b4a

File tree

5 files changed

+22
-2
lines changed

5 files changed

+22
-2
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ Metrics
123123
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
124124
- :func:`sklearn.metrics.pairwise.chi2_kernel`
125125
- :func:`sklearn.metrics.pairwise.cosine_similarity`
126+
- :func:`sklearn.metrics.pairwise.cosine_distances`
126127
- :func:`sklearn.metrics.pairwise.euclidean_distances` (see :ref:`device_support_for_float64`)
127128
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
128129
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)

doc/whats_new/v1.6.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ See :ref:`array_api` for more details.
4343
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
4444
- :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko <Tialo>`;
4545
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`;
46+
- :func:`sklearn.metrics.pairwise.cosine_distances` :pr:`29265` by :user:`Emily Chen <EmilyXinyi>`;
4647
- :func:`sklearn.metrics.pairwise.euclidean_distances` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`;
4748
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`;
4849
- :func:`sklearn.metrics.pairwise.rbf_kernel` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`.

sklearn/metrics/pairwise.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
gen_even_slices,
2323
)
2424
from ..utils._array_api import (
25+
_clip,
2526
_fill_or_add_to_diagonal,
2627
_find_matching_floating_dtype,
2728
_is_numpy_namespace,
@@ -1139,15 +1140,17 @@ def cosine_distances(X, Y=None):
11391140
array([[1. , 1. ],
11401141
[0.42..., 0.18...]])
11411142
"""
1143+
xp, _ = get_namespace(X, Y)
1144+
11421145
# 1.0 - cosine_similarity(X, Y) without copy
11431146
S = cosine_similarity(X, Y)
11441147
S *= -1
11451148
S += 1
1146-
np.clip(S, 0, 2, out=S)
1149+
S = _clip(S, 0, 2, xp)
11471150
if X is Y or Y is None:
11481151
# Ensure that distances between vectors and themselves are set to 0.0.
11491152
# This may not be the case due to floating point rounding errors.
1150-
np.fill_diagonal(S, 0.0)
1153+
_fill_or_add_to_diagonal(S, 0.0, xp, add_value=False)
11511154
return S
11521155

11531156

sklearn/metrics/tests/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from sklearn.metrics.pairwise import (
5555
additive_chi2_kernel,
5656
chi2_kernel,
57+
cosine_distances,
5758
cosine_similarity,
5859
euclidean_distances,
5960
paired_cosine_distances,
@@ -2016,6 +2017,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
20162017
mean_gamma_deviance: [check_array_api_regression_metric],
20172018
max_error: [check_array_api_regression_metric],
20182019
chi2_kernel: [check_array_api_metric_pairwise],
2020+
cosine_distances: [check_array_api_metric_pairwise],
20192021
euclidean_distances: [check_array_api_metric_pairwise],
20202022
rbf_kernel: [check_array_api_metric_pairwise],
20212023
}

sklearn/utils/_array_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,19 @@ def _nanmax(X, axis=None, xp=None):
791791
return X
792792

793793

794+
def _clip(S, min_val, max_val, xp):
795+
# TODO: remove this method and change all usage once we move to array api 2023.12
796+
# https://data-apis.org/array-api/2023.12/API_specification/generated/array_api.clip.html#clip
797+
if _is_numpy_namespace(xp):
798+
return numpy.clip(S, min_val, max_val)
799+
else:
800+
min_arr = xp.asarray(min_val, dtype=S.dtype)
801+
max_arr = xp.asarray(max_val, dtype=S.dtype)
802+
S = xp.where(S < min_arr, min_arr, S)
803+
S = xp.where(S > max_arr, max_arr, S)
804+
return S
805+
806+
794807
def _asarray_with_order(
795808
array, dtype=None, order=None, copy=None, *, xp=None, device=None
796809
):

0 commit comments

Comments
 (0)