Skip to content

Commit dc6c01c

Browse files
authored
array API support for mean_absolute_percentage_error (scikit-learn#29300)
1 parent 1813b4a commit dc6c01c

File tree

4 files changed

+21
-6
lines changed

4 files changed

+21
-6
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Metrics
117117
- :func:`sklearn.metrics.d2_tweedie_score`
118118
- :func:`sklearn.metrics.max_error`
119119
- :func:`sklearn.metrics.mean_absolute_error`
120+
- :func:`sklearn.metrics.mean_absolute_percentage_error`
120121
- :func:`sklearn.metrics.mean_gamma_deviance`
121122
- :func:`sklearn.metrics.mean_squared_error`
122123
- :func:`sklearn.metrics.mean_tweedie_deviance`

doc/whats_new/v1.6.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ See :ref:`array_api` for more details.
3737
- :func:`sklearn.metrics.max_error` :pr:`29212` by :user:`Edoardo Abati <EdAbati>`;
3838
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`
3939
and :pr:`29143` by :user:`Tialo <Tialo>` and :user:`Loïc Estève <lesteve>`;
40-
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :usser:`Emily Chen <EmilyXinyi>`;
40+
- :func:`sklearn.metrics.mean_absolute_percentage_error` :pr:`29300` by :user:`Emily Chen <EmilyXinyi>`;
41+
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :user:`Emily Chen <EmilyXinyi>`;
4142
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
4243
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
4344
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;

sklearn/metrics/_regression.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,21 +395,30 @@ def mean_absolute_percentage_error(
395395
>>> mean_absolute_percentage_error(y_true, y_pred)
396396
112589990684262.48
397397
"""
398+
input_arrays = [y_true, y_pred, sample_weight, multioutput]
399+
xp, _ = get_namespace(*input_arrays)
400+
dtype = _find_matching_floating_dtype(y_true, y_pred, sample_weight, xp=xp)
401+
398402
y_type, y_true, y_pred, multioutput = _check_reg_targets(
399403
y_true, y_pred, multioutput
400404
)
401405
check_consistent_length(y_true, y_pred, sample_weight)
402-
epsilon = np.finfo(np.float64).eps
403-
mape = np.abs(y_pred - y_true) / np.maximum(np.abs(y_true), epsilon)
404-
output_errors = np.average(mape, weights=sample_weight, axis=0)
406+
epsilon = xp.asarray(xp.finfo(xp.float64).eps, dtype=dtype)
407+
y_true_abs = xp.asarray(xp.abs(y_true), dtype=dtype)
408+
mape = xp.asarray(xp.abs(y_pred - y_true), dtype=dtype) / xp.maximum(
409+
y_true_abs, epsilon
410+
)
411+
output_errors = _average(mape, weights=sample_weight, axis=0)
405412
if isinstance(multioutput, str):
406413
if multioutput == "raw_values":
407414
return output_errors
408415
elif multioutput == "uniform_average":
409-
# pass None as weights to np.average: uniform mean
416+
# pass None as weights to _average: uniform mean
410417
multioutput = None
411418

412-
return np.average(output_errors, weights=multioutput)
419+
mean_absolute_percentage_error = _average(output_errors, weights=multioutput)
420+
assert mean_absolute_percentage_error.shape == ()
421+
return float(mean_absolute_percentage_error)
413422

414423

415424
@validate_params(

sklearn/metrics/tests/test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,6 +2016,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
20162016
additive_chi2_kernel: [check_array_api_metric_pairwise],
20172017
mean_gamma_deviance: [check_array_api_regression_metric],
20182018
max_error: [check_array_api_regression_metric],
2019+
mean_absolute_percentage_error: [
2020+
check_array_api_regression_metric,
2021+
check_array_api_regression_metric_multioutput,
2022+
],
20192023
chi2_kernel: [check_array_api_metric_pairwise],
20202024
cosine_distances: [check_array_api_metric_pairwise],
20212025
euclidean_distances: [check_array_api_metric_pairwise],

0 commit comments

Comments
 (0)