Skip to content

Commit 15d5a06

Browse files
authored
TST Improve tests for neighbor models with X=None (scikit-learn#30101)
1 parent 6eb2ef3 commit 15d5a06

File tree

1 file changed

+50
-12
lines changed

1 file changed

+50
-12
lines changed

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,35 +2401,73 @@ def _weights(dist):
24012401
"nn_model",
24022402
[
24032403
neighbors.KNeighborsClassifier(n_neighbors=10),
2404-
neighbors.RadiusNeighborsClassifier(radius=5.0),
2404+
neighbors.RadiusNeighborsClassifier(),
24052405
],
24062406
)
2407-
def test_neighbor_classifiers_loocv(nn_model):
2408-
"""Check that `predict` and related functions work fine with X=None"""
2409-
X, y = datasets.make_blobs(n_samples=500, centers=5, n_features=2, random_state=0)
2407+
@pytest.mark.parametrize("algorithm", ALGORITHMS)
2408+
def test_neighbor_classifiers_loocv(nn_model, algorithm):
2409+
"""Check that `predict` and related functions work fine with X=None
2410+
2411+
Calling predict with X=None computes a prediction for each training point
2412+
from the labels of its neighbors (without the label of the data point being
2413+
predicted upon). This is therefore mathematically equivalent to
2414+
leave-one-out cross-validation without having do any retraining (rebuilding
2415+
a KD-tree or Ball-tree index) or any data reshuffling.
2416+
"""
2417+
X, y = datasets.make_blobs(n_samples=15, centers=5, n_features=2, random_state=0)
2418+
2419+
nn_model = clone(nn_model).set_params(algorithm=algorithm)
2420+
2421+
# Set the radius for RadiusNeighborsRegressor to some percentile of the
2422+
# empirical pairwise distances to avoid trivial test cases and warnings for
2423+
# predictions with no neighbors within the radius.
2424+
if "radius" in nn_model.get_params():
2425+
dists = pairwise_distances(X).ravel()
2426+
dists = dists[dists > 0]
2427+
nn_model.set_params(radius=np.percentile(dists, 80))
24102428

24112429
loocv = cross_val_score(nn_model, X, y, cv=LeaveOneOut())
24122430
nn_model.fit(X, y)
24132431

2414-
assert np.all(loocv == (nn_model.predict(None) == y))
2415-
assert np.mean(loocv) == nn_model.score(None, y)
2432+
assert_allclose(loocv, nn_model.predict(None) == y)
2433+
assert np.mean(loocv) == pytest.approx(nn_model.score(None, y))
2434+
2435+
# Evaluating `nn_model` on its "training" set should lead to a higher
2436+
# accuracy value than leaving out each data point in turn because the
2437+
# former can overfit while the latter cannot by construction.
24162438
assert nn_model.score(None, y) < nn_model.score(X, y)
24172439

24182440

24192441
@pytest.mark.parametrize(
24202442
"nn_model",
24212443
[
24222444
neighbors.KNeighborsRegressor(n_neighbors=10),
2423-
neighbors.RadiusNeighborsRegressor(radius=0.5),
2445+
neighbors.RadiusNeighborsRegressor(),
24242446
],
24252447
)
2426-
def test_neighbor_regressors_loocv(nn_model):
2448+
@pytest.mark.parametrize("algorithm", ALGORITHMS)
2449+
def test_neighbor_regressors_loocv(nn_model, algorithm):
24272450
"""Check that `predict` and related functions work fine with X=None"""
2428-
X, y = datasets.load_diabetes(return_X_y=True)
2451+
X, y = datasets.make_regression(n_samples=15, n_features=2, random_state=0)
24292452

24302453
# Only checking cross_val_predict and not cross_val_score because
2431-
# cross_val_score does not work with LeaveOneOut() for a regressor
2454+
# cross_val_score does not work with LeaveOneOut() for a regressor: the
2455+
# default score method implements R2 score which is not well defined for a
2456+
# single data point.
2457+
#
2458+
# TODO: if score is refactored to evaluate models for other scoring
2459+
# functions, then this test can be extended to check cross_val_score as
2460+
# well.
2461+
nn_model = clone(nn_model).set_params(algorithm=algorithm)
2462+
2463+
# Set the radius for RadiusNeighborsRegressor to some percentile of the
2464+
# empirical pairwise distances to avoid trivial test cases and warnings for
2465+
# predictions with no neighbors within the radius.
2466+
if "radius" in nn_model.get_params():
2467+
dists = pairwise_distances(X).ravel()
2468+
dists = dists[dists > 0]
2469+
nn_model.set_params(radius=np.percentile(dists, 80))
2470+
24322471
loocv = cross_val_predict(nn_model, X, y, cv=LeaveOneOut())
24332472
nn_model.fit(X, y)
2434-
2435-
assert np.all(loocv == nn_model.predict(None))
2473+
assert_allclose(loocv, nn_model.predict(None))

0 commit comments

Comments
 (0)