Skip to content

Commit 56dbfd0

Browse files
authored
extreme_stable case for mean_tweedie_deviance (scikit-learn#29258)
1 parent a408a59 commit 56dbfd0

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

sklearn/metrics/_regression.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,12 +1284,14 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
12841284
"""Mean Tweedie deviance regression loss."""
12851285
xp, _ = get_namespace(y_true, y_pred)
12861286
p = power
1287+
zero = xp.asarray(0, dtype=y_true.dtype)
12871288
if p < 0:
12881289
# 'Extreme stable', y any real number, y_pred > 0
12891290
dev = 2 * (
1290-
xp.pow(xp.where(y_true > 0, y_true, 0), 2 - p) / ((1 - p) * (2 - p))
1291-
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
1292-
+ xp.pow(y_pred, 2 - p) / (2 - p)
1291+
xp.pow(xp.where(y_true > 0, y_true, zero), xp.asarray(2 - p))
1292+
/ ((1 - p) * (2 - p))
1293+
- y_true * xp.pow(y_pred, xp.asarray(1 - p)) / (1 - p)
1294+
+ xp.pow(y_pred, xp.asarray(2 - p)) / (2 - p)
12931295
)
12941296
elif p == 0:
12951297
# Normal distribution, y and y_pred any real number
@@ -1302,9 +1304,9 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
13021304
dev = 2 * (xp.log(y_pred / y_true) + y_true / y_pred - 1)
13031305
else:
13041306
dev = 2 * (
1305-
xp.pow(y_true, 2 - p) / ((1 - p) * (2 - p))
1306-
- y_true * xp.pow(y_pred, 1 - p) / (1 - p)
1307-
+ xp.pow(y_pred, 2 - p) / (2 - p)
1307+
xp.pow(y_true, xp.asarray(2 - p)) / ((1 - p) * (2 - p))
1308+
- y_true * xp.pow(y_pred, xp.asarray(1 - p)) / (1 - p)
1309+
+ xp.pow(y_pred, xp.asarray(2 - p)) / (2 - p)
13081310
)
13091311
return float(_average(dev, weights=sample_weight))
13101312

@@ -1384,14 +1386,14 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
13841386
message = f"Mean Tweedie deviance error with power={power} can only be used on "
13851387
if power < 0:
13861388
# 'Extreme stable', y any real number, y_pred > 0
1387-
if (y_pred <= 0).any():
1389+
if xp.any(y_pred <= 0):
13881390
raise ValueError(message + "strictly positive y_pred.")
13891391
elif power == 0:
13901392
# Normal, y and y_pred can be any real number
13911393
pass
13921394
elif 1 <= power < 2:
13931395
# Poisson and compound Poisson distribution, y >= 0, y_pred > 0
1394-
if (y_true < 0).any() or (y_pred <= 0).any():
1396+
if xp.any(y_true < 0) or xp.any(y_pred <= 0):
13951397
raise ValueError(message + "non-negative y and strictly positive y_pred.")
13961398
elif power >= 2:
13971399
# Gamma and Extreme stable distribution, y and y_pred > 0

sklearn/metrics/tests/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,8 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
19911991
check_array_api_multilabel_classification_metric,
19921992
],
19931993
mean_tweedie_deviance: [check_array_api_regression_metric],
1994+
partial(mean_tweedie_deviance, power=-0.5): [check_array_api_regression_metric],
1995+
partial(mean_tweedie_deviance, power=1.5): [check_array_api_regression_metric],
19941996
r2_score: [
19951997
check_array_api_regression_metric,
19961998
check_array_api_regression_metric_multioutput,

0 commit comments

Comments
 (0)