Skip to content

Commit 90e4ad3

Browse files
authored
[TST] Test rust client backward compatibility (#3867)
## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - N/A - New functionality - Test rust client backward compatibility with python clients across legacy versions ## Test plan *How are these changes tested?* - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?*
1 parent 608e333 commit 90e4ad3

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

bin/rust_python_compat_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import json
2+
import multiprocessing
3+
import os
4+
import packaging
5+
import re
6+
import shutil
7+
import subprocess
8+
import sys
9+
import tempfile
10+
import tqdm
11+
import urllib
12+
13+
from chromadb import RustClient
14+
from chromadb.config import Settings
15+
from chromadb.segment.impl.manager.local import LocalSegmentManager
16+
from chromadb.test.property.test_cross_version_persist import api_import_for_version
17+
from chromadb.test.utils.cross_version import install_version, switch_to_version
18+
from packaging import version
19+
from typing import List
20+
from urllib import request
21+
22+
persist_size = 10000
23+
batch_size = 100
24+
collection_name = "rust_py_compat_test"
25+
26+
version_re = re.compile(r"^[0-9]+\.[0-9]+\.[0-9]+$")
27+
28+
def versions() -> List[str]:
29+
"""Returns the pinned minimum version and the latest version of chromadb."""
30+
url = "https://pypi.org/pypi/chromadb/json"
31+
data = json.load(request.urlopen(request.Request(url)))
32+
versions = list(data["releases"].keys())
33+
# Older versions on pypi contain "devXYZ" suffixes
34+
versions = [v for v in versions if version_re.match(v) and version.Version(v) >= version.Version("0.5.3")]
35+
versions.sort(key=version.Version)
36+
return versions
37+
38+
def persist_with_old_version(ver: str, path: str):
39+
print(f"Installing ChromaDB {ver}")
40+
install_version(ver, {})
41+
old_modules = switch_to_version(ver, ["pydantic", "numpy", "tokenizers"])
42+
43+
print(f"Initializing client {ver}")
44+
settings = Settings(
45+
chroma_api_impl="chromadb.api.segment.SegmentAPI",
46+
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
47+
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
48+
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
49+
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager",
50+
allow_reset=True,
51+
is_persistent=True,
52+
persist_directory=path,
53+
)
54+
if version.Version(ver) <= version.Version("0.4.14"):
55+
settings.chroma_telemetry_impl = "chromadb.telemetry.posthog.Posthog"
56+
system = old_modules.config.System(settings)
57+
api = system.instance(api_import_for_version(old_modules, ver))
58+
system.start()
59+
api.reset()
60+
if version.Version(ver) >= version.Version("0.5.4"):
61+
api = old_modules.api.client.Client.from_system(system)
62+
63+
print(f"Persisting data with old client to {path}")
64+
coll = api.create_collection(collection_name)
65+
for start in tqdm.tqdm(range(0, persist_size // 2, batch_size)):
66+
id_vals = range(start, start + batch_size)
67+
documents = [f"DOC-{i}" for i in id_vals]
68+
embeddings = [[i, i] for i in id_vals]
69+
ids = [str(i) for i in id_vals]
70+
metadatas = [{"int": i, "float": i / 2.0, "str": f"<{i}>"} for i in id_vals]
71+
coll.add(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas)
72+
assert coll.count() == persist_size // 2
73+
system.instance(LocalSegmentManager).stop()
74+
for start in tqdm.tqdm(range(persist_size // 2, persist_size, batch_size)):
75+
id_vals = range(start, start + batch_size)
76+
documents = [f"DOC-{i}" for i in id_vals]
77+
embeddings = [[i, i] for i in id_vals]
78+
ids = [str(i) for i in id_vals]
79+
metadatas = [{"int": i, "float": i / 2.0, "str": f"<{i}>"} for i in id_vals]
80+
coll.add(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas)
81+
82+
def verify_collection_content(path: str):
83+
print("Loading collection from rust client")
84+
client = RustClient(path=path)
85+
coll = client.get_collection(collection_name)
86+
87+
print("Verifying collection content")
88+
assert coll.count() == persist_size
89+
records = coll.get(include=["documents", "embeddings", "metadatas"])
90+
assert records["ids"] == [str(i) for i in range(persist_size)]
91+
assert records["documents"] == [f"DOC-{i}" for i in range(persist_size)]
92+
assert all(emb[0] == emb[1] == i for i, emb in enumerate(records["embeddings"]))
93+
94+
if __name__ == "__main__":
95+
for ver in versions():
96+
path = tempfile.gettempdir() + "/" + collection_name
97+
ctx = multiprocessing.get_context("spawn")
98+
proc_handle = ctx.Process(
99+
target=persist_with_old_version,
100+
args=(ver, path),
101+
)
102+
proc_handle.start()
103+
proc_handle.join()
104+
if proc_handle.exitcode == 0:
105+
verify_collection_content(path)
106+
shutil.rmtree(path, ignore_errors=True)

chromadb/test/property/test_cross_version_persist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def persist_generated_data_with_old_version(
238238

239239
@given(
240240
collection_strategy=collection_st,
241-
embeddings_strategy=strategies.recordsets(collection_st),
241+
embeddings_strategy=strategies.recordsets(collection_st, max_size=200),
242242
)
243243
@settings(deadline=None)
244244
def test_cycle_versions(

0 commit comments

Comments
 (0)