Skip to content

Commit bb96354

Browse files
Merge pull request #7 from vadimzaliva/development
+ DBSCAN and data generator. Improves KNN API
2 parents 6602de0 + c43990e commit bb96354

File tree

11 files changed

+556
-53
lines changed

11 files changed

+556
-53
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ nalgebra = { version = "0.22.0", optional = true }
2424
num-traits = "0.2.12"
2525
num = "0.3.0"
2626
rand = "0.7.3"
27+
rand_distr = "0.3.0"
2728
serde = { version = "1.0.115", features = ["derive"] }
2829
serde_derive = "1.0.115"
2930

src/algorithm/neighbour/cover_tree.rs

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
100100
/// Find k nearest neighbors of `p`
101101
/// * `p` - look for k nearest points to `p`
102102
/// * `k` - the number of nearest neighbors to return
103-
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
103+
pub fn find(&self, p: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
104104
if k <= 0 {
105105
return Err(Failed::because(FailedError::FindFailed, "k should be > 0"));
106106
}
@@ -164,20 +164,74 @@ impl<T: Debug + PartialEq, F: RealNumber, D: Distance<T, F>> CoverTree<T, F, D>
164164
current_cover_set = next_cover_set;
165165
}
166166

167-
let mut neighbors: Vec<(usize, F)> = Vec::new();
167+
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
168168
let upper_bound = *heap.peek();
169169
for ds in zero_set {
170170
if ds.0 <= upper_bound {
171171
let v = self.get_data_value(ds.1.idx);
172172
if !self.identical_excluded || v != p {
173-
neighbors.push((ds.1.idx, ds.0));
173+
neighbors.push((ds.1.idx, ds.0, &v));
174174
}
175175
}
176176
}
177177

178178
Ok(neighbors.into_iter().take(k).collect())
179179
}
180180

181+
/// Find all nearest neighbors within radius `radius` from `p`
182+
/// * `p` - look for k nearest points to `p`
183+
/// * `radius` - radius of the search
184+
pub fn find_radius(&self, p: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
185+
if radius <= F::zero() {
186+
return Err(Failed::because(
187+
FailedError::FindFailed,
188+
"radius should be > 0",
189+
));
190+
}
191+
192+
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
193+
194+
let mut current_cover_set: Vec<(F, &Node<F>)> = Vec::new();
195+
let mut zero_set: Vec<(F, &Node<F>)> = Vec::new();
196+
197+
let e = self.get_data_value(self.root.idx);
198+
let mut d = self.distance.distance(&e, p);
199+
current_cover_set.push((d, &self.root));
200+
201+
while !current_cover_set.is_empty() {
202+
let mut next_cover_set: Vec<(F, &Node<F>)> = Vec::new();
203+
for par in current_cover_set {
204+
let parent = par.1;
205+
for c in 0..parent.children.len() {
206+
let child = &parent.children[c];
207+
if c == 0 {
208+
d = par.0;
209+
} else {
210+
d = self.distance.distance(self.get_data_value(child.idx), p);
211+
}
212+
213+
if d <= radius + child.max_dist {
214+
if !child.children.is_empty() {
215+
next_cover_set.push((d, child));
216+
} else if d <= radius {
217+
zero_set.push((d, child));
218+
}
219+
}
220+
}
221+
}
222+
current_cover_set = next_cover_set;
223+
}
224+
225+
for ds in zero_set {
226+
let v = self.get_data_value(ds.1.idx);
227+
if !self.identical_excluded || v != p {
228+
neighbors.push((ds.1.idx, ds.0, &v));
229+
}
230+
}
231+
232+
Ok(neighbors)
233+
}
234+
181235
fn new_leaf(&self, idx: usize) -> Node<F> {
182236
Node {
183237
idx: idx,
@@ -417,6 +471,11 @@ mod tests {
417471
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
418472
let knn: Vec<usize> = knn.iter().map(|v| v.0).collect();
419473
assert_eq!(vec!(3, 4, 5), knn);
474+
475+
let mut knn = tree.find_radius(&5, 2.0).unwrap();
476+
knn.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
477+
let knn: Vec<i32> = knn.iter().map(|v| *v.2).collect();
478+
assert_eq!(vec!(3, 4, 5, 6, 7), knn);
420479
}
421480

422481
#[test]

src/algorithm/neighbour/linear_search.rs

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use std::cmp::{Ordering, PartialOrd};
2626
use std::marker::PhantomData;
2727

2828
use crate::algorithm::sort::heap_select::HeapSelection;
29-
use crate::error::Failed;
29+
use crate::error::{Failed, FailedError};
3030
use crate::math::distance::Distance;
3131
use crate::math::num::RealNumber;
3232

@@ -53,9 +53,12 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
5353
/// Find k nearest neighbors
5454
/// * `from` - look for k nearest points to `from`
5555
/// * `k` - the number of nearest neighbors to return
56-
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F)>, Failed> {
56+
pub fn find(&self, from: &T, k: usize) -> Result<Vec<(usize, F, &T)>, Failed> {
5757
if k < 1 || k > self.data.len() {
58-
panic!("k should be >= 1 and <= length(data)");
58+
return Err(Failed::because(
59+
FailedError::FindFailed,
60+
"k should be >= 1 and <= length(data)",
61+
));
5962
}
6063

6164
let mut heap = HeapSelection::<KNNPoint<F>>::with_capacity(k);
@@ -80,9 +83,33 @@ impl<T, F: RealNumber, D: Distance<T, F>> LinearKNNSearch<T, F, D> {
8083
Ok(heap
8184
.get()
8285
.into_iter()
83-
.flat_map(|x| x.index.map(|i| (i, x.distance)))
86+
.flat_map(|x| x.index.map(|i| (i, x.distance, &self.data[i])))
8487
.collect())
8588
}
89+
90+
/// Find all nearest neighbors within radius `radius` from `p`
91+
/// * `p` - look for k nearest points to `p`
92+
/// * `radius` - radius of the search
93+
pub fn find_radius(&self, from: &T, radius: F) -> Result<Vec<(usize, F, &T)>, Failed> {
94+
if radius <= F::zero() {
95+
return Err(Failed::because(
96+
FailedError::FindFailed,
97+
"radius should be > 0",
98+
));
99+
}
100+
101+
let mut neighbors: Vec<(usize, F, &T)> = Vec::new();
102+
103+
for i in 0..self.data.len() {
104+
let d = self.distance.distance(&from, &self.data[i]);
105+
106+
if d <= radius {
107+
neighbors.push((i, d, &self.data[i]));
108+
}
109+
}
110+
111+
Ok(neighbors)
112+
}
86113
}
87114

88115
#[derive(Debug)]
@@ -134,6 +161,16 @@ mod tests {
134161

135162
assert_eq!(vec!(0, 1, 2), found_idxs1);
136163

164+
let mut found_idxs1: Vec<i32> = algorithm1
165+
.find_radius(&5, 3.0)
166+
.unwrap()
167+
.iter()
168+
.map(|v| *v.2)
169+
.collect();
170+
found_idxs1.sort();
171+
172+
assert_eq!(vec!(2, 3, 4, 5, 6, 7, 8), found_idxs1);
173+
137174
let data2 = vec![
138175
vec![1., 1.],
139176
vec![2., 2.],

src/algorithm/neighbour/mod.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,68 @@
2929
//! <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
3030
//! <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
3131
32+
use crate::algorithm::neighbour::cover_tree::CoverTree;
33+
use crate::algorithm::neighbour::linear_search::LinearKNNSearch;
34+
use crate::error::Failed;
35+
use crate::math::distance::Distance;
36+
use crate::math::num::RealNumber;
37+
use serde::{Deserialize, Serialize};
38+
3239
pub(crate) mod bbd_tree;
3340
/// tree data structure for fast nearest neighbor search
3441
pub mod cover_tree;
3542
/// very simple algorithm that sequentially checks each element of the list until a match is found or the whole list has been searched.
3643
pub mod linear_search;
44+
45+
/// Both, KNN classifier and regressor benefits from underlying search algorithms that helps to speed up queries.
46+
/// `KNNAlgorithmName` maintains a list of supported search algorithms, see [KNN algorithms](../algorithm/neighbour/index.html)
47+
#[derive(Serialize, Deserialize, Debug, Clone)]
48+
pub enum KNNAlgorithmName {
49+
/// Heap Search algorithm, see [`LinearSearch`](../algorithm/neighbour/linear_search/index.html)
50+
LinearSearch,
51+
/// Cover Tree Search algorithm, see [`CoverTree`](../algorithm/neighbour/cover_tree/index.html)
52+
CoverTree,
53+
}
54+
55+
#[derive(Serialize, Deserialize, Debug)]
56+
pub(crate) enum KNNAlgorithm<T: RealNumber, D: Distance<Vec<T>, T>> {
57+
LinearSearch(LinearKNNSearch<Vec<T>, T, D>),
58+
CoverTree(CoverTree<Vec<T>, T, D>),
59+
}
60+
61+
impl KNNAlgorithmName {
62+
pub(crate) fn fit<T: RealNumber, D: Distance<Vec<T>, T>>(
63+
&self,
64+
data: Vec<Vec<T>>,
65+
distance: D,
66+
) -> Result<KNNAlgorithm<T, D>, Failed> {
67+
match *self {
68+
KNNAlgorithmName::LinearSearch => {
69+
LinearKNNSearch::new(data, distance).map(|a| KNNAlgorithm::LinearSearch(a))
70+
}
71+
KNNAlgorithmName::CoverTree => {
72+
CoverTree::new(data, distance).map(|a| KNNAlgorithm::CoverTree(a))
73+
}
74+
}
75+
}
76+
}
77+
78+
impl<T: RealNumber, D: Distance<Vec<T>, T>> KNNAlgorithm<T, D> {
79+
pub fn find(&self, from: &Vec<T>, k: usize) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
80+
match *self {
81+
KNNAlgorithm::LinearSearch(ref linear) => linear.find(from, k),
82+
KNNAlgorithm::CoverTree(ref cover) => cover.find(from, k),
83+
}
84+
}
85+
86+
pub fn find_radius(
87+
&self,
88+
from: &Vec<T>,
89+
radius: T,
90+
) -> Result<Vec<(usize, T, &Vec<T>)>, Failed> {
91+
match *self {
92+
KNNAlgorithm::LinearSearch(ref linear) => linear.find_radius(from, radius),
93+
KNNAlgorithm::CoverTree(ref cover) => cover.find_radius(from, radius),
94+
}
95+
}
96+
}

0 commit comments

Comments
 (0)