Skip to content

Commit dc7f01d

Browse files
Mec-iSmorenol
authored andcommitted
Implement fastpair (#142)
* initial fastpair implementation * FastPair initial implementation * implement fastpair * Add random test * Add bench for fastpair * Refactor with constructor for FastPair * Add serialization for PairwiseDistance * Add fp_bench feature for fastpair bench
1 parent eb4b49d commit dc7f01d

File tree

5 files changed

+669
-0
lines changed

5 files changed

+669
-0
lines changed

Cargo.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ default = ["datasets"]
1717
ndarray-bindings = ["ndarray"]
1818
nalgebra-bindings = ["nalgebra"]
1919
datasets = []
20+
fp_bench = []
2021

2122
[dependencies]
2223
ndarray = { version = "0.15", optional = true }
@@ -26,6 +27,7 @@ num = "0.4"
2627
rand = "0.8"
2728
rand_distr = "0.4"
2829
serde = { version = "1", features = ["derive"], optional = true }
30+
itertools = "0.10.3"
2931

3032
[target.'cfg(target_arch = "wasm32")'.dependencies]
3133
getrandom = { version = "0.2", features = ["js"] }
@@ -46,3 +48,8 @@ harness = false
4648
name = "naive_bayes"
4749
harness = false
4850
required-features = ["ndarray-bindings", "nalgebra-bindings"]
51+
52+
[[bench]]
53+
name = "fastpair"
54+
harness = false
55+
required-features = ["fp_bench"]

benches/fastpair.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
2+
3+
// to run this bench you have to change the declaraion in mod.rs ---> pub mod fastpair;
4+
use smartcore::algorithm::neighbour::fastpair::FastPair;
5+
use smartcore::linalg::naive::dense_matrix::*;
6+
use std::time::Duration;
7+
8+
fn closest_pair_bench(n: usize, m: usize) -> () {
9+
let x = DenseMatrix::<f64>::rand(n, m);
10+
let fastpair = FastPair::new(&x);
11+
let result = fastpair.unwrap();
12+
13+
result.closest_pair();
14+
}
15+
16+
fn closest_pair_brute_bench(n: usize, m: usize) -> () {
17+
let x = DenseMatrix::<f64>::rand(n, m);
18+
let fastpair = FastPair::new(&x);
19+
let result = fastpair.unwrap();
20+
21+
result.closest_pair_brute();
22+
}
23+
24+
fn bench_fastpair(c: &mut Criterion) {
25+
let mut group = c.benchmark_group("FastPair");
26+
27+
// with full samples size (100) the test will take too long
28+
group.significance_level(0.1).sample_size(30);
29+
// increase from default 5.0 secs
30+
group.measurement_time(Duration::from_secs(60));
31+
32+
for n_samples in [100_usize, 1000_usize].iter() {
33+
for n_features in [10_usize, 100_usize, 1000_usize].iter() {
34+
group.bench_with_input(
35+
BenchmarkId::from_parameter(format!(
36+
"fastpair --- n_samples: {}, n_features: {}",
37+
n_samples, n_features
38+
)),
39+
n_samples,
40+
|b, _| b.iter(|| closest_pair_bench(*n_samples, *n_features)),
41+
);
42+
group.bench_with_input(
43+
BenchmarkId::from_parameter(format!(
44+
"brute --- n_samples: {}, n_features: {}",
45+
n_samples, n_features
46+
)),
47+
n_samples,
48+
|b, _| b.iter(|| closest_pair_brute_bench(*n_samples, *n_features)),
49+
);
50+
}
51+
}
52+
group.finish();
53+
}
54+
55+
criterion_group!(benches, bench_fastpair);
56+
criterion_main!(benches);

src/algorithm/neighbour/distances.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//!
2+
//! Dissimilarities for vector-vector distance
3+
//!
4+
//! Representing distances as pairwise dissimilarities, so to build a
5+
//! graph of closest neighbours. This representation can be reused for
6+
//! different implementations (initially used in this library for FastPair).
7+
use std::cmp::{Eq, Ordering, PartialOrd};
8+
9+
#[cfg(feature = "serde")]
10+
use serde::{Deserialize, Serialize};
11+
12+
use crate::math::num::RealNumber;
13+
14+
///
15+
/// The edge of the subgraph is defined by `PairwiseDistance`.
16+
/// The calling algorithm can store a list of distsances as
17+
/// a list of these structures.
18+
///
19+
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20+
#[derive(Debug, Clone, Copy)]
21+
pub struct PairwiseDistance<T: RealNumber> {
22+
/// index of the vector in the original `Matrix` or list
23+
pub node: usize,
24+
25+
/// index of the closest neighbor in the original `Matrix` or same list
26+
pub neighbour: Option<usize>,
27+
28+
/// measure of distance, according to the algorithm distance function
29+
/// if the distance is None, the edge has value "infinite" or max distance
30+
/// each algorithm has to match
31+
pub distance: Option<T>,
32+
}
33+
34+
impl<T: RealNumber> Eq for PairwiseDistance<T> {}
35+
36+
impl<T: RealNumber> PartialEq for PairwiseDistance<T> {
37+
fn eq(&self, other: &Self) -> bool {
38+
self.node == other.node
39+
&& self.neighbour == other.neighbour
40+
&& self.distance == other.distance
41+
}
42+
}
43+
44+
impl<T: RealNumber> PartialOrd for PairwiseDistance<T> {
45+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
46+
self.distance.partial_cmp(&other.distance)
47+
}
48+
}

0 commit comments

Comments
 (0)