Skip to content

Commit a627a07

Browse files
saihttamjnothman
authored andcommitted
[MRG+1] Fix for test_neighbors_metrics() check kd-tree algorithm when possible (scikit-learn#7339)
* Fix for test_neighbors_metrics() check kd-tree algorithm when possible * Improving check in test_neighbors_metrics by using a dictionary for the results and checking for existence of key for kd_tree * Fix flake8 problem due to line too long * remove enumerate, since not needed when using results dict
1 parent f916449 commit a627a07

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,7 @@ def test_neighbors_metrics(n_samples=20, n_features=3,
942942
test = rng.rand(n_query_pts, n_features)
943943

944944
for metric, metric_params in metrics:
945-
results = []
945+
results = {}
946946
p = metric_params.pop('p', 2)
947947
for algorithm in algorithms:
948948
# KD tree doesn't support all metrics
@@ -953,16 +953,19 @@ def test_neighbors_metrics(n_samples=20, n_features=3,
953953
algorithm=algorithm,
954954
metric=metric, metric_params=metric_params)
955955
continue
956-
957956
neigh = neighbors.NearestNeighbors(n_neighbors=n_neighbors,
958957
algorithm=algorithm,
959958
metric=metric, p=p,
960959
metric_params=metric_params)
961960
neigh.fit(X)
962-
results.append(neigh.kneighbors(test, return_distance=True))
963-
964-
assert_array_almost_equal(results[0][0], results[1][0])
965-
assert_array_almost_equal(results[0][1], results[1][1])
961+
results[algorithm] = neigh.kneighbors(test, return_distance=True)
962+
assert_array_almost_equal(results['brute'][0], results['ball_tree'][0])
963+
assert_array_almost_equal(results['brute'][1], results['ball_tree'][1])
964+
if 'kd_tree' in results:
965+
assert_array_almost_equal(results['brute'][0],
966+
results['kd_tree'][0])
967+
assert_array_almost_equal(results['brute'][1],
968+
results['kd_tree'][1])
966969

967970

968971
def test_callable_metric():

0 commit comments

Comments
 (0)