Skip to content

Commit a1027b5

Browse files
Tialobetatim
andauthored
ENH Add Array API compatibility for additive_chi2_kernel (scikit-learn#29144)
Co-authored-by: Tim Head <betatim@gmail.com>
1 parent f10c171 commit a1027b5

File tree

4 files changed

+26
-8
lines changed

4 files changed

+26
-8
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ Metrics
119119
- :func:`sklearn.metrics.mean_gamma_deviance`
120120
- :func:`sklearn.metrics.mean_squared_error`
121121
- :func:`sklearn.metrics.mean_tweedie_deviance`
122+
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
122123
- :func:`sklearn.metrics.pairwise.cosine_similarity`
123124
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
124125
- :func:`sklearn.metrics.r2_score`

doc/whats_new/v1.6.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ See :ref:`array_api` for more details.
3838
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :usser:`Emily Chen <EmilyXinyi>`;
3939
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
4040
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
41-
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`.
41+
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
42+
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`;
4243
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`.
4344

4445
**Classes:**

sklearn/metrics/pairwise.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,7 +1718,7 @@ def additive_chi2_kernel(X, Y=None):
17181718
17191719
Returns
17201720
-------
1721-
kernel : ndarray of shape (n_samples_X, n_samples_Y)
1721+
kernel : array-like of shape (n_samples_X, n_samples_Y)
17221722
The kernel matrix.
17231723
17241724
See Also
@@ -1750,15 +1750,26 @@ def additive_chi2_kernel(X, Y=None):
17501750
array([[-1., -2.],
17511751
[-2., -1.]])
17521752
"""
1753+
xp, _ = get_namespace(X, Y)
17531754
X, Y = check_pairwise_arrays(X, Y, accept_sparse=False)
1754-
if (X < 0).any():
1755+
if xp.any(X < 0):
17551756
raise ValueError("X contains negative values.")
1756-
if Y is not X and (Y < 0).any():
1757+
if Y is not X and xp.any(Y < 0):
17571758
raise ValueError("Y contains negative values.")
17581759

1759-
result = np.zeros((X.shape[0], Y.shape[0]), dtype=X.dtype)
1760-
_chi2_kernel_fast(X, Y, result)
1761-
return result
1760+
if _is_numpy_namespace(xp):
1761+
result = np.zeros((X.shape[0], Y.shape[0]), dtype=X.dtype)
1762+
_chi2_kernel_fast(X, Y, result)
1763+
return result
1764+
else:
1765+
dtype = _find_matching_floating_dtype(X, Y, xp=xp)
1766+
xb = X[:, None, :]
1767+
yb = Y[None, :, :]
1768+
nom = -((xb - yb) ** 2)
1769+
denom = xb + yb
1770+
nom = xp.where(denom == 0, xp.asarray(0, dtype=dtype), nom)
1771+
denom = xp.where(denom == 0, xp.asarray(1, dtype=dtype), denom)
1772+
return xp.sum(nom / denom, axis=2)
17621773

17631774

17641775
@validate_params(

sklearn/metrics/tests/test_common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@
5151
zero_one_loss,
5252
)
5353
from sklearn.metrics._base import _average_binary_score
54-
from sklearn.metrics.pairwise import cosine_similarity, paired_cosine_distances
54+
from sklearn.metrics.pairwise import (
55+
additive_chi2_kernel,
56+
cosine_similarity,
57+
paired_cosine_distances,
58+
)
5559
from sklearn.preprocessing import LabelBinarizer
5660
from sklearn.utils import shuffle
5761
from sklearn.utils._array_api import (
@@ -1955,6 +1959,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
19551959
check_array_api_regression_metric,
19561960
],
19571961
paired_cosine_distances: [check_array_api_metric_pairwise],
1962+
additive_chi2_kernel: [check_array_api_metric_pairwise],
19581963
mean_gamma_deviance: [check_array_api_regression_metric],
19591964
max_error: [check_array_api_regression_metric],
19601965
}

0 commit comments

Comments
 (0)