Skip to content

Commit 76d9be2

Browse files
MAINT Parameters validation for sklearn.metrics.pairwise_distances_chunked (scikit-learn#26125)
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent 96e13f1 commit 76d9be2

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

sklearn/metrics/pairwise.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,6 +1883,17 @@ def _precompute_metric_params(X, Y, metric=None, **kwds):
18831883
return {}
18841884

18851885

1886+
@validate_params(
1887+
{
1888+
"X": ["array-like", "sparse matrix"],
1889+
"Y": ["array-like", "sparse matrix", None],
1890+
"reduce_func": [callable, None],
1891+
"metric": [StrOptions({"precomputed"}.union(_VALID_METRICS)), callable],
1892+
"n_jobs": [Integral, None],
1893+
"working_memory": [Interval(Real, 0, None, closed="left"), None],
1894+
},
1895+
prefer_skip_nested_validation=False, # metric is not validated yet
1896+
)
18861897
def pairwise_distances_chunked(
18871898
X,
18881899
Y=None,
@@ -1903,13 +1914,13 @@ def pairwise_distances_chunked(
19031914
19041915
Parameters
19051916
----------
1906-
X : ndarray of shape (n_samples_X, n_samples_X) or \
1917+
X : {array-like, sparse matrix} of shape (n_samples_X, n_samples_X) or \
19071918
(n_samples_X, n_features)
19081919
Array of pairwise distances between samples, or a feature array.
19091920
The shape the array should be (n_samples_X, n_samples_X) if
19101921
metric='precomputed' and (n_samples_X, n_features) otherwise.
19111922
1912-
Y : ndarray of shape (n_samples_Y, n_features), default=None
1923+
Y : {array-like, sparse matrix} of shape (n_samples_Y, n_features), default=None
19131924
An optional second feature array. Only allowed if
19141925
metric != "precomputed".
19151926
@@ -1946,7 +1957,7 @@ def pairwise_distances_chunked(
19461957
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
19471958
for more details.
19481959
1949-
working_memory : int, default=None
1960+
working_memory : float, default=None
19501961
The sought maximum memory for temporary distance matrix chunks.
19511962
When None (default), the value of
19521963
``sklearn.get_config()['working_memory']`` is used.

sklearn/metrics/tests/test_pairwise.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -748,9 +748,6 @@ def test_pairwise_distances_chunked(global_dtype):
748748
# "cityblock" uses scikit-learn metric, cityblock (function) is
749749
# scipy.spatial.
750750
check_pairwise_distances_chunked(X, Y, working_memory=1, metric="cityblock")
751-
# Test that a value error is raised if the metric is unknown
752-
with pytest.raises(ValueError):
753-
next(pairwise_distances_chunked(X, Y, metric="blah"))
754751

755752
# Test precomputed returns all at once
756753
D = pairwise_distances(X)

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def _check_function_param_validation(
254254
"sklearn.metrics.pairwise.sigmoid_kernel",
255255
"sklearn.metrics.pairwise_distances",
256256
"sklearn.metrics.pairwise_distances_argmin",
257+
"sklearn.metrics.pairwise_distances_chunked",
257258
"sklearn.metrics.precision_recall_curve",
258259
"sklearn.metrics.precision_recall_fscore_support",
259260
"sklearn.metrics.precision_score",

0 commit comments

Comments
 (0)