Skip to content

Commit e609337

Browse files
committed
[ENH] Auto-set tenant in python CloudClient, add AdminCloudClient
1 parent 202f0b2 commit e609337

File tree

4 files changed

+209
-102
lines changed

4 files changed

+209
-102
lines changed

chromadb/__init__.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from typing import Dict, Optional, Union
22
import logging
33
from chromadb.api.client import Client as ClientCreator
4-
from chromadb.api.client import AdminClient as AdminClientCreator
4+
from chromadb.api.client import (
5+
AdminClient as AdminClientCreator,
6+
AdminCloudClient as AdminCloudClientCreator,
7+
)
58
from chromadb.api.async_client import AsyncClient as AsyncClientCreator
69
from chromadb.auth.token_authn import TokenTransportHeader
710
import chromadb.config
811
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings
9-
from chromadb.api import AdminAPI, AsyncClientAPI, ClientAPI
12+
from chromadb.api import AdminAPI, AsyncClientAPI, ClientAPI, AdminCloudAPI
1013
from chromadb.api.models.Collection import Collection
1114
from chromadb.api.types import (
1215
CollectionMetadata,
@@ -27,6 +30,7 @@
2730
UpdateCollectionMetadata,
2831
)
2932
from pathlib import Path
33+
import warnings
3034

3135
# Re-export types from chromadb.types
3236
__all__ = [
@@ -311,19 +315,39 @@ def CloudClient(
311315
Creates a client to connect to a tennant and database on the Chroma cloud.
312316
313317
Args:
314-
tenant: The tenant to use for this client.
315318
database: The database to use for this client.
316319
api_key: The api key to use for this client.
317320
"""
321+
if tenant is not None:
322+
warnings.warn(
323+
"The 'tenant' parameter is deprecated and will be removed in a future version.",
324+
DeprecationWarning,
325+
stacklevel=2,
326+
)
327+
328+
if cloud_port != 8000:
329+
warnings.warn(
330+
"The 'cloud_port' parameter is deprecated and will be removed in a future version.",
331+
DeprecationWarning,
332+
stacklevel=2,
333+
)
334+
335+
if enable_ssl is not True:
336+
warnings.warn(
337+
"The 'enable_ssl' parameter is deprecated and will be removed in a future version.",
338+
DeprecationWarning,
339+
stacklevel=2,
340+
)
341+
318342
required_args = [
319-
CloudClientArg(name="tenant", env_var="CHROMA_TENANT", value=tenant),
320343
CloudClientArg(name="database", env_var="CHROMA_DATABASE", value=database),
321344
CloudClientArg(name="api_key", env_var="CHROMA_API_KEY", value=api_key),
322345
]
323346

324347
# If any of tenant, database, or api_key is not provided, try to load it from the environment variable
325348
if not all([arg.value for arg in required_args]):
326349
import os
350+
327351
for arg in required_args:
328352
arg.value = arg.value or os.environ.get(arg.env_var)
329353

@@ -338,26 +362,45 @@ def CloudClient(
338362
settings = Settings()
339363

340364
# Make sure paramaters are the correct types -- users can pass anything.
341-
tenant = str(tenant)
342365
database = str(database)
343366
api_key = str(api_key)
344367
cloud_host = str(cloud_host)
345-
cloud_port = int(cloud_port)
346-
enable_ssl = bool(enable_ssl)
347368

348369
settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
349370
settings.chroma_server_host = cloud_host
350-
settings.chroma_server_http_port = cloud_port
371+
settings.chroma_server_http_port = 443
351372
# Always use SSL for cloud
352-
settings.chroma_server_ssl_enabled = enable_ssl
373+
settings.chroma_server_ssl_enabled = True
353374

354375
settings.chroma_client_auth_provider = (
355376
"chromadb.auth.token_authn.TokenAuthClientProvider"
356377
)
357378
settings.chroma_client_auth_credentials = api_key
358379
settings.chroma_auth_token_transport_header = TokenTransportHeader.X_CHROMA_TOKEN
380+
settings.chroma_overwrite_singleton_tenant_database_access_from_auth = True
359381

360-
return ClientCreator(tenant=tenant, database=database, settings=settings)
382+
return ClientCreator(database=database, settings=settings)
383+
384+
385+
def AdminCloudClient(
386+
api_key: str,
387+
settings: Optional[Settings] = None,
388+
*,
389+
cloud_host: str = "api.trychroma.com",
390+
) -> AdminCloudAPI:
391+
if settings is None:
392+
settings = Settings()
393+
settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
394+
settings.chroma_server_host = cloud_host
395+
settings.chroma_server_http_port = 443
396+
settings.chroma_server_ssl_enabled = True
397+
settings.chroma_client_auth_provider = (
398+
"chromadb.auth.token_authn.TokenAuthClientProvider"
399+
)
400+
settings.chroma_client_auth_credentials = api_key
401+
settings.chroma_auth_token_transport_header = TokenTransportHeader.X_CHROMA_TOKEN
402+
settings.chroma_overwrite_singleton_tenant_database_access_from_auth = True
403+
return AdminCloudClientCreator(settings=settings)
361404

362405

363406
def Client(

chromadb/api/__init__.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,63 @@ def get_tenant(self, name: str) -> Tenant:
562562
pass
563563

564564

565+
class AdminCloudAPI(ABC):
566+
@abstractmethod
567+
def create_database(self, name: str) -> None:
568+
"""Create a new database. Raises an error if the database already exists.
569+
570+
Args:
571+
name: The name of the database to create.
572+
573+
"""
574+
pass
575+
576+
@abstractmethod
577+
def get_database(self, name: str) -> Database:
578+
"""Get a database. Raises an error if the database does not exist.
579+
580+
Args:
581+
name: The name of the database to get.
582+
583+
"""
584+
pass
585+
586+
@abstractmethod
587+
def delete_database(self, name: str) -> None:
588+
"""Delete a database. Raises an error if the database does not exist.
589+
590+
Args:
591+
name: The name of the database to delete.
592+
593+
"""
594+
pass
595+
596+
@abstractmethod
597+
def list_databases(
598+
self,
599+
limit: Optional[int] = None,
600+
offset: Optional[int] = None,
601+
) -> Sequence[Database]:
602+
"""List all databases for a tenant. Raises an error if the tenant does not exist.
603+
604+
Args:
605+
limit: The maximum number of databases to return.
606+
offset: The offset to start from.
607+
608+
"""
609+
pass
610+
611+
@abstractmethod
612+
def get_tenant(self, name: str) -> Tenant:
613+
"""Get a tenant. Raises an error if the tenant does not exist.
614+
615+
Args:
616+
name: The name of the tenant to get.
617+
618+
"""
619+
pass
620+
621+
565622
class ServerAPI(BaseAPI, AdminAPI, Component):
566623
"""An API instance that extends the relevant Base API methods by passing
567624
in a tenant and database. This is the root component of the Chroma System"""

chromadb/api/client.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from overrides import override
55
import httpx
6-
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
6+
from chromadb.api import AdminAPI, ClientAPI, ServerAPI, AdminCloudAPI
77
from chromadb.api.collection_configuration import (
88
CreateCollectionConfiguration,
99
UpdateCollectionConfiguration,
@@ -525,3 +525,59 @@ def from_system(
525525
SharedSystemClient._populate_data_from_system(system)
526526
instance = cls(settings=system.settings)
527527
return instance
528+
529+
530+
class AdminCloudClient(SharedSystemClient, AdminCloudAPI):
531+
_server: ServerAPI
532+
tenant: str
533+
534+
def __init__(self, settings: Settings = Settings()) -> None:
535+
super().__init__(settings)
536+
self._server = self._system.instance(ServerAPI)
537+
538+
user_identity = self._server.get_user_identity()
539+
540+
maybe_tenant, _ = maybe_set_tenant_and_database(
541+
user_identity,
542+
overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
543+
user_provided_tenant=None,
544+
user_provided_database=None,
545+
)
546+
if maybe_tenant:
547+
self.tenant = maybe_tenant
548+
else:
549+
raise ValueError("Could not get tenant from user identity")
550+
551+
@override
552+
def create_database(self, name: str) -> None:
553+
return self._server.create_database(name=name, tenant=self.tenant)
554+
555+
@override
556+
def get_database(self, name: str) -> Database:
557+
return self._server.get_database(name=name, tenant=self.tenant)
558+
559+
@override
560+
def delete_database(self, name: str) -> None:
561+
return self._server.delete_database(name=name, tenant=self.tenant)
562+
563+
@override
564+
def list_databases(
565+
self,
566+
limit: Optional[int] = None,
567+
offset: Optional[int] = None,
568+
) -> Sequence[Database]:
569+
return self._server.list_databases(limit, offset, tenant=self.tenant)
570+
571+
@override
572+
def get_tenant(self, name: str) -> Tenant:
573+
return self._server.get_tenant(name=name)
574+
575+
@classmethod
576+
@override
577+
def from_system(
578+
cls,
579+
system: System,
580+
) -> "AdminCloudClient":
581+
SharedSystemClient._populate_data_from_system(system)
582+
instance = cls(settings=system.settings)
583+
return instance
Lines changed: 42 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,101 +1,52 @@
1-
import multiprocessing
2-
from typing import Any, Dict, Generator, Optional, Tuple
31
import pytest
2+
from unittest.mock import patch
43
from chromadb import CloudClient
5-
from chromadb.api import ServerAPI
6-
from chromadb.auth.token_authn import TokenTransportHeader
7-
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
84
from chromadb.errors import ChromaAuthError
5+
from chromadb.auth import UserIdentity
6+
from chromadb.types import Tenant, Database
7+
from uuid import uuid4
8+
9+
10+
def test_valid_key() -> None:
11+
with patch(
12+
"chromadb.api.fastapi.FastAPI.get_user_identity"
13+
) as mock_get_user_identity, patch(
14+
"chromadb.api.client.AdminClient.get_tenant"
15+
) as mock_get_tenant, patch(
16+
"chromadb.api.client.AdminClient.get_database"
17+
) as mock_get_database, patch(
18+
"chromadb.api.fastapi.FastAPI.heartbeat"
19+
) as mock_heartbeat:
20+
mock_get_user_identity.return_value = UserIdentity(
21+
user_id="test_user", tenant="default_tenant", databases=["testdb"]
22+
)
23+
mock_get_tenant.return_value = Tenant(name="default_tenant")
24+
mock_get_database.return_value = Database(
25+
id=uuid4(), name="testdb", tenant="default_tenant"
26+
)
27+
mock_heartbeat.return_value = 1234567890
928

10-
from chromadb.test.conftest import _await_server, _run_server, find_free_port
11-
12-
TOKEN_TRANSPORT_HEADER = TokenTransportHeader.X_CHROMA_TOKEN
13-
TEST_CLOUD_HOST = "localhost"
14-
15-
16-
@pytest.fixture(scope="module")
17-
def valid_token() -> str:
18-
return "valid_token"
19-
20-
21-
@pytest.fixture(scope="module")
22-
def mock_cloud_server(valid_token: str) -> Generator[System, None, None]:
23-
chroma_server_authn_provider: str = (
24-
"chromadb.auth.token_authn.TokenAuthenticationServerProvider"
25-
)
26-
chroma_server_authn_credentials: str = valid_token
27-
chroma_auth_token_transport_header: str = TOKEN_TRANSPORT_HEADER
28-
29-
port = find_free_port()
30-
31-
args: Tuple[
32-
int,
33-
bool,
34-
Optional[str],
35-
Optional[str],
36-
Optional[str],
37-
Optional[str],
38-
Optional[str],
39-
Optional[str],
40-
Optional[str],
41-
Optional[Dict[str, Any]],
42-
] = (
43-
port,
44-
False,
45-
None,
46-
chroma_server_authn_provider,
47-
None,
48-
chroma_server_authn_credentials,
49-
chroma_auth_token_transport_header,
50-
None,
51-
None,
52-
None,
53-
)
54-
ctx = multiprocessing.get_context("spawn")
55-
proc = ctx.Process(target=_run_server, args=args, daemon=True)
56-
proc.start()
57-
58-
settings = Settings(
59-
chroma_api_impl="chromadb.api.fastapi.FastAPI",
60-
chroma_server_host=TEST_CLOUD_HOST,
61-
chroma_server_http_port=port,
62-
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
63-
chroma_client_auth_credentials=valid_token,
64-
chroma_auth_token_transport_header=TOKEN_TRANSPORT_HEADER,
65-
)
29+
client = CloudClient(database="testdb", api_key="valid_token")
6630

67-
system = System(settings)
68-
api = system.instance(ServerAPI)
69-
system.start()
70-
_await_server(api)
71-
yield system
72-
system.stop()
73-
proc.kill()
31+
assert client.get_user_identity().user_id == "test_user"
32+
assert client.get_user_identity().tenant == "default_tenant"
33+
assert client.get_user_identity().databases == ["testdb"]
7434

35+
settings = client.get_settings()
36+
assert settings.chroma_client_auth_credentials == "valid_token"
37+
assert (
38+
settings.chroma_client_auth_provider
39+
== "chromadb.auth.token_authn.TokenAuthClientProvider"
40+
)
7541

76-
def test_valid_key(mock_cloud_server: System, valid_token: str) -> None:
77-
valid_client = CloudClient(
78-
tenant=DEFAULT_TENANT,
79-
database=DEFAULT_DATABASE,
80-
api_key=valid_token,
81-
cloud_host=TEST_CLOUD_HOST,
82-
cloud_port=mock_cloud_server.settings.chroma_server_http_port or 8000,
83-
enable_ssl=False,
84-
)
42+
assert client.heartbeat() == 1234567890
8543

86-
assert valid_client.heartbeat()
8744

45+
def test_invalid_key() -> None:
46+
with patch(
47+
"chromadb.api.fastapi.FastAPI.get_user_identity"
48+
) as mock_get_user_identity:
49+
mock_get_user_identity.side_effect = ChromaAuthError("Authentication failed")
8850

89-
def test_invalid_key(mock_cloud_server: System, valid_token: str) -> None:
90-
# Try to connect to the default tenant and database with an invalid token
91-
invalid_token = valid_token + "_invalid"
92-
with pytest.raises(ChromaAuthError):
93-
client = CloudClient(
94-
tenant=DEFAULT_TENANT,
95-
database=DEFAULT_DATABASE,
96-
api_key=invalid_token,
97-
cloud_host=TEST_CLOUD_HOST,
98-
cloud_port=mock_cloud_server.settings.chroma_server_http_port or 8000,
99-
enable_ssl=False,
100-
)
101-
client.heartbeat()
51+
with pytest.raises(ChromaAuthError):
52+
CloudClient(database="testdb", api_key="invalid_token")

0 commit comments

Comments
 (0)