Skip to content

Commit ff5ff3d

Browse files
committed
[ENH] Add schema support to collection configuration
1 parent 72ff620 commit ff5ff3d

File tree

16 files changed

+345
-40
lines changed

16 files changed

+345
-40
lines changed

chromadb/api/collection_configuration.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
UpdateMetadata,
77
EmbeddingFunction,
88
)
9+
from chromadb.base_types import CollectionSchema, ValueType
910
from chromadb.utils.embedding_functions import (
1011
known_embedding_functions,
1112
register_embedding_function,
@@ -41,6 +42,7 @@ class CollectionConfiguration(TypedDict, total=True):
4142
hnsw: Optional[HNSWConfiguration]
4243
spann: Optional[SpannConfiguration]
4344
embedding_function: Optional[EmbeddingFunction] # type: ignore
45+
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]
4446

4547

4648
def load_collection_configuration_from_json_str(
@@ -107,6 +109,7 @@ def load_collection_configuration_from_json(
107109
hnsw=hnsw_config,
108110
spann=spann_config,
109111
embedding_function=ef, # type: ignore
112+
schema=config_json_map.get("schema"),
110113
)
111114

112115

@@ -119,6 +122,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
119122
hnsw_config = config.get("hnsw")
120123
spann_config = config.get("spann")
121124
ef = config.get("embedding_function")
125+
schema = config.get("schema")
122126
else:
123127
try:
124128
hnsw_config = config.get_parameter("hnsw").value
@@ -178,6 +182,7 @@ def collection_configuration_to_json(config: CollectionConfiguration) -> Dict[st
178182
"hnsw": hnsw_config,
179183
"spann": spann_config,
180184
"embedding_function": ef_config,
185+
"schema": schema,
181186
}
182187

183188

@@ -258,6 +263,7 @@ class CreateCollectionConfiguration(TypedDict, total=False):
258263
hnsw: Optional[CreateHNSWConfiguration]
259264
spann: Optional[CreateSpannConfiguration]
260265
embedding_function: Optional[EmbeddingFunction] # type: ignore
266+
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]
261267

262268

263269
def load_collection_configuration_from_create_collection_configuration(
@@ -402,6 +408,7 @@ def create_collection_configuration_to_json(
402408
"hnsw": hnsw_config,
403409
"spann": spann_config,
404410
"embedding_function": ef_config,
411+
"schema": config.get("schema"),
405412
}
406413

407414

@@ -473,6 +480,7 @@ class UpdateCollectionConfiguration(TypedDict, total=False):
473480
hnsw: Optional[UpdateHNSWConfiguration]
474481
spann: Optional[UpdateSpannConfiguration]
475482
embedding_function: Optional[EmbeddingFunction] # type: ignore
483+
schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]]
476484

477485

478486
def update_collection_configuration_from_legacy_collection_metadata(
@@ -527,8 +535,14 @@ def update_collection_configuration_to_json(
527535
"""Convert an UpdateCollectionConfiguration to a JSON-serializable dict"""
528536
hnsw_config = config.get("hnsw")
529537
spann_config = config.get("spann")
538+
schema = config.get("schema")
530539
ef = config.get("embedding_function")
531-
if hnsw_config is None and spann_config is None and ef is None:
540+
if (
541+
hnsw_config is None
542+
and spann_config is None
543+
and ef is None
544+
and schema is None
545+
):
532546
return {}
533547

534548
if hnsw_config is not None:
@@ -562,6 +576,7 @@ def update_collection_configuration_to_json(
562576
"hnsw": hnsw_config,
563577
"spann": spann_config,
564578
"embedding_function": ef_config,
579+
"schema": schema,
565580
}
566581

567582

@@ -710,13 +725,40 @@ def overwrite_collection_configuration(
710725
else:
711726
updated_embedding_function = update_ef
712727

728+
729+
existing_schema = existing_config.get("schema")
730+
new_diff_schema = update_config.get("schema")
731+
updated_schema: Optional[Dict[str, Dict[ValueType, CollectionSchema]]] = None
732+
if existing_schema is not None:
733+
if new_diff_schema is not None:
734+
updated_schema = overwrite_schema(existing_schema, new_diff_schema)
735+
else:
736+
updated_schema = existing_schema
737+
else:
738+
updated_schema = new_diff_schema
739+
713740
return CollectionConfiguration(
714741
hnsw=updated_hnsw_config,
715742
spann=updated_spann_config,
716743
embedding_function=updated_embedding_function,
744+
schema=updated_schema,
717745
)
718746

719747

748+
def overwrite_schema(
749+
existing_schema: Dict[str, Dict[ValueType, CollectionSchema]],
750+
new_diff_schema: Dict[str, Dict[ValueType, CollectionSchema]],
751+
) -> Dict[str, Dict[ValueType, CollectionSchema]]:
752+
"""Overwrite a schema with a new configuration"""
753+
for new_key, new_value in new_diff_schema.items():
754+
if new_key in existing_schema:
755+
for value_type, new_schema in new_value.items():
756+
existing_schema[new_key][value_type] = new_schema
757+
else:
758+
existing_schema[new_key] = new_value
759+
return existing_schema
760+
761+
720762
def validate_embedding_function_conflict_on_create(
721763
embedding_function: Optional[EmbeddingFunction], # type: ignore
722764
configuration_ef: Optional[EmbeddingFunction], # type: ignore

chromadb/api/models/AsyncCollection.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,22 @@ async def add(
6060
ValueError: If you provide an id that already exists
6161
6262
"""
63-
add_request = self._validate_and_prepare_add_request(
63+
64+
curr_schema = self._model.get_configuration().get("schema")
65+
66+
add_request, new_attributes = self._validate_and_prepare_add_request(
6467
ids=ids,
6568
embeddings=embeddings,
6669
metadatas=metadatas,
6770
documents=documents,
6871
images=images,
6972
uris=uris,
73+
schema=curr_schema,
7074
)
7175

76+
if len(new_attributes.keys()) > 0:
77+
await self.modify(configuration={"schema": new_attributes})
78+
7279
await self._client._add(
7380
collection_id=self.id,
7481
ids=add_request["ids"],
@@ -313,15 +320,20 @@ async def update(
313320
Returns:
314321
None
315322
"""
316-
update_request = self._validate_and_prepare_update_request(
323+
curr_schema = self._model.get_configuration().get("schema")
324+
update_request, new_attributes = self._validate_and_prepare_update_request(
317325
ids=ids,
318326
embeddings=embeddings,
319327
metadatas=metadatas,
320328
documents=documents,
321329
images=images,
322330
uris=uris,
331+
schema=curr_schema,
323332
)
324333

334+
if len(new_attributes.keys()) > 0:
335+
await self.modify(configuration={"schema": new_attributes})
336+
325337
await self._client._update(
326338
collection_id=self.id,
327339
ids=update_request["ids"],
@@ -358,14 +370,18 @@ async def upsert(
358370
Returns:
359371
None
360372
"""
361-
upsert_request = self._validate_and_prepare_upsert_request(
373+
curr_schema = self._model.get_configuration().get("schema")
374+
upsert_request, new_attributes = self._validate_and_prepare_upsert_request(
362375
ids=ids,
363376
embeddings=embeddings,
364377
metadatas=metadatas,
365378
documents=documents,
366379
images=images,
367380
uris=uris,
381+
schema=curr_schema,
368382
)
383+
if len(new_attributes.keys()) > 0:
384+
await self.modify(configuration={"schema": new_attributes})
369385

370386
await self._client._upsert(
371387
collection_id=self.id,

chromadb/api/models/Collection.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,20 @@ def add(
7777
7878
"""
7979

80-
add_request = self._validate_and_prepare_add_request(
80+
curr_schema = self._model.get_configuration().get("schema")
81+
add_request, new_attributes = self._validate_and_prepare_add_request(
8182
ids=ids,
8283
embeddings=embeddings,
8384
metadatas=metadatas,
8485
documents=documents,
8586
images=images,
8687
uris=uris,
88+
schema=curr_schema,
8789
)
8890

91+
if len(new_attributes.keys()) > 0:
92+
self.modify(configuration={"schema": new_attributes})
93+
8994
self._client._add(
9095
collection_id=self.id,
9196
ids=add_request["ids"],
@@ -255,6 +260,7 @@ def modify(
255260
# Note there is a race condition here where the metadata can be updated
256261
# but another thread sees the cached local metadata.
257262
# TODO: fixme
263+
258264
self._client._modify(
259265
id=self.id,
260266
new_name=name,
@@ -317,15 +323,20 @@ def update(
317323
Returns:
318324
None
319325
"""
320-
update_request = self._validate_and_prepare_update_request(
326+
curr_schema = self._model.get_configuration().get("schema")
327+
update_request, new_attributes = self._validate_and_prepare_update_request(
321328
ids=ids,
322329
embeddings=embeddings,
323330
metadatas=metadatas,
324331
documents=documents,
325332
images=images,
326333
uris=uris,
334+
schema=curr_schema,
327335
)
328336

337+
if len(new_attributes.keys()) > 0:
338+
self.modify(configuration={"schema": new_attributes})
339+
329340
self._client._update(
330341
collection_id=self.id,
331342
ids=update_request["ids"],
@@ -362,15 +373,20 @@ def upsert(
362373
Returns:
363374
None
364375
"""
365-
upsert_request = self._validate_and_prepare_upsert_request(
376+
curr_schema = self._model.get_configuration().get("schema")
377+
upsert_request, new_attributes = self._validate_and_prepare_upsert_request(
366378
ids=ids,
367379
embeddings=embeddings,
368380
metadatas=metadatas,
369381
documents=documents,
370382
images=images,
371383
uris=uris,
384+
schema=curr_schema,
372385
)
373386

387+
if len(new_attributes.keys()) > 0:
388+
self.modify(configuration={"schema": new_attributes})
389+
374390
self._client._upsert(
375391
collection_id=self.id,
376392
ids=upsert_request["ids"],

0 commit comments

Comments
 (0)