Skip to content

Commit d03054b

Browse files
authored
FIX Remove median_absolute_error from METRICS_WITHOUT_SAMPLE_WEIGHT (scikit-learn#30787)
1 parent aa8b113 commit d03054b

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- :func:`metrics.median_absolute_error` now uses `_averaged_weighted_percentile`
2+
instead of `_weighted_percentile` to calculate median when `sample_weight` is not
3+
`None`. This is equivalent to using the "averaged_inverted_cdf" instead of
4+
the "inverted_cdf" quantile method, which gives results equivalent to `numpy.median`
5+
if equal weights used.
6+
By :user:`Lucy Liu <lucyleeow>`

sklearn/metrics/_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
_xlogy as xlogy,
2929
)
3030
from ..utils._param_validation import Interval, StrOptions, validate_params
31-
from ..utils.stats import _weighted_percentile
31+
from ..utils.stats import _averaged_weighted_percentile, _weighted_percentile
3232
from ..utils.validation import (
3333
_check_sample_weight,
3434
_num_samples,
@@ -923,7 +923,7 @@ def median_absolute_error(
923923
if sample_weight is None:
924924
output_errors = _median(xp.abs(y_pred - y_true), axis=0)
925925
else:
926-
output_errors = _weighted_percentile(
926+
output_errors = _averaged_weighted_percentile(
927927
xp.abs(y_pred - y_true), sample_weight=sample_weight
928928
)
929929
if isinstance(multioutput, str):

sklearn/metrics/tests/test_common.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,6 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
555555

556556
# No Sample weight support
557557
METRICS_WITHOUT_SAMPLE_WEIGHT = {
558-
"median_absolute_error",
559558
"max_error",
560559
"ovo_roc_auc",
561560
"weighted_ovo_roc_auc",
@@ -1474,9 +1473,10 @@ def test_averaging_multilabel_all_ones(name):
14741473
check_averaging(name, y_true, y_true_binarize, y_pred, y_pred_binarize, y_score)
14751474

14761475

1477-
def check_sample_weight_invariance(name, metric, y1, y2):
1476+
def check_sample_weight_invariance(name, metric, y1, y2, sample_weight=None):
14781477
rng = np.random.RandomState(0)
1479-
sample_weight = rng.randint(1, 10, size=len(y1))
1478+
if sample_weight is None:
1479+
sample_weight = rng.randint(1, 10, size=len(y1))
14801480

14811481
# top_k_accuracy_score always lead to a perfect score for k > 1 in the
14821482
# binary case
@@ -1552,7 +1552,10 @@ def check_sample_weight_invariance(name, metric, y1, y2):
15521552
if not name.startswith("unnormalized"):
15531553
# check that the score is invariant under scaling of the weights by a
15541554
# common factor
1555-
for scaling in [2, 0.3]:
1555+
# Due to numerical instability of floating points in `cumulative_sum` in
1556+
# `median_absolute_error`, it is not always equivalent when scaling by a float.
1557+
scaling_values = [2] if name == "median_absolute_error" else [2, 0.3]
1558+
for scaling in scaling_values:
15561559
assert_allclose(
15571560
weighted_score,
15581561
metric(y1, y2, sample_weight=sample_weight * scaling),
@@ -1584,8 +1587,10 @@ def test_regression_sample_weight_invariance(name):
15841587
# regression
15851588
y_true = random_state.random_sample(size=(n_samples,))
15861589
y_pred = random_state.random_sample(size=(n_samples,))
1590+
sample_weight = np.arange(len(y_true))
15871591
metric = ALL_METRICS[name]
1588-
check_sample_weight_invariance(name, metric, y_true, y_pred)
1592+
1593+
check_sample_weight_invariance(name, metric, y_true, y_pred, sample_weight)
15891594

15901595

15911596
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)