|
27 | 27 | from ..neighbors import NearestNeighbors
|
28 | 28 | from ..utils import check_random_state
|
29 | 29 | from ..utils._openmp_helpers import _openmp_effective_n_threads
|
30 |
| -from ..utils._param_validation import Interval, StrOptions |
31 |
| -from ..utils.validation import check_non_negative |
| 30 | +from ..utils._param_validation import Interval, StrOptions, validate_params |
| 31 | +from ..utils.validation import _num_samples, check_non_negative |
32 | 32 |
|
33 | 33 | # mypy error: Module 'sklearn.manifold' has no attribute '_utils'
|
34 | 34 | # mypy error: Module 'sklearn.manifold' has no attribute '_barnes_hut_tsne'
|
@@ -446,6 +446,15 @@ def _gradient_descent(
|
446 | 446 | return p, error, i
|
447 | 447 |
|
448 | 448 |
|
| 449 | +@validate_params( |
| 450 | + { |
| 451 | + "X": ["array-like", "sparse matrix"], |
| 452 | + "X_embedded": ["array-like", "sparse matrix"], |
| 453 | + "n_neighbors": [Interval(Integral, 1, None, closed="left")], |
| 454 | + "metric": [StrOptions(set(_VALID_METRICS) | {"precomputed"}), callable], |
| 455 | + }, |
| 456 | + prefer_skip_nested_validation=True, |
| 457 | +) |
449 | 458 | def trustworthiness(X, X_embedded, *, n_neighbors=5, metric="euclidean"):
|
450 | 459 | r"""Indicate to what extent the local structure is retained.
|
451 | 460 |
|
@@ -504,7 +513,7 @@ def trustworthiness(X, X_embedded, *, n_neighbors=5, metric="euclidean"):
|
504 | 513 | Local Structure. Proceedings of the Twelfth International Conference on
|
505 | 514 | Artificial Intelligence and Statistics, PMLR 5:384-391, 2009.
|
506 | 515 | """
|
507 |
| - n_samples = X.shape[0] |
| 516 | + n_samples = _num_samples(X) |
508 | 517 | if n_neighbors >= n_samples / 2:
|
509 | 518 | raise ValueError(
|
510 | 519 | f"n_neighbors ({n_neighbors}) should be less than n_samples / 2"
|
|
0 commit comments