Skip to content

Commit f10c171

Browse files
EmilyXinyibetatim
andauthored
array API support for mean_gamma_deviance (scikit-learn#29239)
Co-authored-by: Tim Head <betatim@gmail.com>
1 parent 6595229 commit f10c171

File tree

4 files changed

+5
-2
lines changed

4 files changed

+5
-2
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ Metrics
116116
- :func:`sklearn.metrics.d2_tweedie_score`
117117
- :func:`sklearn.metrics.max_error`
118118
- :func:`sklearn.metrics.mean_absolute_error`
119+
- :func:`sklearn.metrics.mean_gamma_deviance`
119120
- :func:`sklearn.metrics.mean_squared_error`
120121
- :func:`sklearn.metrics.mean_tweedie_deviance`
121122
- :func:`sklearn.metrics.pairwise.cosine_similarity`

doc/whats_new/v1.6.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ See :ref:`array_api` for more details.
3535
- :func:`sklearn.metrics.d2_tweedie_score` :pr:`29207` by :user:`Emily Chen <EmilyXinyi>`;
3636
- :func:`sklearn.metrics.max_error` :pr:`29212` by :user:`Edoardo Abati <EdAbati>`;
3737
- :func:`sklearn.metrics.mean_absolute_error` :pr:`27736` by :user:`Edoardo Abati <EdAbati>`;
38+
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :usser:`Emily Chen <EmilyXinyi>`;
3839
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
3940
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
4041
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`.

sklearn/metrics/_regression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1412,7 +1412,7 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
14121412
raise ValueError(message + "non-negative y and strictly positive y_pred.")
14131413
elif power >= 2:
14141414
# Gamma and Extreme stable distribution, y and y_pred > 0
1415-
if (y_true <= 0).any() or (y_pred <= 0).any():
1415+
if xp.any(y_true <= 0) or xp.any(y_pred <= 0):
14161416
raise ValueError(message + "strictly positive y and y_pred.")
14171417
else: # pragma: nocover
14181418
# Unreachable statement

sklearn/metrics/tests/test_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1827,7 +1827,7 @@ def check_array_api_multiclass_classification_metric(
18271827

18281828

18291829
def check_array_api_regression_metric(metric, array_namespace, device, dtype_name):
1830-
y_true_np = np.array([2.0, 0.0, 1.0, 4.0], dtype=dtype_name)
1830+
y_true_np = np.array([2.0, 0.1, 1.0, 4.0], dtype=dtype_name)
18311831
y_pred_np = np.array([0.5, 0.5, 2, 2], dtype=dtype_name)
18321832

18331833
metric_kwargs = {}
@@ -1955,6 +1955,7 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
19551955
check_array_api_regression_metric,
19561956
],
19571957
paired_cosine_distances: [check_array_api_metric_pairwise],
1958+
mean_gamma_deviance: [check_array_api_regression_metric],
19581959
max_error: [check_array_api_regression_metric],
19591960
}
19601961

0 commit comments

Comments
 (0)