Skip to content

Commit 02dce94

Browse files
MAINT parameter validation for metrics.pairwise.pairwise_kernels (scikit-learn#26665)
1 parent 50f17a3 commit 02dce94

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

sklearn/metrics/pairwise.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,6 +2279,19 @@ def kernel_metrics():
22792279
}
22802280

22812281

2282+
@validate_params(
2283+
{
2284+
"X": ["array-like", "sparse matrix"],
2285+
"Y": ["array-like", "sparse matrix", None],
2286+
"metric": [
2287+
StrOptions(set(PAIRWISE_KERNEL_FUNCTIONS) | {"precomputed"}),
2288+
callable,
2289+
],
2290+
"filter_params": ["boolean"],
2291+
"n_jobs": [Integral, None],
2292+
},
2293+
prefer_skip_nested_validation=True,
2294+
)
22822295
def pairwise_kernels(
22832296
X, Y=None, metric="linear", *, filter_params=False, n_jobs=None, **kwds
22842297
):
@@ -2303,18 +2316,19 @@ def pairwise_kernels(
23032316
23042317
Parameters
23052318
----------
2306-
X : ndarray of shape (n_samples_X, n_samples_X) or (n_samples_X, n_features)
2319+
X : {array-like, sparse matrix} of shape (n_samples_X, n_samples_X) or \
2320+
(n_samples_X, n_features)
23072321
Array of pairwise kernels between samples, or a feature array.
23082322
The shape of the array should be (n_samples_X, n_samples_X) if
23092323
metric == "precomputed" and (n_samples_X, n_features) otherwise.
23102324
2311-
Y : ndarray of shape (n_samples_Y, n_features), default=None
2325+
Y : {array-like, sparse matrix} of shape (n_samples_Y, n_features), default=None
23122326
A second feature array only if X has shape (n_samples_X, n_features).
23132327
23142328
metric : str or callable, default="linear"
23152329
The metric to use when calculating kernel between instances in a
23162330
feature array. If metric is a string, it must be one of the metrics
2317-
in pairwise.PAIRWISE_KERNEL_FUNCTIONS.
2331+
in ``pairwise.PAIRWISE_KERNEL_FUNCTIONS``.
23182332
If metric is "precomputed", X is assumed to be a kernel matrix.
23192333
Alternatively, if metric is a callable function, it is called on each
23202334
pair of instances (rows) and the resulting value recorded. The callable
@@ -2365,7 +2379,5 @@ def pairwise_kernels(
23652379
func = PAIRWISE_KERNEL_FUNCTIONS[metric]
23662380
elif callable(metric):
23672381
func = partial(_pairwise_callable, metric=metric, **kwds)
2368-
else:
2369-
raise ValueError("Unknown kernel %r" % metric)
23702382

23712383
return _parallel_pairwise(X, Y, func, n_jobs, **kwds)

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ def _check_function_param_validation(
246246
"sklearn.metrics.pairwise.paired_distances",
247247
"sklearn.metrics.pairwise.paired_euclidean_distances",
248248
"sklearn.metrics.pairwise.paired_manhattan_distances",
249+
"sklearn.metrics.pairwise.pairwise_kernels",
249250
"sklearn.metrics.pairwise.polynomial_kernel",
250251
"sklearn.metrics.pairwise.rbf_kernel",
251252
"sklearn.metrics.pairwise.sigmoid_kernel",

0 commit comments

Comments
 (0)