Skip to content

Commit c561c11

Browse files
committed
[CLN] Use InternalUpdateConfiguration in Rust, cleanup go code
1 parent a445e12 commit c561c11

File tree

7 files changed

+126
-64
lines changed

7 files changed

+126
-64
lines changed

go/pkg/sysdb/coordinator/table_catalog.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -815,8 +815,8 @@ func (tc *Catalog) updateCollectionConfiguration(
815815

816816
// Update existing configuration with new values
817817
if updateConfig.VectorIndex != nil {
818-
if updateConfig.VectorIndex.Type == "hnsw" && updateConfig.VectorIndex.Hnsw != nil {
819-
if existingConfig.VectorIndex == nil || existingConfig.VectorIndex.Type != "hnsw" {
818+
if updateConfig.VectorIndex.Hnsw != nil {
819+
if existingConfig.VectorIndex == nil || existingConfig.VectorIndex.Hnsw == nil {
820820
return existingConfigJsonStr, nil
821821
}
822822
if updateConfig.VectorIndex.Hnsw.EfSearch != nil {
@@ -837,8 +837,8 @@ func (tc *Catalog) updateCollectionConfiguration(
837837
if updateConfig.VectorIndex.Hnsw.BatchSize != nil {
838838
existingConfig.VectorIndex.Hnsw.BatchSize = *updateConfig.VectorIndex.Hnsw.BatchSize
839839
}
840-
} else if updateConfig.VectorIndex.Type == "spann" && updateConfig.VectorIndex.Spann != nil {
841-
if existingConfig.VectorIndex == nil || existingConfig.VectorIndex.Type != "spann" {
840+
} else if updateConfig.VectorIndex.Spann != nil {
841+
if existingConfig.VectorIndex == nil || existingConfig.VectorIndex.Spann == nil {
842842
return existingConfigJsonStr, nil
843843
}
844844
if updateConfig.VectorIndex.Spann.EfSearch != nil {

rust/frontend/src/server.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ use chroma_types::{
1818
DeleteCollectionRecordsResponse, DeleteDatabaseRequest, DeleteDatabaseResponse,
1919
GetCollectionRequest, GetDatabaseRequest, GetDatabaseResponse, GetRequest, GetResponse,
2020
GetTenantRequest, GetTenantResponse, GetUserIdentityResponse, HeartbeatResponse, IncludeList,
21-
InternalCollectionConfiguration, ListCollectionsRequest, ListCollectionsResponse,
22-
ListDatabasesRequest, ListDatabasesResponse, Metadata, QueryRequest, QueryResponse,
23-
UpdateCollectionConfiguration, UpdateCollectionRecordsResponse, UpdateCollectionResponse,
24-
UpdateMetadata, UpsertCollectionRecordsResponse,
21+
InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListCollectionsRequest,
22+
ListCollectionsResponse, ListDatabasesRequest, ListDatabasesResponse, Metadata, QueryRequest,
23+
QueryResponse, UpdateCollectionConfiguration, UpdateCollectionRecordsResponse,
24+
UpdateCollectionResponse, UpdateMetadata, UpsertCollectionRecordsResponse,
2525
};
2626
use chroma_types::{ForkCollectionResponse, RawWhereFields};
2727
use mdac::{Rule, Scorecard, ScorecardTicket};
@@ -1091,13 +1091,18 @@ async fn update_collection(
10911091
let collection_id =
10921092
CollectionUuid::from_str(&collection_id).map_err(|_| ValidationError::CollectionId)?;
10931093

1094+
let configuration = match payload.new_configuration {
1095+
Some(c) => Some(InternalUpdateCollectionConfiguration::try_from(c)?),
1096+
None => None,
1097+
};
1098+
10941099
let request = chroma_types::UpdateCollectionRequest::try_new(
10951100
collection_id,
10961101
payload.new_name,
10971102
payload
10981103
.new_metadata
10991104
.map(CollectionMetadataUpdate::UpdateMetadata),
1100-
payload.new_configuration,
1105+
configuration,
11011106
)?;
11021107

11031108
server.frontend.update_collection(request).await?;

rust/python_bindings/src/bindings.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ use chroma_types::{
1919
CountResponse, CreateCollectionRequest, CreateDatabaseRequest, CreateTenantRequest, Database,
2020
DeleteCollectionRequest, DeleteDatabaseRequest, GetCollectionRequest, GetDatabaseRequest,
2121
GetResponse, GetTenantRequest, GetTenantResponse, HeartbeatError, IncludeList,
22-
InternalCollectionConfiguration, KnnIndex, ListCollectionsRequest, ListDatabasesRequest,
23-
Metadata, QueryResponse, UpdateCollectionConfiguration, UpdateCollectionRequest,
24-
UpdateMetadata, WrappedSerdeJsonError,
22+
InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, KnnIndex,
23+
ListCollectionsRequest, ListDatabasesRequest, Metadata, QueryResponse,
24+
UpdateCollectionConfiguration, UpdateCollectionRequest, UpdateMetadata, WrappedSerdeJsonError,
2525
};
2626
use pyo3::{exceptions::PyValueError, pyclass, pyfunction, pymethods, types::PyAnyMethods, Python};
2727
use std::time::SystemTime;
@@ -344,11 +344,16 @@ impl Bindings {
344344
None => None,
345345
};
346346

347+
let configuration = match configuration_json {
348+
Some(c) => Some(InternalUpdateCollectionConfiguration::try_from(c)?),
349+
None => None,
350+
};
351+
347352
let request = UpdateCollectionRequest::try_new(
348353
collection_id,
349354
new_name,
350355
new_metadata.map(CollectionMetadataUpdate::UpdateMetadata),
351-
configuration_json,
356+
configuration,
352357
)?;
353358

354359
let mut frontend = self.frontend.clone();

rust/sysdb/src/sqlite.rs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ use chroma_types::{
1212
CreateTenantError, CreateTenantResponse, Database, DatabaseUuid, DeleteCollectionError,
1313
DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionWithSegmentsError,
1414
GetCollectionsError, GetDatabaseError, GetSegmentsError, GetTenantError, GetTenantResponse,
15-
InternalCollectionConfiguration, ListDatabasesError, Metadata, MetadataValue, ResetError,
16-
ResetResponse, Segment, SegmentScope, SegmentType, SegmentUuid, UpdateCollectionConfiguration,
17-
UpdateCollectionError,
15+
InternalCollectionConfiguration, InternalUpdateCollectionConfiguration, ListDatabasesError,
16+
Metadata, MetadataValue, ResetError, ResetResponse, Segment, SegmentScope, SegmentType,
17+
SegmentUuid, UpdateCollectionError,
1818
};
1919
use futures::TryStreamExt;
2020
use sea_query_binder::SqlxBinder;
@@ -356,7 +356,7 @@ impl SqliteSysDb {
356356
name: Option<String>,
357357
metadata: Option<CollectionMetadataUpdate>,
358358
dimension: Option<u32>,
359-
configuration: Option<UpdateCollectionConfiguration>,
359+
configuration: Option<InternalUpdateCollectionConfiguration>,
360360
) -> Result<(), UpdateCollectionError> {
361361
let mut tx = self
362362
.db
@@ -1048,8 +1048,9 @@ mod tests {
10481048
use super::*;
10491049
use chroma_sqlite::db::test_utils::get_new_sqlite_db;
10501050
use chroma_types::{
1051-
SegmentScope, SegmentType, SegmentUuid, UpdateHnswConfiguration, UpdateMetadata,
1052-
UpdateMetadataValue, VectorIndexConfiguration,
1051+
InternalUpdateCollectionConfiguration, SegmentScope, SegmentType, SegmentUuid,
1052+
UpdateHnswConfiguration, UpdateMetadata, UpdateMetadataValue,
1053+
UpdateVectorIndexConfiguration, VectorIndexConfiguration,
10531054
};
10541055

10551056
#[tokio::test]
@@ -1354,13 +1355,14 @@ mod tests {
13541355
Some("new_name".to_string()),
13551356
Some(CollectionMetadataUpdate::UpdateMetadata(metadata)),
13561357
Some(1024),
1357-
Some(UpdateCollectionConfiguration {
1358-
hnsw: Some(UpdateHnswConfiguration {
1359-
ef_search: Some(20),
1360-
num_threads: Some(4),
1361-
..Default::default()
1362-
}),
1363-
spann: None,
1358+
Some(InternalUpdateCollectionConfiguration {
1359+
vector_index: Some(UpdateVectorIndexConfiguration::Hnsw(Some(
1360+
UpdateHnswConfiguration {
1361+
ef_search: Some(10),
1362+
num_threads: Some(2),
1363+
..Default::default()
1364+
},
1365+
))),
13641366
embedding_function: None,
13651367
}),
13661368
)

rust/sysdb/src/sysdb.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ use chroma_types::{
1414
DeleteCollectionError, DeleteDatabaseError, DeleteDatabaseResponse, GetCollectionSizeError,
1515
GetCollectionWithSegmentsError, GetCollectionsError, GetDatabaseError, GetDatabaseResponse,
1616
GetSegmentsError, GetTenantError, GetTenantResponse, InternalCollectionConfiguration,
17-
ListCollectionVersionsError, ListDatabasesError, ListDatabasesResponse, Metadata, ResetError,
18-
ResetResponse, SegmentFlushInfo, SegmentFlushInfoConversionError, SegmentUuid,
19-
UpdateCollectionConfiguration, UpdateCollectionError, VectorIndexConfiguration,
17+
InternalUpdateCollectionConfiguration, ListCollectionVersionsError, ListDatabasesError,
18+
ListDatabasesResponse, Metadata, ResetError, ResetResponse, SegmentFlushInfo,
19+
SegmentFlushInfoConversionError, SegmentUuid, UpdateCollectionError, VectorIndexConfiguration,
2020
};
2121
use chroma_types::{
2222
BatchGetCollectionSoftDeleteStatusError, BatchGetCollectionVersionFilePathsError, Collection,
@@ -299,7 +299,7 @@ impl SysDb {
299299
name: Option<String>,
300300
metadata: Option<CollectionMetadataUpdate>,
301301
dimension: Option<u32>,
302-
configuration: Option<UpdateCollectionConfiguration>,
302+
configuration: Option<InternalUpdateCollectionConfiguration>,
303303
) -> Result<(), UpdateCollectionError> {
304304
match self {
305305
SysDb::Grpc(grpc) => {
@@ -984,7 +984,7 @@ impl GrpcSysDb {
984984
name: Option<String>,
985985
metadata: Option<CollectionMetadataUpdate>,
986986
dimension: Option<u32>,
987-
configuration: Option<UpdateCollectionConfiguration>,
987+
configuration: Option<InternalUpdateCollectionConfiguration>,
988988
) -> Result<(), UpdateCollectionError> {
989989
let mut configuration_json_str = None;
990990
if let Some(configuration) = configuration {

rust/types/src/api_types.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::collection_configuration::InternalCollectionConfiguration;
2-
use crate::collection_configuration::UpdateCollectionConfiguration;
2+
use crate::collection_configuration::InternalUpdateCollectionConfiguration;
33
use crate::error::QueryConversionError;
44
use crate::operator::GetResult;
55
use crate::operator::KnnBatchResult;
@@ -724,15 +724,15 @@ pub struct UpdateCollectionRequest {
724724
pub new_name: Option<String>,
725725
#[validate(custom(function = "validate_non_empty_collection_update_metadata"))]
726726
pub new_metadata: Option<CollectionMetadataUpdate>,
727-
pub new_configuration: Option<UpdateCollectionConfiguration>,
727+
pub new_configuration: Option<InternalUpdateCollectionConfiguration>,
728728
}
729729

730730
impl UpdateCollectionRequest {
731731
pub fn try_new(
732732
collection_id: CollectionUuid,
733733
new_name: Option<String>,
734734
new_metadata: Option<CollectionMetadataUpdate>,
735-
new_configuration: Option<UpdateCollectionConfiguration>,
735+
new_configuration: Option<InternalUpdateCollectionConfiguration>,
736736
) -> Result<Self, ChromaValidationError> {
737737
let request = Self {
738738
collection_id,

rust/types/src/collection_configuration.rs

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -152,39 +152,47 @@ impl InternalCollectionConfiguration {
152152
}
153153
}
154154

155-
pub fn update(&mut self, configuration: &UpdateCollectionConfiguration) {
155+
pub fn update(&mut self, configuration: &InternalUpdateCollectionConfiguration) {
156156
// Update vector_index if it exists in the update configuration
157157

158-
if let Some(hnsw_config) = &configuration.hnsw {
159-
if let VectorIndexConfiguration::Hnsw(current_config) = &mut self.vector_index {
160-
// Update only the non-None fields from the update configuration
161-
if let Some(ef_search) = hnsw_config.ef_search {
162-
current_config.ef_search = ef_search;
163-
}
164-
if let Some(max_neighbors) = hnsw_config.max_neighbors {
165-
current_config.max_neighbors = max_neighbors;
166-
}
167-
if let Some(num_threads) = hnsw_config.num_threads {
168-
current_config.num_threads = num_threads;
169-
}
170-
if let Some(resize_factor) = hnsw_config.resize_factor {
171-
current_config.resize_factor = resize_factor;
172-
}
173-
if let Some(sync_threshold) = hnsw_config.sync_threshold {
174-
current_config.sync_threshold = sync_threshold;
175-
}
176-
if let Some(batch_size) = hnsw_config.batch_size {
177-
current_config.batch_size = batch_size;
178-
}
179-
}
180-
}
181-
if let Some(spann_config) = &configuration.spann {
182-
if let VectorIndexConfiguration::Spann(current_config) = &mut self.vector_index {
183-
if let Some(search_nprobe) = spann_config.search_nprobe {
184-
current_config.search_nprobe = search_nprobe;
158+
if let Some(vector_index) = &configuration.vector_index {
159+
match vector_index {
160+
UpdateVectorIndexConfiguration::Hnsw(hnsw_config) => {
161+
if let VectorIndexConfiguration::Hnsw(current_config) = &mut self.vector_index {
162+
if let Some(update_config) = hnsw_config {
163+
if let Some(ef_search) = update_config.ef_search {
164+
current_config.ef_search = ef_search;
165+
}
166+
if let Some(max_neighbors) = update_config.max_neighbors {
167+
current_config.max_neighbors = max_neighbors;
168+
}
169+
if let Some(num_threads) = update_config.num_threads {
170+
current_config.num_threads = num_threads;
171+
}
172+
if let Some(resize_factor) = update_config.resize_factor {
173+
current_config.resize_factor = resize_factor;
174+
}
175+
if let Some(sync_threshold) = update_config.sync_threshold {
176+
current_config.sync_threshold = sync_threshold;
177+
}
178+
if let Some(batch_size) = update_config.batch_size {
179+
current_config.batch_size = batch_size;
180+
}
181+
}
182+
}
185183
}
186-
if let Some(ef_search) = spann_config.ef_search {
187-
current_config.ef_search = ef_search;
184+
UpdateVectorIndexConfiguration::Spann(spann_config) => {
185+
if let VectorIndexConfiguration::Spann(current_config) = &mut self.vector_index
186+
{
187+
if let Some(update_config) = spann_config {
188+
if let Some(search_nprobe) = update_config.search_nprobe {
189+
current_config.search_nprobe = search_nprobe;
190+
}
191+
if let Some(ef_search) = update_config.ef_search {
192+
current_config.ef_search = ef_search;
193+
}
194+
}
195+
}
188196
}
189197
}
190198
}
@@ -398,6 +406,48 @@ pub struct UpdateCollectionConfiguration {
398406
pub embedding_function: Option<EmbeddingFunctionConfiguration>,
399407
}
400408

409+
#[derive(Deserialize, Serialize, ToSchema, Debug, Clone)]
410+
pub struct InternalUpdateCollectionConfiguration {
411+
pub vector_index: Option<UpdateVectorIndexConfiguration>,
412+
pub embedding_function: Option<EmbeddingFunctionConfiguration>,
413+
}
414+
415+
#[derive(Debug, Error)]
416+
pub enum UpdateCollectionConfigurationToInternalUpdateConfigurationError {
417+
#[error("Multiple vector index configurations provided")]
418+
MultipleVectorIndexConfigurations,
419+
}
420+
421+
impl ChromaError for UpdateCollectionConfigurationToInternalUpdateConfigurationError {
422+
fn code(&self) -> ErrorCodes {
423+
match self {
424+
Self::MultipleVectorIndexConfigurations => ErrorCodes::InvalidArgument,
425+
}
426+
}
427+
}
428+
429+
impl TryFrom<UpdateCollectionConfiguration> for InternalUpdateCollectionConfiguration {
430+
type Error = UpdateCollectionConfigurationToInternalUpdateConfigurationError;
431+
432+
fn try_from(value: UpdateCollectionConfiguration) -> Result<Self, Self::Error> {
433+
match (value.hnsw, value.spann) {
434+
(Some(_), Some(_)) => Err(Self::Error::MultipleVectorIndexConfigurations),
435+
(Some(hnsw), None) => Ok(InternalUpdateCollectionConfiguration {
436+
vector_index: Some(UpdateVectorIndexConfiguration::Hnsw(Some(hnsw))),
437+
embedding_function: value.embedding_function,
438+
}),
439+
(None, Some(spann)) => Ok(InternalUpdateCollectionConfiguration {
440+
vector_index: Some(UpdateVectorIndexConfiguration::Spann(Some(spann))),
441+
embedding_function: value.embedding_function,
442+
}),
443+
(None, None) => Ok(InternalUpdateCollectionConfiguration {
444+
vector_index: None,
445+
embedding_function: value.embedding_function,
446+
}),
447+
}
448+
}
449+
}
450+
401451
#[cfg(test)]
402452
mod tests {
403453
use crate::hnsw_configuration::HnswConfiguration;

0 commit comments

Comments
 (0)