Skip to content

Commit f0d6a9c

Browse files
MAINT Parameter validation for metrics.cluster._supervised (scikit-learn#26258)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 695ad73 commit f0d6a9c

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

sklearn/metrics/cluster/_supervised.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred, *, beta=1.0):
476476
477477
Parameters
478478
----------
479-
labels_true : int array, shape = [n_samples]
479+
labels_true : array-like of shape (n_samples,)
480480
Ground truth class labels to be used as a reference.
481481
482482
labels_pred : array-like of shape (n_samples,)
@@ -532,6 +532,12 @@ def homogeneity_completeness_v_measure(labels_true, labels_pred, *, beta=1.0):
532532
return homogeneity, completeness, v_measure_score
533533

534534

535+
@validate_params(
536+
{
537+
"labels_true": ["array-like"],
538+
"labels_pred": ["array-like"],
539+
}
540+
)
535541
def homogeneity_score(labels_true, labels_pred):
536542
"""Homogeneity metric of a cluster labeling given a ground truth.
537543
@@ -550,7 +556,7 @@ def homogeneity_score(labels_true, labels_pred):
550556
551557
Parameters
552558
----------
553-
labels_true : int array, shape = [n_samples]
559+
labels_true : array-like of shape (n_samples,)
554560
Ground truth class labels to be used as a reference.
555561
556562
labels_pred : array-like of shape (n_samples,)
@@ -601,6 +607,12 @@ def homogeneity_score(labels_true, labels_pred):
601607
return homogeneity_completeness_v_measure(labels_true, labels_pred)[0]
602608

603609

610+
@validate_params(
611+
{
612+
"labels_true": ["array-like"],
613+
"labels_pred": ["array-like"],
614+
}
615+
)
604616
def completeness_score(labels_true, labels_pred):
605617
"""Compute completeness metric of a cluster labeling given a ground truth.
606618
@@ -619,7 +631,7 @@ def completeness_score(labels_true, labels_pred):
619631
620632
Parameters
621633
----------
622-
labels_true : int array, shape = [n_samples]
634+
labels_true : array-like of shape (n_samples,)
623635
Ground truth class labels to be used as a reference.
624636
625637
labels_pred : array-like of shape (n_samples,)
@@ -670,6 +682,13 @@ def completeness_score(labels_true, labels_pred):
670682
return homogeneity_completeness_v_measure(labels_true, labels_pred)[1]
671683

672684

685+
@validate_params(
686+
{
687+
"labels_true": ["array-like"],
688+
"labels_pred": ["array-like"],
689+
"beta": [Interval(Real, 0, None, closed="left")],
690+
}
691+
)
673692
def v_measure_score(labels_true, labels_pred, *, beta=1.0):
674693
"""V-measure cluster labeling given a ground truth.
675694
@@ -694,7 +713,7 @@ def v_measure_score(labels_true, labels_pred, *, beta=1.0):
694713
695714
Parameters
696715
----------
697-
labels_true : int array, shape = [n_samples]
716+
labels_true : array-like of shape (n_samples,)
698717
Ground truth class labels to be used as a reference.
699718
700719
labels_pred : array-like of shape (n_samples,)

sklearn/tests/test_public_functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def _check_function_param_validation(
180180
"sklearn.metrics.balanced_accuracy_score",
181181
"sklearn.metrics.brier_score_loss",
182182
"sklearn.metrics.calinski_harabasz_score",
183+
"sklearn.metrics.completeness_score",
183184
"sklearn.metrics.class_likelihood_ratios",
184185
"sklearn.metrics.classification_report",
185186
"sklearn.metrics.cluster.adjusted_mutual_info_score",
@@ -205,6 +206,7 @@ def _check_function_param_validation(
205206
"sklearn.metrics.get_scorer",
206207
"sklearn.metrics.hamming_loss",
207208
"sklearn.metrics.hinge_loss",
209+
"sklearn.metrics.homogeneity_score",
208210
"sklearn.metrics.jaccard_score",
209211
"sklearn.metrics.label_ranking_average_precision_score",
210212
"sklearn.metrics.label_ranking_loss",
@@ -248,6 +250,7 @@ def _check_function_param_validation(
248250
"sklearn.metrics.roc_auc_score",
249251
"sklearn.metrics.roc_curve",
250252
"sklearn.metrics.top_k_accuracy_score",
253+
"sklearn.metrics.v_measure_score",
251254
"sklearn.metrics.zero_one_loss",
252255
"sklearn.model_selection.cross_validate",
253256
"sklearn.model_selection.permutation_test_score",

0 commit comments

Comments
 (0)