@@ -555,7 +555,6 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
555
555
556
556
# No Sample weight support
557
557
METRICS_WITHOUT_SAMPLE_WEIGHT = {
558
- "median_absolute_error" ,
559
558
"max_error" ,
560
559
"ovo_roc_auc" ,
561
560
"weighted_ovo_roc_auc" ,
@@ -1474,9 +1473,10 @@ def test_averaging_multilabel_all_ones(name):
1474
1473
check_averaging (name , y_true , y_true_binarize , y_pred , y_pred_binarize , y_score )
1475
1474
1476
1475
1477
- def check_sample_weight_invariance (name , metric , y1 , y2 ):
1476
+ def check_sample_weight_invariance (name , metric , y1 , y2 , sample_weight = None ):
1478
1477
rng = np .random .RandomState (0 )
1479
- sample_weight = rng .randint (1 , 10 , size = len (y1 ))
1478
+ if sample_weight is None :
1479
+ sample_weight = rng .randint (1 , 10 , size = len (y1 ))
1480
1480
1481
1481
# top_k_accuracy_score always lead to a perfect score for k > 1 in the
1482
1482
# binary case
@@ -1552,7 +1552,10 @@ def check_sample_weight_invariance(name, metric, y1, y2):
1552
1552
if not name .startswith ("unnormalized" ):
1553
1553
# check that the score is invariant under scaling of the weights by a
1554
1554
# common factor
1555
- for scaling in [2 , 0.3 ]:
1555
+ # Due to numerical instability of floating points in `cumulative_sum` in
1556
+ # `median_absolute_error`, it is not always equivalent when scaling by a float.
1557
+ scaling_values = [2 ] if name == "median_absolute_error" else [2 , 0.3 ]
1558
+ for scaling in scaling_values :
1556
1559
assert_allclose (
1557
1560
weighted_score ,
1558
1561
metric (y1 , y2 , sample_weight = sample_weight * scaling ),
@@ -1584,8 +1587,10 @@ def test_regression_sample_weight_invariance(name):
1584
1587
# regression
1585
1588
y_true = random_state .random_sample (size = (n_samples ,))
1586
1589
y_pred = random_state .random_sample (size = (n_samples ,))
1590
+ sample_weight = np .arange (len (y_true ))
1587
1591
metric = ALL_METRICS [name ]
1588
- check_sample_weight_invariance (name , metric , y_true , y_pred )
1592
+
1593
+ check_sample_weight_invariance (name , metric , y_true , y_pred , sample_weight )
1589
1594
1590
1595
1591
1596
@pytest .mark .parametrize (
0 commit comments