@@ -5,9 +5,10 @@ use chroma_error::{ChromaError, ErrorCodes};
5
5
use chroma_index:: { HnswIndex , HnswIndexConfig , Index , IndexConfig , PersistentIndex } ;
6
6
use chroma_sqlite:: { db:: SqliteDb , table:: MaxSeqId } ;
7
7
use chroma_types:: {
8
- operator:: RecordDistance , Chunk , HnswParametersFromSegmentError , LogRecord , Operation , Segment ,
9
- SegmentUuid , SingleNodeHnswParameters ,
8
+ operator:: RecordDistance , Chunk , HnswParametersFromSegmentError , LogRecord , Operation ,
9
+ OperationRecord , Segment , SegmentUuid , SingleNodeHnswParameters ,
10
10
} ;
11
+ use rayon:: iter:: { IntoParallelIterator , ParallelIterator } ;
11
12
use sea_query:: { Expr , OnConflict , Query , SqliteQueryBuilder } ;
12
13
use sea_query_binder:: SqlxBinder ;
13
14
use serde:: { Deserialize , Serialize } ;
@@ -520,6 +521,9 @@ impl LocalHnswSegmentWriter {
520
521
return Ok ( next_label) ;
521
522
}
522
523
let mut max_seq_id = u64:: MIN ;
524
+ // In order to insert into hnsw index in parallel, we need to collect all the embeddings
525
+ let mut hnsw_batch: HashMap < u32 , Vec < ( u32 , & OperationRecord ) > > =
526
+ HashMap :: with_capacity ( log_chunk. len ( ) ) ;
523
527
for ( log, _) in log_chunk. iter ( ) {
524
528
guard. num_elements_since_last_persist += 1 ;
525
529
max_seq_id = max_seq_id. max ( log. log_offset as u64 ) ;
@@ -528,7 +532,7 @@ impl LocalHnswSegmentWriter {
528
532
// only update if the id is not already present
529
533
if !guard. id_map . id_to_label . contains_key ( & log. record . id ) {
530
534
match & log. record . embedding {
531
- Some ( embedding ) => {
535
+ Some ( _embedding ) => {
532
536
guard
533
537
. id_map
534
538
. id_to_label
@@ -537,17 +541,15 @@ impl LocalHnswSegmentWriter {
537
541
. id_map
538
542
. label_to_id
539
543
. insert ( next_label, log. record . id . clone ( ) ) ;
540
- let index_len = guard. index . len_with_deleted ( ) ;
541
- let index_capacity = guard. index . capacity ( ) ;
542
- if index_len + 1 > index_capacity {
543
- guard. index . resize ( index_capacity * 2 ) . map_err ( |_| {
544
- LocalHnswSegmentWriterError :: HnswIndexResizeError
545
- } ) ?;
546
- }
547
- guard
548
- . index
549
- . add ( next_label as usize , embedding. as_slice ( ) )
550
- . map_err ( |_| LocalHnswSegmentWriterError :: HnwsIndexAddError ) ?;
544
+ let records_for_label = match hnsw_batch. get_mut ( & next_label) {
545
+ Some ( records) => records,
546
+ None => {
547
+ hnsw_batch. insert ( next_label, Vec :: new ( ) ) ;
548
+ // SAFETY: We just inserted the key. We have exclusive access to the map.
549
+ hnsw_batch. get_mut ( & next_label) . unwrap ( )
550
+ }
551
+ } ;
552
+ records_for_label. push ( ( next_label, & log. record ) ) ;
551
553
next_label += 1 ;
552
554
}
553
555
None => {
@@ -558,29 +560,32 @@ impl LocalHnswSegmentWriter {
558
560
}
559
561
Operation :: Update => {
560
562
if let Some ( label) = guard. id_map . id_to_label . get ( & log. record . id ) . cloned ( ) {
561
- if let Some ( embedding) = & log. record . embedding {
562
- let index_len = guard. index . len_with_deleted ( ) ;
563
- let index_capacity = guard. index . capacity ( ) ;
564
- if index_len + 1 > index_capacity {
565
- guard. index . resize ( index_capacity * 2 ) . map_err ( |_| {
566
- LocalHnswSegmentWriterError :: HnswIndexResizeError
567
- } ) ?;
568
- }
569
- guard
570
- . index
571
- . add ( label as usize , embedding. as_slice ( ) )
572
- . map_err ( |_| LocalHnswSegmentWriterError :: HnwsIndexAddError ) ?;
563
+ if let Some ( _embedding) = & log. record . embedding {
564
+ let records_for_label = match hnsw_batch. get_mut ( & label) {
565
+ Some ( records) => records,
566
+ None => {
567
+ hnsw_batch. insert ( label, Vec :: new ( ) ) ;
568
+ // SAFETY: We just inserted the key. We have exclusive access to the map.
569
+ hnsw_batch. get_mut ( & label) . unwrap ( )
570
+ }
571
+ } ;
572
+ records_for_label. push ( ( label, & log. record ) ) ;
573
573
}
574
574
}
575
575
}
576
576
Operation :: Delete => {
577
577
if let Some ( label) = guard. id_map . id_to_label . get ( & log. record . id ) . cloned ( ) {
578
578
guard. id_map . id_to_label . remove ( & log. record . id ) ;
579
579
guard. id_map . label_to_id . remove ( & label) ;
580
- guard
581
- . index
582
- . delete ( label as usize )
583
- . map_err ( |_| LocalHnswSegmentWriterError :: HnswIndexDeleteError ) ?;
580
+ let records_for_label = match hnsw_batch. get_mut ( & label) {
581
+ Some ( records) => records,
582
+ None => {
583
+ hnsw_batch. insert ( label, Vec :: new ( ) ) ;
584
+ // SAFETY: We just inserted the key. We have exclusive access to the map.
585
+ hnsw_batch. get_mut ( & label) . unwrap ( )
586
+ }
587
+ } ;
588
+ records_for_label. push ( ( label, & log. record ) ) ;
584
589
}
585
590
}
586
591
Operation :: Upsert => {
@@ -593,7 +598,7 @@ impl LocalHnswSegmentWriter {
593
598
}
594
599
} ;
595
600
match & log. record . embedding {
596
- Some ( embedding ) => {
601
+ Some ( _embedding ) => {
597
602
guard
598
603
. id_map
599
604
. id_to_label
@@ -602,17 +607,15 @@ impl LocalHnswSegmentWriter {
602
607
. id_map
603
608
. label_to_id
604
609
. insert ( label, log. record . id . clone ( ) ) ;
605
- let index_len = guard. index . len_with_deleted ( ) ;
606
- let index_capacity = guard. index . capacity ( ) ;
607
- if index_len + 1 > index_capacity {
608
- guard. index . resize ( index_capacity * 2 ) . map_err ( |_| {
609
- LocalHnswSegmentWriterError :: HnswIndexResizeError
610
- } ) ?;
611
- }
612
- guard
613
- . index
614
- . add ( label as usize , embedding. as_slice ( ) )
615
- . map_err ( |_| LocalHnswSegmentWriterError :: HnwsIndexAddError ) ?;
610
+ let records_for_label = match hnsw_batch. get_mut ( & label) {
611
+ Some ( records) => records,
612
+ None => {
613
+ hnsw_batch. insert ( label, Vec :: new ( ) ) ;
614
+ // SAFETY: We just inserted the key. We have exclusive access to the map.
615
+ hnsw_batch. get_mut ( & label) . unwrap ( )
616
+ }
617
+ } ;
618
+ records_for_label. push ( ( label, & log. record ) ) ;
616
619
if update_label {
617
620
next_label += 1 ;
618
621
}
@@ -624,6 +627,49 @@ impl LocalHnswSegmentWriter {
624
627
}
625
628
}
626
629
}
630
+
631
+ // Add to hnsw index in parallel using rayon.
632
+ // Resize the index if needed
633
+ let index_len = guard. index . len_with_deleted ( ) ;
634
+ let index_capacity = guard. index . capacity ( ) ;
635
+ if index_len + hnsw_batch. len ( ) >= index_capacity {
636
+ let needed_capacity = ( index_len + hnsw_batch. len ( ) ) . next_power_of_two ( ) ;
637
+ guard
638
+ . index
639
+ . resize ( needed_capacity)
640
+ . map_err ( |_| LocalHnswSegmentWriterError :: HnswIndexResizeError ) ?;
641
+ }
642
+ let index_for_pool = & guard. index ;
643
+
644
+ hnsw_batch
645
+ . into_par_iter ( )
646
+ . map ( |( _, records) | {
647
+ for ( label, log_record) in records {
648
+ match log_record. operation {
649
+ Operation :: Add | Operation :: Upsert | Operation :: Update => {
650
+ let embedding = log_record. embedding . as_ref ( ) . expect (
651
+ "Add, update or upsert should have an embedding at this point" ,
652
+ ) ;
653
+ match index_for_pool. add ( label as usize , embedding) {
654
+ Ok ( _) => { }
655
+ Err ( _e) => {
656
+ return Err ( LocalHnswSegmentWriterError :: HnwsIndexAddError ) ;
657
+ }
658
+ }
659
+ }
660
+ Operation :: Delete => match index_for_pool. delete ( label as usize ) {
661
+ Ok ( _) => { }
662
+ Err ( _e) => {
663
+ return Err ( LocalHnswSegmentWriterError :: HnswIndexDeleteError ) ;
664
+ }
665
+ } ,
666
+ }
667
+ }
668
+ Ok ( ( ) )
669
+ } )
670
+ . find_any ( |result| result. is_err ( ) )
671
+ . unwrap_or ( Ok ( ( ) ) ) ?;
672
+
627
673
guard. id_map . total_elements_added = next_label - 1 ;
628
674
if guard. num_elements_since_last_persist >= guard. sync_threshold as u64 {
629
675
guard = persist ( guard) . await ?;
0 commit comments