Skip to content

Commit bd8f5bd

Browse files
authored
ENH Add array_api compatibility to max_error (scikit-learn#29212)
1 parent 7e8ad63 commit bd8f5bd

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ Metrics
114114

115115
- :func:`sklearn.metrics.accuracy_score`
116116
- :func:`sklearn.metrics.d2_tweedie_score`
117+
- :func:`sklearn.metrics.max_error`
117118
- :func:`sklearn.metrics.mean_absolute_error`
118119
- :func:`sklearn.metrics.mean_squared_error`
119120
- :func:`sklearn.metrics.mean_tweedie_deviance`

doc/whats_new/v1.6.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ See :ref:`array_api` for more details.
3333
**Functions:**
3434

3535
- :func:`sklearn.metrics.d2_tweedie_score` :pr:`29207` by :user:`Emily Chen <EmilyXinyi>`;
36+
- :func:`sklearn.metrics.max_error` :pr:`29212` by :user:`Edoardo Abati <EdAbati>`;
3637
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`;
3738
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
3839
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
3940
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`.
4041
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`.
4142

42-
4343
**Classes:**
4444

4545
- :class:`preprocessing.LabelEncoder` now supports Array API compatible inputs.

sklearn/metrics/_regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1290,10 +1290,11 @@ def max_error(y_true, y_pred):
12901290
>>> max_error(y_true, y_pred)
12911291
1
12921292
"""
1293+
xp, _ = get_namespace(y_true, y_pred)
12931294
y_type, y_true, y_pred, _ = _check_reg_targets(y_true, y_pred, None)
12941295
if y_type == "continuous-multioutput":
12951296
raise ValueError("Multioutput not supported in max_error")
1296-
return np.max(np.abs(y_true - y_pred))
1297+
return xp.max(xp.abs(y_true - y_pred))
12971298

12981299

12991300
def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):

sklearn/metrics/tests/test_common.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1827,20 +1827,14 @@ def check_array_api_multiclass_classification_metric(
18271827

18281828

18291829
def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
1830-
y_true_np = np.array([2, 0, 1, 4], dtype=dtype_name)
1830+
y_true_np = np.array([2.0, 0.0, 1.0, 4.0], dtype=dtype_name)
18311831
y_pred_np = np.array([0.5, 0.5, 2, 2], dtype=dtype_name)
18321832

1833-
check_array_api_metric(
1834-
metric,
1835-
array_namespace,
1836-
device,
1837-
dtype_name,
1838-
a_np=y_true_np,
1839-
b_np=y_pred_np,
1840-
sample_weight=None,
1841-
)
1833+
metric_kwargs = {}
1834+
metric_params = signature(metric).parameters
18421835

1843-
sample_weight = np.array([0.1, 2.0, 1.5, 0.5], dtype=dtype_name)
1836+
if "sample_weight" in metric_params:
1837+
metric_kwargs["sample_weight"] = None
18441838

18451839
check_array_api_metric(
18461840
metric,
@@ -1849,9 +1843,24 @@ def check_array_api_regression_metric(metric, array_namespace, device, dtype_nam
18491843
dtype_name,
18501844
a_np=y_true_np,
18511845
b_np=y_pred_np,
1852-
sample_weight=sample_weight,
1846+
**metric_kwargs,
18531847
)
18541848

1849+
if "sample_weight" in metric_params:
1850+
metric_kwargs["sample_weight"] = np.array(
1851+
[0.1, 2.0, 1.5, 0.5], dtype=dtype_name
1852+
)
1853+
1854+
check_array_api_metric(
1855+
metric,
1856+
array_namespace,
1857+
device,
1858+
dtype_name,
1859+
a_np=y_true_np,
1860+
b_np=y_pred_np,
1861+
**metric_kwargs,
1862+
)
1863+
18551864

18561865
def check_array_api_regression_metric_multioutput(
18571866
metric, array_namespace, device, dtype_name
@@ -1946,6 +1955,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
19461955
check_array_api_regression_metric,
19471956
],
19481957
paired_cosine_distances: [check_array_api_metric_pairwise],
1958+
max_error: [check_array_api_regression_metric],
19491959
}
19501960

19511961

0 commit comments

Comments
 (0)