Skip to content

[ENH] Auto-set tenant in python CloudClient, add AdminCloudClient #5026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 60 additions & 15 deletions chromadb/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Dict, Optional, Union
import logging
from chromadb.api.client import Client as ClientCreator
from chromadb.api.client import AdminClient as AdminClientCreator
from chromadb.api.client import (
AdminClient as AdminClientCreator,
AdminCloudClient as AdminCloudClientCreator,
)
from chromadb.api.async_client import AsyncClient as AsyncClientCreator
from chromadb.auth.token_authn import TokenTransportHeader
import chromadb.config
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings
from chromadb.api import AdminAPI, AsyncClientAPI, ClientAPI
from chromadb.api import AdminAPI, AsyncClientAPI, ClientAPI, AdminCloudAPI
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
CollectionMetadata,
Expand All @@ -27,6 +30,8 @@
UpdateCollectionMetadata,
)
from pathlib import Path
import warnings
import os

# Re-export types from chromadb.types
__all__ = [
Expand Down Expand Up @@ -311,19 +316,36 @@ def CloudClient(
Creates a client to connect to a tennant and database on the Chroma cloud.

Args:
tenant: The tenant to use for this client.
database: The database to use for this client.
api_key: The api key to use for this client.
"""
if tenant is not None:
warnings.warn(
"The 'tenant' parameter is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)

if cloud_port != 8000:
warnings.warn(
"The 'cloud_port' parameter is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)

if enable_ssl is not True:
warnings.warn(
"The 'enable_ssl' parameter is deprecated and will be removed in a future version.",
DeprecationWarning,
stacklevel=2,
)

required_args = [
CloudClientArg(name="tenant", env_var="CHROMA_TENANT", value=tenant),
CloudClientArg(name="database", env_var="CHROMA_DATABASE", value=database),
CloudClientArg(name="api_key", env_var="CHROMA_API_KEY", value=api_key),
]

# If any of tenant, database, or api_key is not provided, try to load it from the environment variable
# If api_key is not provided, try to load it from the environment variable
if not all([arg.value for arg in required_args]):
import os
for arg in required_args:
arg.value = arg.value or os.environ.get(arg.env_var)

Expand All @@ -338,26 +360,49 @@ def CloudClient(
settings = Settings()

# Make sure paramaters are the correct types -- users can pass anything.
tenant = str(tenant)
database = str(database)
database = database or os.environ.get("CHROMA_DATABASE")
if database is not None:
database = str(database)
api_key = str(api_key)
cloud_host = str(cloud_host)
cloud_port = int(cloud_port)
enable_ssl = bool(enable_ssl)

settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
settings.chroma_server_host = cloud_host
settings.chroma_server_http_port = cloud_port
# Always use SSL for cloud
settings.chroma_server_ssl_enabled = enable_ssl
settings.chroma_server_http_port = 443
settings.chroma_server_ssl_enabled = True

settings.chroma_client_auth_provider = (
"chromadb.auth.token_authn.TokenAuthClientProvider"
)
settings.chroma_client_auth_credentials = api_key
settings.chroma_auth_token_transport_header = TokenTransportHeader.X_CHROMA_TOKEN
settings.chroma_overwrite_singleton_tenant_database_access_from_auth = True

return ClientCreator(tenant=tenant, database=database, settings=settings)
if database is None:
return ClientCreator(settings=settings, user_supplied_db=False, is_cloud=True)

return ClientCreator(database=database, settings=settings, is_cloud=True)


def AdminCloudClient(
api_key: str,
settings: Optional[Settings] = None,
*,
cloud_host: str = "api.trychroma.com",
) -> AdminCloudAPI:
if settings is None:
settings = Settings()
settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"
settings.chroma_server_host = cloud_host
settings.chroma_server_http_port = 443
settings.chroma_server_ssl_enabled = True
settings.chroma_client_auth_provider = (
"chromadb.auth.token_authn.TokenAuthClientProvider"
)
settings.chroma_client_auth_credentials = api_key
settings.chroma_auth_token_transport_header = TokenTransportHeader.X_CHROMA_TOKEN
settings.chroma_overwrite_singleton_tenant_database_access_from_auth = True
return AdminCloudClientCreator(settings=settings)


def Client(
Expand Down
57 changes: 57 additions & 0 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,63 @@ def get_tenant(self, name: str) -> Tenant:
pass


class AdminCloudAPI(ABC):
@abstractmethod
def create_database(self, name: str) -> None:
"""Create a new database. Raises an error if the database already exists.

Args:
name: The name of the database to create.

"""
pass

@abstractmethod
def get_database(self, name: str) -> Database:
"""Get a database. Raises an error if the database does not exist.

Args:
name: The name of the database to get.

"""
pass

@abstractmethod
def delete_database(self, name: str) -> None:
"""Delete a database. Raises an error if the database does not exist.

Args:
name: The name of the database to delete.

"""
pass

@abstractmethod
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Database]:
"""List all databases for a tenant. Raises an error if the tenant does not exist.

Args:
limit: The maximum number of databases to return.
offset: The offset to start from.

"""
pass

@abstractmethod
def get_tenant(self, name: str) -> Tenant:
"""Get a tenant. Raises an error if the tenant does not exist.

Args:
name: The name of the tenant to get.

"""
pass


class ServerAPI(BaseAPI, AdminAPI, Component):
"""An API instance that extends the relevant Base API methods by passing
in a tenant and database. This is the root component of the Chroma System"""
Expand Down
69 changes: 66 additions & 3 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from overrides import override
import httpx
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
from chromadb.api import AdminAPI, ClientAPI, ServerAPI, AdminCloudAPI
from chromadb.api.collection_configuration import (
CreateCollectionConfiguration,
UpdateCollectionConfiguration,
Expand Down Expand Up @@ -61,6 +61,9 @@ def __init__(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
settings: Settings = Settings(),
*,
is_cloud: bool = False,
user_supplied_db: bool = True,
) -> None:
super().__init__(settings=settings)
self.tenant = tenant
Expand All @@ -75,8 +78,10 @@ def __init__(
user_identity,
overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
user_provided_tenant=tenant,
user_provided_database=database,
user_provided_database=database if user_supplied_db else None,
)
if is_cloud and not maybe_database:
raise ValueError("Could not get database. Please provide a database name.")
if maybe_tenant:
self.tenant = maybe_tenant
if maybe_database:
Expand Down Expand Up @@ -455,7 +460,9 @@ def set_database(self, database: str) -> None:
self._validate_tenant_database(tenant=self.tenant, database=database)
self.database = database

def _validate_tenant_database(self, tenant: str, database: str) -> None:
def _validate_tenant_database(
self, tenant: str, database: str, is_cloud: bool = False
) -> None:
try:
self._admin_client.get_tenant(name=tenant)
except httpx.ConnectError:
Expand Down Expand Up @@ -525,3 +532,59 @@ def from_system(
SharedSystemClient._populate_data_from_system(system)
instance = cls(settings=system.settings)
return instance


class AdminCloudClient(SharedSystemClient, AdminCloudAPI):
_server: ServerAPI
tenant: str

def __init__(self, settings: Settings = Settings()) -> None:
super().__init__(settings)
self._server = self._system.instance(ServerAPI)

user_identity = self._server.get_user_identity()

maybe_tenant, _ = maybe_set_tenant_and_database(
user_identity,
overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
user_provided_tenant=None,
user_provided_database=None,
)
if maybe_tenant:
self.tenant = maybe_tenant
else:
raise ValueError("Could not get tenant from user identity")

@override
def create_database(self, name: str) -> None:
return self._server.create_database(name=name, tenant=self.tenant)

@override
def get_database(self, name: str) -> Database:
return self._server.get_database(name=name, tenant=self.tenant)

@override
def delete_database(self, name: str) -> None:
return self._server.delete_database(name=name, tenant=self.tenant)

@override
def list_databases(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> Sequence[Database]:
return self._server.list_databases(limit, offset, tenant=self.tenant)

@override
def get_tenant(self, name: str) -> Tenant:
return self._server.get_tenant(name=name)

@classmethod
@override
def from_system(
cls,
system: System,
) -> "AdminCloudClient":
SharedSystemClient._populate_data_from_system(system)
instance = cls(settings=system.settings)
return instance
6 changes: 4 additions & 2 deletions chromadb/auth/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ def _singleton_tenant_database_if_applicable(
user_databases = user_identity.databases
if user_tenant and user_tenant != "*":
tenant = user_tenant
if user_databases and len(user_databases) == 1 and user_databases[0] != "*":
database = user_databases[0]
if user_databases:
user_databases_set = set(user_databases)
if len(user_databases_set) == 1 and "*" not in user_databases_set:
database = list(user_databases_set)[0]
return tenant, database


Expand Down
Loading
Loading