Skip to content

Commit 9cabb12

Browse files
MAINT Parameters validation for sklearn.manifold.trustworthiness (scikit-learn#26276)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 6f2cf7c commit 9cabb12

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

sklearn/manifold/_t_sne.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from ..neighbors import NearestNeighbors
2828
from ..utils import check_random_state
2929
from ..utils._openmp_helpers import _openmp_effective_n_threads
30-
from ..utils._param_validation import Interval, StrOptions
31-
from ..utils.validation import check_non_negative
30+
from ..utils._param_validation import Interval, StrOptions, validate_params
31+
from ..utils.validation import _num_samples, check_non_negative
3232

3333
# mypy error: Module 'sklearn.manifold' has no attribute '_utils'
3434
# mypy error: Module 'sklearn.manifold' has no attribute '_barnes_hut_tsne'
@@ -446,6 +446,15 @@ def _gradient_descent(
446446
return p, error, i
447447

448448

449+
@validate_params(
450+
{
451+
"X": ["array-like", "sparse matrix"],
452+
"X_embedded": ["array-like", "sparse matrix"],
453+
"n_neighbors": [Interval(Integral, 1, None, closed="left")],
454+
"metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable],
455+
},
456+
prefer_skip_nested_validation=True,
457+
)
449458
def trustworthiness(X, X_embedded, *, n_neighbors=5, metric="euclidean"):
450459
r"""Indicate to what extent the local structure is retained.
451460
@@ -504,7 +513,7 @@ def trustworthiness(X, X_embedded, *, n_neighbors=5, metric="euclidean"):
504513
Local Structure. Proceedings of the Twelfth International Conference on
505514
Artificial Intelligence and Statistics, PMLR 5:384-391, 2009.
506515
"""
507-
n_samples = X.shape[0]
516+
n_samples = _num_samples(X)
508517
if n_neighbors >= n_samples / 2:
509518
raise ValueError(
510519
f"n_neighbors ({n_neighbors}) should be less than n_samples / 2"

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def _check_function_param_validation(
202202
"sklearn.linear_model.orthogonal_mp",
203203
"sklearn.linear_model.orthogonal_mp_gram",
204204
"sklearn.linear_model.ridge_regression",
205+
"sklearn.manifold.trustworthiness",
205206
"sklearn.metrics.accuracy_score",
206207
"sklearn.manifold.smacof",
207208
"sklearn.metrics.auc",

0 commit comments

Comments
 (0)