Skip to content

Commit 89a769d

Browse files
authored
[ENH] Add in parallel to local hsnw (#3866)
## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Uses rayon to write in batches to hnsw, mirror'ing python - misc cleanup - disable otel in config for logspam, we should eventually turn it on - New functionality - ... ## Test plan *How are these changes tested?* Existing tests and manually - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes None
1 parent 90e4ad3 commit 89a769d

File tree

7 files changed

+93
-56
lines changed

7 files changed

+93
-56
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ regex = "1.11.1"
4747
pyo3 = { version = "0.23.3", features = ["abi3-py39"] }
4848
tower-http = { version = "0.6.2", features = ["trace", "cors"] }
4949
bytemuck = "1.21.0"
50+
rayon = "1.10.0"
5051
validator = { version = "0.19", features = ["derive"] }
5152
rust-embed = { version = "8.5.0", features = ["include-exclude", "debug-embed"] }
5253
hnswlib = { version = "0.8.0", git = "https://github.com/chroma-core/hnswlib.git" }
@@ -81,7 +82,6 @@ proptest-state-machine = "0.3.0"
8182
proptest-derive = "0.5.1"
8283
rand = "0.8.5"
8384
rand_xorshift = "0.3.0"
84-
rayon = "1.10.0"
8585
shuttle = "0.7.1"
8686
tempfile = "3.14.0"
8787
itertools = "0.13.0"

chromadb/api/rust.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ def __init__(self, system: System):
7373
def start(self) -> None:
7474
# Construct the SqliteConfig
7575
# TOOD: We should add a "config converter"
76-
# TODO: How to name this file?
77-
# TODO: proper path handling
7876
if self._system.settings.require("is_persistent"):
7977
persist_path = self._system.settings.require("persist_directory")
8078
sqlite_persist_path = persist_path + "/chroma.sqlite3"
@@ -500,7 +498,8 @@ def _delete(
500498
CollectionDeleteEvent(
501499
# NOTE: the delete amount is not observable from python
502500
# TODO: Fix this when posthog is pushed into Rust frontend
503-
collection_uuid=str(collection_id), delete_amount=0
501+
collection_uuid=str(collection_id),
502+
delete_amount=0,
504503
)
505504
)
506505

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1 @@
11
persist_path: "./chroma"
2-
open_telemetry:
3-
service_name: "rust-frontend-service"
4-
endpoint: "http://otel-collector:4317"

rust/python_bindings/src/bindings.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ impl Bindings {
5959
hnsw_cache_size: usize,
6060
persist_path: Option<String>,
6161
) -> ChromaPyResult<Self> {
62-
// TODO: runtime config
6362
let runtime = tokio::runtime::Runtime::new().unwrap();
6463
let _guard = runtime.enter();
6564
let system = System::new();
@@ -115,7 +114,6 @@ impl Bindings {
115114
}
116115

117116
/// Returns the current eopch time in ns
118-
/// TODO(hammadb): This should proxy to ServerAPI
119117
#[allow(dead_code)]
120118
fn heartbeat(&self) -> ChromaPyResult<u128> {
121119
let duration_since_epoch = std::time::SystemTime::now()
@@ -129,11 +127,6 @@ impl Bindings {
129127
self.frontend.clone().get_max_batch_size()
130128
}
131129

132-
// TODO(hammadb): Determine our pattern for optional arguments in python
133-
// options include using Option or passing defaults from python
134-
// or using pyargs annotations such as
135-
// #[pyargs(limit = "None", offset = "None")]
136-
137130
////////////////////////////// Admin API //////////////////////////////
138131

139132
fn create_database(&self, name: String, tenant: String, _py: Python<'_>) -> ChromaPyResult<()> {

rust/segment/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ thiserror = { workspace = true }
1919
tracing = { workspace = true }
2020
uuid = { workspace = true }
2121
serde_json = { workspace = true }
22+
rayon = { workspace = true }
2223

2324
chroma-blockstore = { workspace = true }
2425
chroma-cache = { workspace = true }

rust/segment/src/local_hnsw.rs

Lines changed: 88 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ use chroma_error::{ChromaError, ErrorCodes};
55
use chroma_index::{HnswIndex, HnswIndexConfig, Index, IndexConfig, PersistentIndex};
66
use chroma_sqlite::{db::SqliteDb, table::MaxSeqId};
77
use chroma_types::{
8-
operator::RecordDistance, Chunk, HnswParametersFromSegmentError, LogRecord, Operation, Segment,
9-
SegmentUuid, SingleNodeHnswParameters,
8+
operator::RecordDistance, Chunk, HnswParametersFromSegmentError, LogRecord, Operation,
9+
OperationRecord, Segment, SegmentUuid, SingleNodeHnswParameters,
1010
};
11+
use rayon::iter::{IntoParallelIterator, ParallelIterator};
1112
use sea_query::{Expr, OnConflict, Query, SqliteQueryBuilder};
1213
use sea_query_binder::SqlxBinder;
1314
use serde::{Deserialize, Serialize};
@@ -520,6 +521,9 @@ impl LocalHnswSegmentWriter {
520521
return Ok(next_label);
521522
}
522523
let mut max_seq_id = u64::MIN;
524+
// In order to insert into hnsw index in parallel, we need to collect all the embeddings
525+
let mut hnsw_batch: HashMap<u32, Vec<(u32, &OperationRecord)>> =
526+
HashMap::with_capacity(log_chunk.len());
523527
for (log, _) in log_chunk.iter() {
524528
guard.num_elements_since_last_persist += 1;
525529
max_seq_id = max_seq_id.max(log.log_offset as u64);
@@ -528,7 +532,7 @@ impl LocalHnswSegmentWriter {
528532
// only update if the id is not already present
529533
if !guard.id_map.id_to_label.contains_key(&log.record.id) {
530534
match &log.record.embedding {
531-
Some(embedding) => {
535+
Some(_embedding) => {
532536
guard
533537
.id_map
534538
.id_to_label
@@ -537,17 +541,15 @@ impl LocalHnswSegmentWriter {
537541
.id_map
538542
.label_to_id
539543
.insert(next_label, log.record.id.clone());
540-
let index_len = guard.index.len_with_deleted();
541-
let index_capacity = guard.index.capacity();
542-
if index_len + 1 > index_capacity {
543-
guard.index.resize(index_capacity * 2).map_err(|_| {
544-
LocalHnswSegmentWriterError::HnswIndexResizeError
545-
})?;
546-
}
547-
guard
548-
.index
549-
.add(next_label as usize, embedding.as_slice())
550-
.map_err(|_| LocalHnswSegmentWriterError::HnwsIndexAddError)?;
544+
let records_for_label = match hnsw_batch.get_mut(&next_label) {
545+
Some(records) => records,
546+
None => {
547+
hnsw_batch.insert(next_label, Vec::new());
548+
// SAFETY: We just inserted the key. We have exclusive access to the map.
549+
hnsw_batch.get_mut(&next_label).unwrap()
550+
}
551+
};
552+
records_for_label.push((next_label, &log.record));
551553
next_label += 1;
552554
}
553555
None => {
@@ -558,29 +560,32 @@ impl LocalHnswSegmentWriter {
558560
}
559561
Operation::Update => {
560562
if let Some(label) = guard.id_map.id_to_label.get(&log.record.id).cloned() {
561-
if let Some(embedding) = &log.record.embedding {
562-
let index_len = guard.index.len_with_deleted();
563-
let index_capacity = guard.index.capacity();
564-
if index_len + 1 > index_capacity {
565-
guard.index.resize(index_capacity * 2).map_err(|_| {
566-
LocalHnswSegmentWriterError::HnswIndexResizeError
567-
})?;
568-
}
569-
guard
570-
.index
571-
.add(label as usize, embedding.as_slice())
572-
.map_err(|_| LocalHnswSegmentWriterError::HnwsIndexAddError)?;
563+
if let Some(_embedding) = &log.record.embedding {
564+
let records_for_label = match hnsw_batch.get_mut(&label) {
565+
Some(records) => records,
566+
None => {
567+
hnsw_batch.insert(label, Vec::new());
568+
// SAFETY: We just inserted the key. We have exclusive access to the map.
569+
hnsw_batch.get_mut(&label).unwrap()
570+
}
571+
};
572+
records_for_label.push((label, &log.record));
573573
}
574574
}
575575
}
576576
Operation::Delete => {
577577
if let Some(label) = guard.id_map.id_to_label.get(&log.record.id).cloned() {
578578
guard.id_map.id_to_label.remove(&log.record.id);
579579
guard.id_map.label_to_id.remove(&label);
580-
guard
581-
.index
582-
.delete(label as usize)
583-
.map_err(|_| LocalHnswSegmentWriterError::HnswIndexDeleteError)?;
580+
let records_for_label = match hnsw_batch.get_mut(&label) {
581+
Some(records) => records,
582+
None => {
583+
hnsw_batch.insert(label, Vec::new());
584+
// SAFETY: We just inserted the key. We have exclusive access to the map.
585+
hnsw_batch.get_mut(&label).unwrap()
586+
}
587+
};
588+
records_for_label.push((label, &log.record));
584589
}
585590
}
586591
Operation::Upsert => {
@@ -593,7 +598,7 @@ impl LocalHnswSegmentWriter {
593598
}
594599
};
595600
match &log.record.embedding {
596-
Some(embedding) => {
601+
Some(_embedding) => {
597602
guard
598603
.id_map
599604
.id_to_label
@@ -602,17 +607,15 @@ impl LocalHnswSegmentWriter {
602607
.id_map
603608
.label_to_id
604609
.insert(label, log.record.id.clone());
605-
let index_len = guard.index.len_with_deleted();
606-
let index_capacity = guard.index.capacity();
607-
if index_len + 1 > index_capacity {
608-
guard.index.resize(index_capacity * 2).map_err(|_| {
609-
LocalHnswSegmentWriterError::HnswIndexResizeError
610-
})?;
611-
}
612-
guard
613-
.index
614-
.add(label as usize, embedding.as_slice())
615-
.map_err(|_| LocalHnswSegmentWriterError::HnwsIndexAddError)?;
610+
let records_for_label = match hnsw_batch.get_mut(&label) {
611+
Some(records) => records,
612+
None => {
613+
hnsw_batch.insert(label, Vec::new());
614+
// SAFETY: We just inserted the key. We have exclusive access to the map.
615+
hnsw_batch.get_mut(&label).unwrap()
616+
}
617+
};
618+
records_for_label.push((label, &log.record));
616619
if update_label {
617620
next_label += 1;
618621
}
@@ -624,6 +627,49 @@ impl LocalHnswSegmentWriter {
624627
}
625628
}
626629
}
630+
631+
// Add to hnsw index in parallel using rayon.
632+
// Resize the index if needed
633+
let index_len = guard.index.len_with_deleted();
634+
let index_capacity = guard.index.capacity();
635+
if index_len + hnsw_batch.len() >= index_capacity {
636+
let needed_capacity = (index_len + hnsw_batch.len()).next_power_of_two();
637+
guard
638+
.index
639+
.resize(needed_capacity)
640+
.map_err(|_| LocalHnswSegmentWriterError::HnswIndexResizeError)?;
641+
}
642+
let index_for_pool = &guard.index;
643+
644+
hnsw_batch
645+
.into_par_iter()
646+
.map(|(_, records)| {
647+
for (label, log_record) in records {
648+
match log_record.operation {
649+
Operation::Add | Operation::Upsert | Operation::Update => {
650+
let embedding = log_record.embedding.as_ref().expect(
651+
"Add, update or upsert should have an embedding at this point",
652+
);
653+
match index_for_pool.add(label as usize, embedding) {
654+
Ok(_) => {}
655+
Err(_e) => {
656+
return Err(LocalHnswSegmentWriterError::HnwsIndexAddError);
657+
}
658+
}
659+
}
660+
Operation::Delete => match index_for_pool.delete(label as usize) {
661+
Ok(_) => {}
662+
Err(_e) => {
663+
return Err(LocalHnswSegmentWriterError::HnswIndexDeleteError);
664+
}
665+
},
666+
}
667+
}
668+
Ok(())
669+
})
670+
.find_any(|result| result.is_err())
671+
.unwrap_or(Ok(()))?;
672+
627673
guard.id_map.total_elements_added = next_label - 1;
628674
if guard.num_elements_since_last_persist >= guard.sync_threshold as u64 {
629675
guard = persist(guard).await?;

0 commit comments

Comments
 (0)