Skip to content

Commit ffcd361

Browse files
authored
FEA Add array api support for jaccard score (scikit-learn#31204)
1 parent 27f2af3 commit ffcd361

File tree

4 files changed

+13
-4
lines changed

4 files changed

+13
-4
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ Metrics
139139
- :func:`sklearn.metrics.f1_score`
140140
- :func:`sklearn.metrics.fbeta_score`
141141
- :func:`sklearn.metrics.hamming_loss`
142+
- :func:`sklearn.metrics.jaccard_score`
142143
- :func:`sklearn.metrics.max_error`
143144
- :func:`sklearn.metrics.mean_absolute_error`
144145
- :func:`sklearn.metrics.mean_absolute_percentage_error`
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- :func:`sklearn.metrics.jaccard_score` now supports Array API compatible inputs.
2+
By :user:`Omar Salman <OmarManzoor>`

sklearn/metrics/_classification.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,9 +1071,10 @@ def jaccard_score(
10711071
numerator = MCM[:, 1, 1]
10721072
denominator = MCM[:, 1, 1] + MCM[:, 0, 1] + MCM[:, 1, 0]
10731073

1074+
xp, _, device_ = get_namespace_and_device(y_true, y_pred)
10741075
if average == "micro":
1075-
numerator = np.array([numerator.sum()])
1076-
denominator = np.array([denominator.sum()])
1076+
numerator = xp.asarray(xp.sum(numerator, keepdims=True), device=device_)
1077+
denominator = xp.asarray(xp.sum(denominator, keepdims=True), device=device_)
10771078

10781079
jaccard = _prf_divide(
10791080
numerator,
@@ -1088,14 +1089,14 @@ def jaccard_score(
10881089
return jaccard
10891090
if average == "weighted":
10901091
weights = MCM[:, 1, 0] + MCM[:, 1, 1]
1091-
if not np.any(weights):
1092+
if not xp.any(weights):
10921093
# numerator is 0, and warning should have already been issued
10931094
weights = None
10941095
elif average == "samples" and sample_weight is not None:
10951096
weights = sample_weight
10961097
else:
10971098
weights = None
1098-
return float(np.average(jaccard, weights=weights))
1099+
return float(_average(jaccard, weights=weights, xp=xp))
10991100

11001101

11011102
@validate_params(

sklearn/metrics/tests/test_common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,11 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
21472147
check_array_api_multiclass_classification_metric,
21482148
check_array_api_multilabel_classification_metric,
21492149
],
2150+
jaccard_score: [
2151+
check_array_api_binary_classification_metric,
2152+
check_array_api_multiclass_classification_metric,
2153+
check_array_api_multilabel_classification_metric,
2154+
],
21502155
multilabel_confusion_matrix: [
21512156
check_array_api_binary_classification_metric,
21522157
check_array_api_multiclass_classification_metric,

0 commit comments

Comments
 (0)