Skip to content

Commit 2c0b2b5

Browse files
authored
Merge pull request #10 from zhao-lang/remove-indexarc
Remove indexarc
2 parents 331ab0e + f577976 commit 2c0b2b5

File tree

7 files changed

+206
-166
lines changed

7 files changed

+206
-166
lines changed

cmd.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ redis-cli hnsw.search test1 5 ${data}
1717
for i in {1..100}
1818
do
1919
redis-cli hnsw.node.del test1 node${i-1}
20-
sleep 0.1
20+
# sleep 0.01
2121
done
2222

2323
redis-cli hnsw.del test1

src/hnsw/hnsw.rs renamed to src/hnsw/core.rs

Lines changed: 54 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use super::metrics;
2-
use crate::types;
32

43
use num::Float;
54
use ordered_float::OrderedFloat;
@@ -13,7 +12,14 @@ use std::fmt;
1312
use std::hash::{Hash, Hasher};
1413
use std::rc::Rc;
1514
use std::sync::{Arc, RwLock, Weak};
16-
use std::thread;
15+
// use std::thread;
16+
17+
struct SelectParams {
18+
m: usize,
19+
lc: usize,
20+
extend_candidates: bool,
21+
keep_pruned_connections: bool,
22+
}
1723

1824
#[derive(Debug)]
1925
pub enum HNSWError {
@@ -48,7 +54,7 @@ pub struct SearchResult<T: Float, R: Float> {
4854
impl<T: Float, R: Float> SearchResult<T, R> {
4955
fn new(sim: OrderedFloat<R>, name: &str, data: &[T]) -> Self {
5056
SearchResult {
51-
sim: sim,
57+
sim,
5258
name: name.to_owned(),
5359
data: data.to_vec(),
5460
}
@@ -108,7 +114,7 @@ where
108114
self.neighbors
109115
.iter()
110116
.map(|l| {
111-
l.into_iter()
117+
l.iter()
112118
.map(|n| n.upgrade().read().name.to_owned())
113119
.collect::<Vec<String>>()
114120
})
@@ -249,8 +255,8 @@ where
249255
{
250256
fn new(sim: OrderedFloat<R>, node: Node<T>) -> Self {
251257
let sp = _SimPair {
252-
sim: sim,
253-
node: node,
258+
sim,
259+
node,
254260
};
255261
SimPair(Rc::new(RefCell::new(sp)))
256262
}
@@ -311,7 +317,7 @@ pub struct Index<T: Float, R: Float> {
311317
pub layers: Vec<HashSet<NodeWeak<T>>>, // distinct nodes in each layer
312318
pub nodes: HashMap<String, Node<T>>, // hashmap of nodes
313319
pub enterpoint: Option<NodeWeak<T>>, // enterpoint node
314-
rng_: StdRng, // rng for level generation
320+
pub rng_: StdRng, // rng for level generation
315321
}
316322

317323
impl<T: Float, R: Float> Index<T, R> {
@@ -324,13 +330,13 @@ impl<T: Float, R: Float> Index<T, R> {
324330
) -> Self {
325331
Index {
326332
name: name.to_string(),
327-
mfunc: mfunc,
333+
mfunc,
328334
mfunc_kind: metrics::MetricFuncs::Euclidean,
329-
data_dim: data_dim,
330-
m: m,
335+
data_dim,
336+
m,
331337
m_max: m,
332338
m_max_0: m * 2,
333-
ef_construction: ef_construction,
339+
ef_construction,
334340
level_mult: 1.0 / (1.0 * m as f64).ln(),
335341
node_count: 0,
336342
max_layer: 0,
@@ -371,35 +377,6 @@ impl<T: Float, R: Float> fmt::Debug for Index<T, R> {
371377
}
372378
}
373379

374-
impl From<&types::IndexRedis> for Index<f32, f32> {
375-
fn from(index: &types::IndexRedis) -> Self {
376-
Index {
377-
name: index.name.clone(),
378-
mfunc: match index.mfunc_kind.as_str() {
379-
"Euclidean" => Box::new(metrics::euclidean),
380-
_ => Box::new(metrics::euclidean),
381-
},
382-
mfunc_kind: match index.mfunc_kind.as_str() {
383-
"Euclidean" => metrics::MetricFuncs::Euclidean,
384-
_ => metrics::MetricFuncs::Euclidean,
385-
},
386-
data_dim: index.data_dim,
387-
m: index.m,
388-
m_max: index.m_max,
389-
m_max_0: index.m_max_0,
390-
ef_construction: index.ef_construction,
391-
level_mult: index.level_mult,
392-
node_count: index.node_count,
393-
max_layer: index.max_layer,
394-
// the next 3 need to be populated from redis
395-
layers: Vec::new(),
396-
nodes: HashMap::new(),
397-
enterpoint: None,
398-
rng_: StdRng::from_entropy(),
399-
}
400-
}
401-
}
402-
403380
impl<T, R> Index<T, R>
404381
where
405382
T: Float + Send + Sync + 'static,
@@ -409,7 +386,7 @@ where
409386
&mut self,
410387
name: &str,
411388
data: &[T],
412-
update_fn: fn(String, Node<T>),
389+
update_fn: impl Fn(String, Node<T>),
413390
) -> Result<(), HNSWError> {
414391
if data.len() != self.data_dim {
415392
return Err(format!("data dimension: {} does not match Index", data.len()).into());
@@ -429,7 +406,7 @@ where
429406
return Ok(());
430407
}
431408

432-
if !self.nodes.get(name).is_none() {
409+
if self.nodes.get(name).is_some() {
433410
return Err(format!("Node: {:?} already exists", name).into());
434411
}
435412

@@ -439,7 +416,7 @@ where
439416
pub fn delete_node(
440417
&mut self,
441418
name: &str,
442-
update_fn: fn(String, Node<T>),
419+
update_fn: impl Fn(String, Node<T>),
443420
) -> Result<(), HNSWError> {
444421
let node = match self.nodes.remove(name) {
445422
Some(node) => node,
@@ -467,7 +444,7 @@ where
467444
for n in updated {
468445
let name = n.read().name.clone();
469446
let node = n.clone();
470-
let _ = thread::spawn(move || update_fn(name, node));
447+
update_fn(name, node);
471448
}
472449

473450
// update enterpoint if necessary
@@ -515,7 +492,7 @@ where
515492
&mut self,
516493
name: &str,
517494
data: &[T],
518-
update_fn: fn(String, Node<T>),
495+
update_fn: impl Fn(String, Node<T>),
519496
) -> Result<(), HNSWError> {
520497
let l = self.gen_random_level();
521498
let l_max = self.max_layer;
@@ -547,7 +524,13 @@ where
547524
let mut updated = HashSet::new();
548525
for lc in (0..(min(l_max, l) + 1)).rev() {
549526
w = self.search_level(data, &ep.upgrade(), self.ef_construction, lc);
550-
let mut neighbors = self.select_neighbors(query, &w, self.m, lc, true, true, None);
527+
let params = SelectParams{
528+
m: self.m,
529+
lc,
530+
extend_candidates: true,
531+
keep_pruned_connections: true
532+
};
533+
let mut neighbors = self.select_neighbors(query, &w, params, None);
551534
self.connect_neighbors(query, &neighbors, lc);
552535

553536
// add node to list of nodes to be updated in redis
@@ -578,8 +561,14 @@ where
578561

579562
let m_max = if lc == 0 { self.m_max_0 } else { self.m_max };
580563
if econn.len() > m_max {
564+
let params = SelectParams{
565+
m: m_max,
566+
lc,
567+
extend_candidates: true,
568+
keep_pruned_connections: true
569+
};
581570
let enewconn =
582-
self.select_neighbors(&er.node, &econn, m_max, lc, true, true, None);
571+
self.select_neighbors(&er.node, &econn, params, None);
583572
let up = self.update_node_connections(&er.node, &enewconn, &econn, lc, None);
584573
for u in up {
585574
updated.insert(u);
@@ -594,7 +583,7 @@ where
594583
for n in updated {
595584
let name = n.read().name.clone();
596585
let node = n.clone();
597-
let _ = thread::spawn(move || update_fn(name, node));
586+
update_fn(name, node);
598587
}
599588

600589
// new enterpoint if we're in a higher layer
@@ -692,18 +681,15 @@ where
692681
&self,
693682
query: &Node<T>,
694683
c: &BinaryHeap<SimPair<T, R>>,
695-
m: usize,
696-
lc: usize,
697-
extend_candidates: bool,
698-
keep_pruned_connections: bool,
684+
params: SelectParams,
699685
ignored_node: Option<&Node<T>>,
700686
) -> BinaryHeap<SimPair<T, R>> {
701-
let mut r: BinaryHeap<SimPair<T, R>> = BinaryHeap::with_capacity(m);
687+
let mut r: BinaryHeap<SimPair<T, R>> = BinaryHeap::with_capacity(params.m);
702688
let mut w = c.clone();
703689
let mut wd = BinaryHeap::new();
704690

705691
// extend candidates by their neighbors
706-
if extend_candidates {
692+
if params.extend_candidates {
707693
let mut ccopy = c.clone();
708694

709695
let mut v = HashSet::with_capacity(ccopy.capacity());
@@ -716,7 +702,7 @@ where
716702
while !ccopy.is_empty() {
717703
let epair = ccopy.pop().unwrap();
718704

719-
for eneighbor in &epair.read().node.read().neighbors[lc] {
705+
for eneighbor in &epair.read().node.read().neighbors[params.lc] {
720706
let eneighbor = eneighbor.upgrade();
721707
if eneighbor == *query
722708
|| (ignored_node.is_some() && eneighbor == *ignored_node.unwrap())
@@ -738,7 +724,7 @@ where
738724
}
739725
}
740726

741-
while !w.is_empty() && r.len() < m {
727+
while !w.is_empty() && r.len() < params.m {
742728
let epair = w.pop().unwrap();
743729
let enr = epair.read();
744730

@@ -755,8 +741,8 @@ where
755741
}
756742

757743
// add back some of the discarded connections
758-
if keep_pruned_connections {
759-
while !wd.is_empty() && r.len() < m {
744+
if params.keep_pruned_connections {
745+
while !wd.is_empty() && r.len() < params.m {
760746
let ppair = wd.pop().unwrap();
761747
{
762748
let pr = ppair.read();
@@ -813,11 +799,8 @@ where
813799
updated.insert(npr.node.clone());
814800
// if new neighbor exists in the old set then we remove it from
815801
// the set of neighbors to be removed
816-
match rmconn.iter().position(|n| n.read().node == npr.node) {
817-
Some(index) => {
818-
rmconn.remove(index);
819-
}
820-
None => (),
802+
if let Some(index) = rmconn.iter().position(|n| n.read().node == npr.node) {
803+
rmconn.remove(index);
821804
}
822805
}
823806

@@ -864,7 +847,13 @@ where
864847
}
865848

866849
let m_max = if lc == 0 { self.m_max_0 } else { self.m_max };
867-
nnewconn = self.select_neighbors(&n, &nconn, m_max, lc, true, true, Some(node));
850+
let params = SelectParams{
851+
m: m_max,
852+
lc,
853+
extend_candidates: true,
854+
keep_pruned_connections: true
855+
};
856+
nnewconn = self.select_neighbors(&n, &nconn, params, Some(node));
868857
}
869858
updated.insert(n.clone());
870859
let up = self.update_node_connections(&n, &nnewconn, &nconn, lc, Some(node));
@@ -896,7 +885,7 @@ where
896885
let cnr = cr.node.read();
897886
res.push(SearchResult::new(
898887
cr.sim,
899-
&((&cnr.name).split(".").collect::<Vec<&str>>())
888+
&((&cnr.name).split('.').collect::<Vec<&str>>())
900889
.last()
901890
.unwrap(),
902891
&cnr.data,

src/hnsw/hnsw_tests.rs renamed to src/hnsw/core_tests.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
use crate::hnsw::hnsw::*;
1+
use crate::hnsw::core::*;
22
use crate::hnsw::metrics::euclidean;
33
use std::sync::Arc;
44
use std::{thread, time};
55

66
#[test]
77
fn hnsw_test() {
8+
let n = 100;
9+
let data_dim = 4;
10+
811
// index creation
9-
let mut index: Index<f32, f32> = Index::new("foo", Box::new(euclidean), 4, 5, 16);
12+
let mut index: Index<f32, f32> = Index::new("foo", Box::new(euclidean), data_dim, 5, 16);
1013
assert_eq!(&index.name, "foo");
11-
assert_eq!(index.data_dim, 4);
14+
assert_eq!(index.data_dim, data_dim);
1215
assert_eq!(index.m, 5);
1316
assert_eq!(index.ef_construction, 16);
1417
assert_eq!(index.node_count, 0);
@@ -18,15 +21,15 @@ fn hnsw_test() {
1821
let mock_fn = |_s: String, _n: Node<f32>| {};
1922

2023
// add node
21-
for i in 0..100 {
24+
for i in 0..n {
2225
let name = format!("node{}", i);
23-
let data = vec![i as f32; 4];
26+
let data = vec![i as f32; data_dim];
2427
index.add_node(&name, &data, mock_fn).unwrap();
2528
}
26-
// sleep for a brief period to make sure all threads are done
27-
let ten_millis = time::Duration::from_millis(10);
28-
thread::sleep(ten_millis);
29-
for i in 0..100 {
29+
// // sleep for a brief period to make sure all threads are done
30+
// let ten_millis = time::Duration::from_millis(10);
31+
// thread::sleep(ten_millis);
32+
for i in 0..n {
3033
let node_name = format!("node{}", i);
3134
let node = index.nodes.get(&node_name).unwrap();
3235
let sc = Arc::strong_count(&node.0);
@@ -35,7 +38,7 @@ fn hnsw_test() {
3538
}
3639
assert_eq!(sc, 1);
3740
}
38-
assert_eq!(index.node_count, 100);
41+
assert_eq!(index.node_count, n);
3942
assert_ne!(index.enterpoint, None);
4043

4144
// search
@@ -50,11 +53,11 @@ fn hnsw_test() {
5053
assert_eq!(res[4].sim.into_inner(), -16.0);
5154

5255
// delete node
53-
for i in 0..100 {
56+
for i in 0..n {
5457
let node_name = format!("node{}", i);
5558
let node = index.nodes.get(&node_name).unwrap().clone();
5659
index.delete_node(&node_name, mock_fn).unwrap();
57-
assert_eq!(index.node_count, 100 - i - 1);
60+
assert_eq!(index.node_count, n - i - 1);
5861
assert_eq!(index.nodes.get(&node_name).is_none(), true);
5962
for l in &index.layers {
6063
assert_eq!(l.contains(&node.downgrade()), false);
@@ -66,9 +69,9 @@ fn hnsw_test() {
6669
}
6770
}
6871
}
69-
// sleep for a brief period to make sure all threads are done
70-
let ten_millis = time::Duration::from_millis(10);
71-
thread::sleep(ten_millis);
72+
// // sleep for a brief period to make sure all threads are done
73+
// let ten_millis = time::Duration::from_millis(10);
74+
// thread::sleep(ten_millis);
7275
let sc = Arc::strong_count(&node.0);
7376
if sc > 1 {
7477
println!("Delete {:?}", node);

0 commit comments

Comments
 (0)