Skip to content

Commit 5c21794

Browse files
Add array API support to median_absolute_error (scikit-learn#31406)
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent 9062822 commit 5c21794

File tree

7 files changed

+102
-10
lines changed

7 files changed

+102
-10
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ Metrics
149149
- :func:`sklearn.metrics.mean_squared_error`
150150
- :func:`sklearn.metrics.mean_squared_log_error`
151151
- :func:`sklearn.metrics.mean_tweedie_deviance`
152+
- :func:`sklearn.metrics.median_absolute_error`
152153
- :func:`sklearn.metrics.multilabel_confusion_matrix`
153154
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
154155
- :func:`sklearn.metrics.pairwise.chi2_kernel`
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`metrics.median_absolute_error` now supports Array API compatible inputs.
2+
By :user:`Lucy Liu <lucyleeow>`.

sklearn/metrics/_regression.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ..utils._array_api import (
2020
_average,
2121
_find_matching_floating_dtype,
22+
_median,
2223
get_namespace,
2324
get_namespace_and_device,
2425
size,
@@ -915,14 +916,15 @@ def median_absolute_error(
915916
>>> median_absolute_error(y_true, y_pred, multioutput=[0.3, 0.7])
916917
0.85
917918
"""
919+
xp, _ = get_namespace(y_true, y_pred, multioutput, sample_weight)
918920
_, y_true, y_pred, sample_weight, multioutput = _check_reg_targets(
919921
y_true, y_pred, sample_weight, multioutput
920922
)
921923
if sample_weight is None:
922-
output_errors = np.median(np.abs(y_pred - y_true), axis=0)
924+
output_errors = _median(xp.abs(y_pred - y_true), axis=0)
923925
else:
924926
output_errors = _weighted_percentile(
925-
np.abs(y_pred - y_true), sample_weight=sample_weight
927+
xp.abs(y_pred - y_true), sample_weight=sample_weight
926928
)
927929
if isinstance(multioutput, str):
928930
if multioutput == "raw_values":
@@ -931,7 +933,7 @@ def median_absolute_error(
931933
# pass None as weights to np.average: uniform mean
932934
multioutput = None
933935

934-
return float(np.average(output_errors, weights=multioutput))
936+
return float(_average(output_errors, weights=multioutput))
935937

936938

937939
def _assemble_r2_explained_variance(

sklearn/metrics/tests/test_common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,6 +2231,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
22312231
check_array_api_regression_metric,
22322232
check_array_api_regression_metric_multioutput,
22332233
],
2234+
median_absolute_error: [
2235+
check_array_api_regression_metric,
2236+
check_array_api_regression_metric_multioutput,
2237+
],
22342238
d2_tweedie_score: [
22352239
check_array_api_regression_metric,
22362240
],
@@ -2275,6 +2279,23 @@ def yield_metric_checker_combinations(metric_checkers=array_api_metric_checkers)
22752279
)
22762280
@pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
22772281
def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
2282+
# TODO: Remove once array-api-strict > 2.3.1
2283+
# https://github.com/data-apis/array-api-strict/issues/134 has been fixed but
2284+
# not released yet.
2285+
if (
2286+
getattr(metric, "__name__", None) == "median_absolute_error"
2287+
and array_namespace == "array_api_strict"
2288+
):
2289+
try:
2290+
import array_api_strict
2291+
except ImportError:
2292+
pass
2293+
else:
2294+
if device == array_api_strict.Device("device1"):
2295+
pytest.xfail(
2296+
"`_weighted_percentile` is affected by array_api_strict bug when "
2297+
"indexing with tuple of arrays on non-'CPU_DEVICE' devices."
2298+
)
22782299
check_func(metric, array_namespace, device, dtype_name)
22792300

22802301

sklearn/utils/_array_api.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,30 @@ def _average(a, axis=None, weights=None, normalize=True, xp=None):
669669
return sum_ / scale
670670

671671

672+
def _median(x, axis=None, keepdims=False, xp=None):
673+
# XXX: `median` is not included in the array API spec, but is implemented
674+
# in most array libraries, and all that we support (as of May 2025).
675+
# TODO: consider simplifying this code to use scipy instead once the oldest
676+
# supported SciPy version provides `scipy.stats.quantile` with native array API
677+
# support (likely scipy 1.6 at the time of writing). Proper benchmarking of
678+
# either option with popular array namespaces is required to evaluate the
679+
# impact of this choice.
680+
xp, _, device = get_namespace_and_device(x, xp=xp)
681+
682+
# `torch.median` takes the lower of the two medians when `x` has even number
683+
# of elements, thus we use `torch.quantile(q=0.5)`, which gives mean of the two
684+
if array_api_compat.is_torch_namespace(xp):
685+
return xp.quantile(x, q=0.5, dim=axis, keepdim=keepdims)
686+
687+
if hasattr(xp, "median"):
688+
return xp.median(x, axis=axis, keepdims=keepdims)
689+
690+
# Intended mostly for array-api-strict (which as no "median", as per the spec)
691+
# as `_convert_to_numpy` does not necessarily work for all array types.
692+
x_np = _convert_to_numpy(x, xp=xp)
693+
return xp.asarray(numpy.median(x_np, axis=axis, keepdims=keepdims), device=device)
694+
695+
672696
def _xlogy(x, y, xp=None):
673697
# TODO: Remove this once https://github.com/scipy/scipy/issues/21736 is fixed
674698
xp, _, device_ = get_namespace_and_device(x, y, xp=xp)

sklearn/utils/tests/test_array_api.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_is_numpy_namespace,
2020
_isin,
2121
_max_precision_float_dtype,
22+
_median,
2223
_nanmax,
2324
_nanmean,
2425
_nanmin,
@@ -603,3 +604,33 @@ def test_sparse_device(csr_container, dispatch):
603604
assert device(a, numpy.array([1])) is None
604605
assert get_namespace_and_device(a, b)[2] is None
605606
assert get_namespace_and_device(a, numpy.array([1]))[2] is None
607+
608+
609+
@pytest.mark.parametrize(
610+
"namespace, device, dtype_name",
611+
yield_namespace_device_dtype_combinations(),
612+
ids=_get_namespace_device_dtype_ids,
613+
)
614+
@pytest.mark.parametrize("axis", [None, 0, 1])
615+
def test_median(namespace, device, dtype_name, axis):
616+
# Note: depending on the value of `axis`, this test will compare median
617+
# computations on arrays of even (4) or odd (5) numbers of elements, hence
618+
# will test for median computation with and without interpolation to check
619+
# that array API namespaces yield consistent results even when the median is
620+
# not mathematically uniquely defined.
621+
xp = _array_api_for_tests(namespace, device)
622+
rng = numpy.random.RandomState(0)
623+
624+
X_np = rng.uniform(low=0.0, high=1.0, size=(5, 4)).astype(dtype_name)
625+
result_np = numpy.median(X_np, axis=axis)
626+
627+
X_xp = xp.asarray(X_np, device=device)
628+
with config_context(array_api_dispatch=True):
629+
result_xp = _median(X_xp, axis=axis)
630+
631+
if xp.__name__ != "array_api_strict":
632+
# We covert array-api-strict arrays to numpy arrays as `median` is not
633+
# part of the Array API spec
634+
assert get_namespace(result_xp)[0] == xp
635+
assert result_xp.device == X_xp.device
636+
assert_allclose(result_np, _convert_to_numpy(result_xp, xp=xp))

sklearn/utils/validation.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818

1919
from .. import get_config as _get_config
2020
from ..exceptions import DataConversionWarning, NotFittedError, PositiveSpectrumWarning
21-
from ..utils._array_api import _asarray_with_order, _is_numpy_namespace, get_namespace
21+
from ..utils._array_api import (
22+
_asarray_with_order,
23+
_is_numpy_namespace,
24+
_max_precision_float_dtype,
25+
get_namespace,
26+
get_namespace_and_device,
27+
)
2228
from ..utils.deprecation import _deprecate_force_all_finite
2329
from ..utils.fixes import ComplexWarning, _preserve_dia_indices_dtype
2430
from ._isfinite import FiniteStatus, cy_isfinite
@@ -390,7 +396,8 @@ def _num_samples(x):
390396

391397
if not hasattr(x, "__len__") and not hasattr(x, "shape"):
392398
if hasattr(x, "__array__"):
393-
x = np.asarray(x)
399+
xp, _ = get_namespace(x)
400+
x = xp.asarray(x)
394401
else:
395402
raise TypeError(message)
396403

@@ -2167,20 +2174,24 @@ def _check_sample_weight(
21672174
sample_weight : ndarray of shape (n_samples,)
21682175
Validated sample weight. It is guaranteed to be "C" contiguous.
21692176
"""
2170-
n_samples = _num_samples(X)
2177+
xp, _, device = get_namespace_and_device(sample_weight, X)
21712178

2172-
xp, _ = get_namespace(X)
2179+
n_samples = _num_samples(X)
21732180

2174-
if dtype is not None and dtype not in [xp.float32, xp.float64]:
2175-
dtype = xp.float64
2181+
max_float_type = _max_precision_float_dtype(xp, device)
2182+
float_dtypes = (
2183+
[xp.float32] if max_float_type == xp.float32 else [xp.float64, xp.float32]
2184+
)
2185+
if dtype is not None and dtype not in float_dtypes:
2186+
dtype = max_float_type
21762187

21772188
if sample_weight is None:
21782189
sample_weight = xp.ones(n_samples, dtype=dtype)
21792190
elif isinstance(sample_weight, numbers.Number):
21802191
sample_weight = xp.full(n_samples, sample_weight, dtype=dtype)
21812192
else:
21822193
if dtype is None:
2183-
dtype = [xp.float64, xp.float32]
2194+
dtype = float_dtypes
21842195
sample_weight = check_array(
21852196
sample_weight,
21862197
accept_sparse=False,

0 commit comments

Comments
 (0)