Skip to content

Commit 4be1089

Browse files
MAINT Parameters validation for sklearn.metrics.pairwise.pairwise_distances_argmin_min (scikit-learn#26123)
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent d3e11ab commit 4be1089

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

sklearn/metrics/pairwise.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import itertools
1111
import warnings
1212
from functools import partial
13+
from numbers import Integral, Real
1314

1415
import numpy as np
1516
from joblib import effective_n_jobs
@@ -29,11 +30,9 @@
2930
from ..utils._mask import _get_mask
3031
from ..utils._param_validation import (
3132
Hidden,
32-
Integral,
3333
Interval,
3434
MissingValues,
3535
Options,
36-
Real,
3736
StrOptions,
3837
validate_params,
3938
)
@@ -658,6 +657,19 @@ def _argmin_reduce(dist, start):
658657
_NAN_METRICS = ["nan_euclidean"]
659658

660659

660+
@validate_params(
661+
{
662+
"X": ["array-like", "sparse matrix"],
663+
"Y": ["array-like", "sparse matrix"],
664+
"axis": [Options(Integral, {0, 1})],
665+
"metric": [
666+
StrOptions(set(_VALID_METRICS).union(ArgKmin.valid_metrics())),
667+
callable,
668+
],
669+
"metric_kwargs": [dict, None],
670+
},
671+
prefer_skip_nested_validation=False, # metric is not validated yet
672+
)
661673
def pairwise_distances_argmin_min(
662674
X, Y, *, axis=1, metric="euclidean", metric_kwargs=None
663675
):

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def _check_function_param_validation(
247247
"sklearn.metrics.pairwise.paired_distances",
248248
"sklearn.metrics.pairwise.paired_euclidean_distances",
249249
"sklearn.metrics.pairwise.paired_manhattan_distances",
250+
"sklearn.metrics.pairwise.pairwise_distances_argmin_min",
250251
"sklearn.metrics.pairwise.pairwise_kernels",
251252
"sklearn.metrics.pairwise.polynomial_kernel",
252253
"sklearn.metrics.pairwise.rbf_kernel",

0 commit comments

Comments
 (0)