From 4ddda3efaa74c7e97caf784b73053a7761bdd322 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 14 Nov 2024 06:50:24 +0100 Subject: [PATCH 01/21] feat: create the dict cache and the method to go with it --- literalai/api/__init__.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index b771092..0a1a9fa 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -169,6 +169,8 @@ def __init__( self.graphql_endpoint = self.url + "/api/graphql" self.rest_endpoint = self.url + "/api" + self._prompt_cache = {} + @property def headers(self): from literalai.version import __version__ @@ -185,6 +187,40 @@ def headers(self): return h + def _get_prompt_cache_key( + self, + id: Optional[str] = None, + name: Optional[str] = None, + version: Optional[int] = None, + ) -> str: + key = "" + if id: + key = f"id:{id}" + elif name: + key = f"name:{name}" + else: + raise ValueError("Either the `id` or the `name` must be provided.") + + if version: + key += f":version:{version}" + + return key + + def _get_prompt_cache(self, id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = None): + key = self._get_prompt_cache_key(id, name, version) + # handle the case where I have a version but it's not found in this case look without the version + if key in self._prompt_cache: + return self._prompt_cache.get(key) + else: + key_without_version = self._get_prompt_cache_key(id, name) + return self._prompt_cache.get(key_without_version) + + def _create_prompt_cache(self, prompt: Prompt): + key = self._get_prompt_cache_key(id=prompt.id, name=prompt.name, version=prompt.version) + key_without_version = self._get_prompt_cache_key(id=prompt.id, name=prompt.name) + self._prompt_cache[key] = prompt + self._prompt_cache[key_without_version] = prompt + class LiteralAPI(BaseLiteralAPI): """ From fd9f4625a758da32be8c699ebeadb1874a16251c Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 14 Nov 2024 09:36:01 +0100 Subject: [PATCH 02/21] feat: get_prompt add caching --- literalai/api/__init__.py | 97 +++++++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 29 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 0a1a9fa..bfa27a6 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -195,31 +195,41 @@ def _get_prompt_cache_key( ) -> str: key = "" if id: - key = f"id:{id}" + return f"id:{id}" elif name: key = f"name:{name}" + if version: + key += f":version:{version}" + return key else: raise ValueError("Either the `id` or the `name` must be provided.") - if version: - key += f":version:{version}" + def _get_prompt_cache(self, id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = None) -> Optional[Prompt]: + """Returns the cached prompt, using key in this order: id, name-version, name + """ + key_id = self._get_prompt_cache_key(id=id) + if key_id in self._prompt_cache: + return self._prompt_cache.get(key_id) - return key + key_name_version = self._get_prompt_cache_key(name=name, version=version) + if key_name_version in self._prompt_cache: + return self._prompt_cache.get(key_name_version) - def _get_prompt_cache(self, id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = None): - key = self._get_prompt_cache_key(id, name, version) - # handle the case where I have a version but it's not found in this case look without the version - if key in self._prompt_cache: - return self._prompt_cache.get(key) - else: - key_without_version = self._get_prompt_cache_key(id, name) - return self._prompt_cache.get(key_without_version) + key_name = self._get_prompt_cache_key(name=name) + if key_name in self._prompt_cache: + return self._prompt_cache.get(key_name) def _create_prompt_cache(self, prompt: Prompt): - key = self._get_prompt_cache_key(id=prompt.id, name=prompt.name, version=prompt.version) - key_without_version = self._get_prompt_cache_key(id=prompt.id, name=prompt.name) - self._prompt_cache[key] = prompt - self._prompt_cache[key_without_version] = prompt + """Creates cache for prompt. 3 entries are created/updated: id, name, name:version + """ + key_id = self._get_prompt_cache_key(id=prompt.id) + self._prompt_cache[key_id] = prompt + + key_name = self._get_prompt_cache_key(name=prompt.name) + self._prompt_cache[key_name] = prompt + + key_name_version = self._get_prompt_cache_key(name=prompt.name, version=prompt.version) + self._prompt_cache[key_name_version] = prompt class LiteralAPI(BaseLiteralAPI): @@ -1398,13 +1408,28 @@ def get_prompt( Returns: Prompt: The prompt with the given identifier or name. """ - if id: - return self.gql_helper(*get_prompt_helper(self, id=id)) - elif name: - return self.gql_helper(*get_prompt_helper(self, name=name, version=version)) - else: + if not (id or name): raise ValueError("Either the `id` or the `name` must be provided.") + cached_prompt = self._get_prompt_cache(id, name) + + try: + if id: + prompt = self.gql_helper(*get_prompt_helper(self, id=id)) + elif name: + prompt = self.gql_helper(*get_prompt_helper(self, name=name, version=version)) + + self._create_prompt_cache(prompt) + return prompt + + except Exception as e: + if cached_prompt: + logger.warning("Failed to get prompt from API, returning cached prompt") + logger.error(f"Error: {e}") + return cached_prompt + + raise e + def create_prompt_variant( self, name: str, @@ -2629,16 +2654,30 @@ async def get_prompt( name: Optional[str] = None, version: Optional[int] = None, ) -> Prompt: - sync_api = LiteralAPI(self.api_key, self.url) - if id: - return await self.gql_helper(*get_prompt_helper(sync_api, id=id)) - elif name: - return await self.gql_helper( - *get_prompt_helper(sync_api, name=name, version=version) - ) - else: + if not (id or name): raise ValueError("Either the `id` or the `name` must be provided.") + sync_api = LiteralAPI(self.api_key, self.url) + cached_prompt = self._get_prompt_cache(id, name) + + try: + if id: + prompt = await self.gql_helper(*get_prompt_helper(sync_api, id=id)) + elif name: + prompt = await self.gql_helper( + *get_prompt_helper(sync_api, name=name, version=version) + ) + + self._create_prompt_cache(prompt) + return prompt + + except Exception as e: + if cached_prompt: + logger.warning("Failed to get prompt from API, returning cached prompt") + logger.error(f"Error: {e}") + return cached_prompt + raise e + get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__ async def update_prompt_ab_testing( From 8774a0015402beb62cb939c010745bdd46ce012d Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 14 Nov 2024 11:06:54 +0100 Subject: [PATCH 03/21] feat: implement caching on get_prompt --- literalai/api/__init__.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index bfa27a6..1b96a59 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -195,14 +195,13 @@ def _get_prompt_cache_key( ) -> str: key = "" if id: - return f"id:{id}" + key = f"id:{id}" elif name: key = f"name:{name}" if version: key += f":version:{version}" - return key - else: - raise ValueError("Either the `id` or the `name` must be provided.") + + return key def _get_prompt_cache(self, id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = None) -> Optional[Prompt]: """Returns the cached prompt, using key in this order: id, name-version, name @@ -1411,7 +1410,7 @@ def get_prompt( if not (id or name): raise ValueError("Either the `id` or the `name` must be provided.") - cached_prompt = self._get_prompt_cache(id, name) + cached_prompt = self._get_prompt_cache(id, name, version) try: if id: From 5d8b5f77e3e0ed98f34eec8514bc4263ce61fd5c Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 14 Nov 2024 11:10:29 +0100 Subject: [PATCH 04/21] fix: ci --- literalai/api/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 1b96a59..904ee80 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -169,7 +169,7 @@ def __init__( self.graphql_endpoint = self.url + "/api/graphql" self.rest_endpoint = self.url + "/api" - self._prompt_cache = {} + self._prompt_cache: Dict[str, Prompt] = dict() @property def headers(self): @@ -218,6 +218,8 @@ def _get_prompt_cache(self, id: Optional[str] = None, name: Optional[str] = None if key_name in self._prompt_cache: return self._prompt_cache.get(key_name) + return None + def _create_prompt_cache(self, prompt: Prompt): """Creates cache for prompt. 3 entries are created/updated: id, name, name:version """ From e7589c63c8884fb0567fbcdc27234409ff559a7a Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 14 Nov 2024 12:46:33 +0100 Subject: [PATCH 05/21] feat: add timeout if prompt cached --- literalai/api/__init__.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 904ee80..733807c 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -247,7 +247,7 @@ class LiteralAPI(BaseLiteralAPI): R = TypeVar("R") def make_gql_call( - self, description: str, query: str, variables: Dict[str, Any] + self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10 ) -> Dict: def raise_error(error): logger.error(f"Failed to {description}: {error}") @@ -259,7 +259,7 @@ def raise_error(error): self.graphql_endpoint, json={"query": query, "variables": variables}, headers=self.headers, - timeout=10, + timeout=timeout, ) try: @@ -323,8 +323,9 @@ def gql_helper( description: str, variables: Dict, process_response: Callable[..., R], + timeout: Optional[int] = None, ) -> R: - response = self.make_gql_call(description, query, variables) + response = self.make_gql_call(description, query, variables, timeout) return process_response(response) # User API @@ -1413,12 +1414,20 @@ def get_prompt( raise ValueError("Either the `id` or the `name` must be provided.") cached_prompt = self._get_prompt_cache(id, name, version) + timeout = 1 if cached_prompt else None try: if id: - prompt = self.gql_helper(*get_prompt_helper(self, id=id)) + prompt = self.gql_helper( + *get_prompt_helper(self, id=id), timeout=timeout + ) elif name: - prompt = self.gql_helper(*get_prompt_helper(self, name=name, version=version)) + prompt = self.gql_helper( + *get_prompt_helper( + self, name=name, version=version + ), + timeout=timeout + ) self._create_prompt_cache(prompt) return prompt @@ -2660,13 +2669,15 @@ async def get_prompt( sync_api = LiteralAPI(self.api_key, self.url) cached_prompt = self._get_prompt_cache(id, name) + timeout = 1 if cached_prompt else None try: if id: - prompt = await self.gql_helper(*get_prompt_helper(sync_api, id=id)) + prompt = await self.gql_helper(*get_prompt_helper(sync_api, id=id), timeout=timeout) elif name: prompt = await self.gql_helper( - *get_prompt_helper(sync_api, name=name, version=version) + *get_prompt_helper(sync_api, name=name, version=version), + timeout=timeout, ) self._create_prompt_cache(prompt) From 723f7fd6d1afdd742d2c1923c07e1de88d95b75e Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 14 Nov 2024 14:31:49 +0100 Subject: [PATCH 06/21] feat: improve caching --- literalai/api/__init__.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 733807c..fe54d76 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -2,7 +2,6 @@ import os import uuid from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -102,9 +101,6 @@ ) from literalai.prompt_engineering.prompt import Prompt, ProviderSettings -if TYPE_CHECKING: - from typing import Tuple # noqa: F401 - import httpx from literalai.my_types import Environment, PaginatedResponse @@ -170,6 +166,7 @@ def __init__( self.rest_endpoint = self.url + "/api" self._prompt_cache: Dict[str, Prompt] = dict() + self._prompt_storage: Dict[str, Prompt] = dict() @property def headers(self): @@ -206,31 +203,40 @@ def _get_prompt_cache_key( def _get_prompt_cache(self, id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = None) -> Optional[Prompt]: """Returns the cached prompt, using key in this order: id, name-version, name """ + ref_key = None + key_id = self._get_prompt_cache_key(id=id) if key_id in self._prompt_cache: - return self._prompt_cache.get(key_id) - - key_name_version = self._get_prompt_cache_key(name=name, version=version) - if key_name_version in self._prompt_cache: - return self._prompt_cache.get(key_name_version) + ref_key = self._prompt_cache[key_id] + else: + key_name_version = self._get_prompt_cache_key(name=name, version=version) + if key_name_version in self._prompt_cache: + ref_key = self._prompt_cache[key_name_version] + else: + key_name = self._get_prompt_cache_key(name=name) + if key_name in self._prompt_cache: + ref_key = self._prompt_cache[key_name] - key_name = self._get_prompt_cache_key(name=name) - if key_name in self._prompt_cache: - return self._prompt_cache.get(key_name) + if ref_key and ref_key in self._prompt_storage: + return self._prompt_storage[ref_key] return None def _create_prompt_cache(self, prompt: Prompt): - """Creates cache for prompt. 3 entries are created/updated: id, name, name:version + """Creates cache for prompt. All keys point to the same storage key (prompt.id) + to avoid storing multiple copies of the same prompt. """ + storage_key = f"id:{prompt.id}" + self._prompt_storage[storage_key] = prompt + key_id = self._get_prompt_cache_key(id=prompt.id) - self._prompt_cache[key_id] = prompt + self._prompt_cache[key_id] = storage_key key_name = self._get_prompt_cache_key(name=prompt.name) - self._prompt_cache[key_name] = prompt + self._prompt_cache[key_name] = storage_key key_name_version = self._get_prompt_cache_key(name=prompt.name, version=prompt.version) - self._prompt_cache[key_name_version] = prompt + self._prompt_cache[key_name_version] = storage_key class LiteralAPI(BaseLiteralAPI): From 32b971f6e2cdb102d0dcd0545071dbdba9a8c969 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 14 Nov 2024 15:09:53 +0100 Subject: [PATCH 07/21] feat: improve logging --- literalai/api/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index fe54d76..6342668 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -165,8 +165,8 @@ def __init__( self.graphql_endpoint = self.url + "/api/graphql" self.rest_endpoint = self.url + "/api" - self._prompt_cache: Dict[str, Prompt] = dict() - self._prompt_storage: Dict[str, Prompt] = dict() + self._prompt_cache: dict[str, str] = dict() + self._prompt_storage: dict[str, Prompt] = dict() @property def headers(self): @@ -1441,7 +1441,6 @@ def get_prompt( except Exception as e: if cached_prompt: logger.warning("Failed to get prompt from API, returning cached prompt") - logger.error(f"Error: {e}") return cached_prompt raise e @@ -2692,7 +2691,6 @@ async def get_prompt( except Exception as e: if cached_prompt: logger.warning("Failed to get prompt from API, returning cached prompt") - logger.error(f"Error: {e}") return cached_prompt raise e From 5bcdce4272ac4ea93874de9d62330007efe4ce2e Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 14 Nov 2024 15:21:28 +0100 Subject: [PATCH 08/21] fix: ci errors --- literalai/api/__init__.py | 35 +++++++++++++++------------------ literalai/api/prompt_helpers.py | 3 ++- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 6342668..a0d9369 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -253,8 +253,8 @@ class LiteralAPI(BaseLiteralAPI): R = TypeVar("R") def make_gql_call( - self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10 - ) -> Dict: + self, description: str, query: str, variables: dict[str, Any], timeout: Optional[int] = 10 + ) -> dict: def raise_error(error): logger.error(f"Failed to {description}: {error}") raise Exception(error) @@ -738,7 +738,7 @@ def upload_file( # Prepare form data form_data = ( {} - ) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]] + ) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]] for field_name, field_value in fields.items(): form_data[field_name] = (None, field_value) @@ -1424,16 +1424,9 @@ def get_prompt( try: if id: - prompt = self.gql_helper( - *get_prompt_helper(self, id=id), timeout=timeout - ) + prompt = self.gql_helper(*get_prompt_helper(self, id=id, timeout=timeout)) elif name: - prompt = self.gql_helper( - *get_prompt_helper( - self, name=name, version=version - ), - timeout=timeout - ) + prompt = self.gql_helper(*get_prompt_helper(self, name=name, version=version, timeout=timeout)) self._create_prompt_cache(prompt) return prompt @@ -1527,7 +1520,7 @@ class AsyncLiteralAPI(BaseLiteralAPI): R = TypeVar("R") async def make_gql_call( - self, description: str, query: str, variables: Dict[str, Any] + self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10 ) -> Dict: def raise_error(error): logger.error(f"Failed to {description}: {error}") @@ -1540,7 +1533,7 @@ def raise_error(error): self.graphql_endpoint, json={"query": query, "variables": variables}, headers=self.headers, - timeout=10, + timeout=timeout, ) try: @@ -1604,8 +1597,9 @@ async def gql_helper( description: str, variables: Dict, process_response: Callable[..., R], + timeout: Optional[int] = 10, ) -> R: - response = await self.make_gql_call(description, query, variables) + response = await self.make_gql_call(description, query, variables, timeout) return process_response(response) async def get_users( @@ -2039,7 +2033,7 @@ async def upload_file( # Prepare form data form_data = ( {} - ) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]] + ) # type: dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]] for field_name, field_value in fields.items(): form_data[field_name] = (None, field_value) @@ -2678,11 +2672,14 @@ async def get_prompt( try: if id: - prompt = await self.gql_helper(*get_prompt_helper(sync_api, id=id), timeout=timeout) + prompt = await self.gql_helper( + *get_prompt_helper(sync_api, id=id, timeout=timeout) + ) elif name: prompt = await self.gql_helper( - *get_prompt_helper(sync_api, name=name, version=version), - timeout=timeout, + *get_prompt_helper( + sync_api, name=name, version=version, timeout=timeout + ) ) self._create_prompt_cache(prompt) diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index 3377f77..5deff89 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -61,6 +61,7 @@ def get_prompt_helper( id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = 0, + timeout: Optional[int] = None, ): variables = {"id": id, "name": name, "version": version} @@ -70,7 +71,7 @@ def process_response(response): description = "get prompt" - return gql.GET_PROMPT_VERSION, description, variables, process_response + return gql.GET_PROMPT_VERSION, description, variables, process_response, timeout def create_prompt_variant_helper( From 2476ebc9a72c6171cd9e3ea03c239eb33cfbb090 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Fri, 15 Nov 2024 10:55:24 +0100 Subject: [PATCH 09/21] feat: improve the prompt cache class --- literalai/api/__init__.py | 110 +++++++++++++++++++------------------- 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index a0d9369..f314712 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -1,3 +1,4 @@ +from threading import Lock import logging import os import uuid @@ -141,6 +142,59 @@ def handle_bytes(item): return handle_bytes(variables) +class SharedPromptCache: + """ + Thread-safe singleton cache for storing prompts. + Only one instance will exist regardless of how many times it's instantiated. + """ + _instance = None + _lock = Lock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + + cls._instance._prompts: dict[str, Prompt] = {} + cls._instance._name_index: dict[str, str] = {} + cls._instance._name_version_index: dict[tuple[str, int], str] = {} + return cls._instance + + def get( + self, + id: Optional[str] = None, + name: Optional[str] = None, + version: Optional[int] = None + ) -> Optional[Prompt]: + """ + Retrieves a prompt using the most specific criteria provided. + Lookup priority: id, name-version, name + """ + if id: + prompt_id = id + elif name and version: + prompt_id = self._name_version_index.get((name, version)) + elif name: + prompt_id = self._name_index.get(name) + + return self._prompts.get(prompt_id) if prompt_id else None + + def put(self, prompt: Prompt): + with self._lock: + self._prompts[prompt.id] = prompt + self._name_index[prompt.name] = prompt.id + self._name_version_index[(prompt.name, prompt.version)] = prompt.id + + def clear(self) -> None: + """ + Clears all cached promopts and indices. + """ + with self._lock: + self._prompts.clear() + self._name_index.clear() + self._name_version_index.clear() + + class BaseLiteralAPI: def __init__( self, @@ -165,8 +219,7 @@ def __init__( self.graphql_endpoint = self.url + "/api/graphql" self.rest_endpoint = self.url + "/api" - self._prompt_cache: dict[str, str] = dict() - self._prompt_storage: dict[str, Prompt] = dict() + self.prompt_cache = PromptCache() @property def headers(self): @@ -184,59 +237,6 @@ def headers(self): return h - def _get_prompt_cache_key( - self, - id: Optional[str] = None, - name: Optional[str] = None, - version: Optional[int] = None, - ) -> str: - key = "" - if id: - key = f"id:{id}" - elif name: - key = f"name:{name}" - if version: - key += f":version:{version}" - - return key - - def _get_prompt_cache(self, id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = None) -> Optional[Prompt]: - """Returns the cached prompt, using key in this order: id, name-version, name - """ - ref_key = None - - key_id = self._get_prompt_cache_key(id=id) - if key_id in self._prompt_cache: - ref_key = self._prompt_cache[key_id] - else: - key_name_version = self._get_prompt_cache_key(name=name, version=version) - if key_name_version in self._prompt_cache: - ref_key = self._prompt_cache[key_name_version] - else: - key_name = self._get_prompt_cache_key(name=name) - if key_name in self._prompt_cache: - ref_key = self._prompt_cache[key_name] - - if ref_key and ref_key in self._prompt_storage: - return self._prompt_storage[ref_key] - - return None - - def _create_prompt_cache(self, prompt: Prompt): - """Creates cache for prompt. All keys point to the same storage key (prompt.id) - to avoid storing multiple copies of the same prompt. - """ - storage_key = f"id:{prompt.id}" - self._prompt_storage[storage_key] = prompt - - key_id = self._get_prompt_cache_key(id=prompt.id) - self._prompt_cache[key_id] = storage_key - - key_name = self._get_prompt_cache_key(name=prompt.name) - self._prompt_cache[key_name] = storage_key - - key_name_version = self._get_prompt_cache_key(name=prompt.name, version=prompt.version) - self._prompt_cache[key_name_version] = storage_key class LiteralAPI(BaseLiteralAPI): From 0aec701aef663a2d307cb214ed0f68bbfa8ba302 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Fri, 15 Nov 2024 10:56:39 +0100 Subject: [PATCH 10/21] refactor: remove useless code --- literalai/api/__init__.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index f314712..6efb1bd 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -286,7 +286,7 @@ def raise_error(error): if json.get("data"): if isinstance(json["data"], dict): - for key, value in json["data"].items(): + for value in json["data"].values(): if value and value.get("ok") is False: raise_error( f"""Failed to {description}: { @@ -295,10 +295,6 @@ def raise_error(error): return json - # This should not be reached, exceptions should be thrown beforehands - # Added because of mypy - raise Exception("Unknown error") - def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: with httpx.Client(follow_redirects=True) as client: response = client.post( @@ -1554,7 +1550,7 @@ def raise_error(error): if json.get("data"): if isinstance(json["data"], dict): - for key, value in json["data"].items(): + for value in json["data"].values(): if value and value.get("ok") is False: raise_error( f"""Failed to {description}: { @@ -1563,10 +1559,6 @@ def raise_error(error): return json - # This should not be reached, exceptions should be thrown beforehands - # Added because of mypy - raise Exception("Unkown error") - async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: async with httpx.AsyncClient(follow_redirects=True) as client: response = await client.post( From 32b4e4863e29862c0c8f787c3894bd0d2bcd6dde Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Fri, 15 Nov 2024 11:08:42 +0100 Subject: [PATCH 11/21] feat: implement the new SharedCachePrompt class --- literalai/api/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 6efb1bd..2b79a45 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -219,7 +219,7 @@ def __init__( self.graphql_endpoint = self.url + "/api/graphql" self.rest_endpoint = self.url + "/api" - self.prompt_cache = PromptCache() + self.prompt_cache = SharedPromptCache() @property def headers(self): @@ -1415,7 +1415,7 @@ def get_prompt( if not (id or name): raise ValueError("Either the `id` or the `name` must be provided.") - cached_prompt = self._get_prompt_cache(id, name, version) + cached_prompt = self.prompt_cache.get(id, name, version) timeout = 1 if cached_prompt else None try: @@ -1424,7 +1424,7 @@ def get_prompt( elif name: prompt = self.gql_helper(*get_prompt_helper(self, name=name, version=version, timeout=timeout)) - self._create_prompt_cache(prompt) + self.prompt_cache.put(prompt) return prompt except Exception as e: @@ -2659,7 +2659,7 @@ async def get_prompt( raise ValueError("Either the `id` or the `name` must be provided.") sync_api = LiteralAPI(self.api_key, self.url) - cached_prompt = self._get_prompt_cache(id, name) + cached_prompt = self.prompt_cache.get(id, name, version) timeout = 1 if cached_prompt else None try: @@ -2674,7 +2674,7 @@ async def get_prompt( ) ) - self._create_prompt_cache(prompt) + self.prompt_cache.put(prompt) return prompt except Exception as e: From f5d460b8f1a4ce1dee3b5abe442a593ae84d7cbd Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Fri, 15 Nov 2024 12:42:40 +0100 Subject: [PATCH 12/21] refactor: improve typing and move some logic --- literalai/api/prompt_helpers.py | 37 ++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index 5deff89..82409ea 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -1,10 +1,11 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict +from typing import TYPE_CHECKING, Optional, TypedDict, Callable from literalai.observability.generation import GenerationMessage from literalai.prompt_engineering.prompt import Prompt, ProviderSettings if TYPE_CHECKING: from literalai.api import LiteralAPI + from literalai.api import SharedPromptCache from literalai.api import gql @@ -36,9 +37,9 @@ def process_response(response): def create_prompt_helper( api: "LiteralAPI", lineage_id: str, - template_messages: List[GenerationMessage], + template_messages: list[GenerationMessage], settings: Optional[ProviderSettings] = None, - tools: Optional[List[Dict]] = None, + tools: Optional[list[dict]] = None, ): variables = { "lineageId": lineage_id, @@ -62,23 +63,35 @@ def get_prompt_helper( name: Optional[str] = None, version: Optional[int] = 0, timeout: Optional[int] = None, -): + prompt_cache: Optional[SharedPromptCache] = None, +) -> tuple[str, str, dict, Callable]: + """Helper function for getting prompts with caching logic""" + if not (id or name): + raise ValueError("Either the `id` or the `name` must be provided.") + + cached_prompt = None + if prompt_cache: + cached_prompt = prompt_cache.get(id, name, version) + timeout = 1 if cached_prompt else timeout + variables = {"id": id, "name": name, "version": version} def process_response(response): - prompt = response["data"]["promptVersion"] - return Prompt.from_dict(api, prompt) if prompt else None + prompt = Prompt.from_dict(api, response["data"]["prompt"]) + if prompt_cache: + prompt_cache.put(prompt) + return prompt description = "get prompt" - return gql.GET_PROMPT_VERSION, description, variables, process_response, timeout + return gql.GET_PROMPT_VERSION, description, variables, process_response, cached_prompt def create_prompt_variant_helper( from_lineage_id: Optional[str] = None, - template_messages: List[GenerationMessage] = [], + template_messages: list[GenerationMessage] = [], settings: Optional[ProviderSettings] = None, - tools: Optional[List[Dict]] = None, + tools: Optional[list[dict]] = None, ): variables = { "fromLineageId": from_lineage_id, @@ -106,7 +119,7 @@ def get_prompt_ab_testing_helper( ): variables = {"lineageName": name} - def process_response(response) -> List[PromptRollout]: + def process_response(response) -> list[PromptRollout]: response_data = response["data"]["promptLineageRollout"] return list(map(lambda x: x["node"], response_data["edges"])) @@ -115,10 +128,10 @@ def process_response(response) -> List[PromptRollout]: return gql.GET_PROMPT_AB_TESTING, description, variables, process_response -def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]): +def update_prompt_ab_testing_helper(name: str, rollouts: list[PromptRollout]): variables = {"name": name, "rollouts": rollouts} - def process_response(response) -> Dict: + def process_response(response) -> dict: return response["data"]["updatePromptLineageRollout"] description = "update prompt A/B testing" From 49fd1404132c461dec06485af0e8c9a96d9e5278 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Mon, 18 Nov 2024 12:13:43 +0100 Subject: [PATCH 13/21] feat: adds memory management to the SharedCachePrompt class --- literalai/api/__init__.py | 49 +++++++++++++++++++++++---------- literalai/api/prompt_helpers.py | 13 +++++---- 2 files changed, 42 insertions(+), 20 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 2b79a45..6a6fcee 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -1,6 +1,7 @@ -from threading import Lock import logging +import time import os +from threading import Lock import uuid from typing import ( Any, @@ -144,17 +145,19 @@ def handle_bytes(item): class SharedPromptCache: """ - Thread-safe singleton cache for storing prompts. + Thread-safe singleton cache for storing prompts with memory leak prevention. Only one instance will exist regardless of how many times it's instantiated. + Implements LRU eviction policy when cache reaches maximum size. """ _instance = None _lock = Lock() - def __new__(cls): + def __new__(cls, max_size: int = 1000): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) + cls._instance._max_size = max_size cls._instance._prompts: dict[str, Prompt] = {} cls._instance._name_index: dict[str, str] = {} cls._instance._name_version_index: dict[tuple[str, int], str] = {} @@ -168,18 +171,36 @@ def get( ) -> Optional[Prompt]: """ Retrieves a prompt using the most specific criteria provided. + Updates access time for LRU tracking. Lookup priority: id, name-version, name """ + if id and not isinstance(id, str): + raise TypeError("Expected a string for id") + if name and not isinstance(name, str): + raise TypeError("Expected a string for name") + if version and not isinstance(version, int): + raise TypeError("Expected an integer for version") + if id: prompt_id = id elif name and version: prompt_id = self._name_version_index.get((name, version)) elif name: prompt_id = self._name_index.get(name) + else: + return None - return self._prompts.get(prompt_id) if prompt_id else None + if prompt_id and prompt_id in self._prompts: + return self._prompts.get(prompt_id) + return None def put(self, prompt: Prompt): + """ + Stores a prompt in the cache, managing size limits with LRU eviction. + """ + if not isinstance(prompt, Prompt): + raise TypeError("Expected a Prompt object") + with self._lock: self._prompts[prompt.id] = prompt self._name_index[prompt.name] = prompt.id @@ -1415,14 +1436,15 @@ def get_prompt( if not (id or name): raise ValueError("Either the `id` or the `name` must be provided.") - cached_prompt = self.prompt_cache.get(id, name, version) - timeout = 1 if cached_prompt else None + get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper( + api=self,id=id, name=name, version=version, prompt_cache=self.prompt_cache + ) try: if id: - prompt = self.gql_helper(*get_prompt_helper(self, id=id, timeout=timeout)) + prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout) elif name: - prompt = self.gql_helper(*get_prompt_helper(self, name=name, version=version, timeout=timeout)) + prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout) self.prompt_cache.put(prompt) return prompt @@ -2659,19 +2681,18 @@ async def get_prompt( raise ValueError("Either the `id` or the `name` must be provided.") sync_api = LiteralAPI(self.api_key, self.url) - cached_prompt = self.prompt_cache.get(id, name, version) - timeout = 1 if cached_prompt else None + get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper( + api=sync_api, id=id, name=name, version=version, prompt_cache=self.prompt_cache + ) try: if id: prompt = await self.gql_helper( - *get_prompt_helper(sync_api, id=id, timeout=timeout) + get_prompt_query, description, variables, process_response, timeout ) elif name: prompt = await self.gql_helper( - *get_prompt_helper( - sync_api, name=name, version=version, timeout=timeout - ) + get_prompt_query, description, variables, process_response, timeout ) self.prompt_cache.put(prompt) diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index 82409ea..68e713a 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -5,7 +5,6 @@ if TYPE_CHECKING: from literalai.api import LiteralAPI - from literalai.api import SharedPromptCache from literalai.api import gql @@ -62,14 +61,15 @@ def get_prompt_helper( id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = 0, - timeout: Optional[int] = None, - prompt_cache: Optional[SharedPromptCache] = None, -) -> tuple[str, str, dict, Callable]: + prompt_cache: Optional["SharedPromptCache"] = None, +) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]: """Helper function for getting prompts with caching logic""" if not (id or name): raise ValueError("Either the `id` or the `name` must be provided.") cached_prompt = None + timeout = 10 + if prompt_cache: cached_prompt = prompt_cache.get(id, name, version) timeout = 1 if cached_prompt else timeout @@ -77,14 +77,15 @@ def get_prompt_helper( variables = {"id": id, "name": name, "version": version} def process_response(response): - prompt = Prompt.from_dict(api, response["data"]["prompt"]) + prompt_version = response["data"]["promptVersion"] + prompt = Prompt.from_dict(api, prompt_version) if prompt_version else None if prompt_cache: prompt_cache.put(prompt) return prompt description = "get prompt" - return gql.GET_PROMPT_VERSION, description, variables, process_response, cached_prompt + return gql.GET_PROMPT_VERSION, description, variables, process_response, timeout, cached_prompt def create_prompt_variant_helper( From 3e139f2fe10864e46e9739547e4efe75564087fd Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Mon, 18 Nov 2024 13:36:00 +0100 Subject: [PATCH 14/21] feat: add unit tests for SharedCachePrompt --- tests/unit/__init__.py | 0 tests/unit/test_cache.py | 185 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 185 insertions(+) create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_cache.py diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py new file mode 100644 index 0000000..e6603c4 --- /dev/null +++ b/tests/unit/test_cache.py @@ -0,0 +1,185 @@ +import pytest +from threading import Thread +import time +import random + +from literalai.prompt_engineering.prompt import Prompt +from literalai.api import SharedPromptCache + +def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Prompt: + return Prompt( + api=None, + id=id, + name=name, + version=version, + created_at="", + updated_at="", + type="chat", + url="", + version_desc=None, + template_messages=[], + tools=None, + provider="", + settings={}, + variables=[], + variables_default_values=None + ) + +def test_singleton_instance(): + """Test that SharedPromptCache maintains singleton pattern""" + cache1 = SharedPromptCache() + cache2 = SharedPromptCache() + assert cache1 is cache2 + +def test_get_empty_cache(): + """Test getting from empty cache returns None""" + cache = SharedPromptCache() + cache.clear() # Ensure clean state + + assert cache._prompts == {} + assert cache._name_index == {} + assert cache._name_version_index == {} + +def test_put_and_get_by_id(): + """Test storing and retrieving prompt by ID""" + cache = SharedPromptCache() + cache.clear() + + prompt = default_prompt() + cache.put(prompt) + + retrieved = cache.get(id="1") + assert retrieved is prompt + assert retrieved.id == "1" + assert retrieved.name == "test" + assert retrieved.version == 1 + +def test_put_and_get_by_name(): + """Test storing and retrieving prompt by name""" + cache = SharedPromptCache() + cache.clear() + + prompt = default_prompt() + cache.put(prompt) + + retrieved = cache.get(name="test") + assert retrieved is prompt + assert retrieved.name == "test" + +def test_put_and_get_by_name_version(): + """Test storing and retrieving prompt by name and version""" + cache = SharedPromptCache() + cache.clear() + + prompt = default_prompt() + cache.put(prompt) + + retrieved = cache.get(name="test", version=1) + assert retrieved is prompt + assert retrieved.name == "test" + assert retrieved.version == 1 + +def test_multiple_versions(): + """Test handling multiple versions of the same prompt""" + cache = SharedPromptCache() + cache.clear() + + prompt1 = default_prompt() + prompt2 = default_prompt(id="2", version=2) + + cache.put(prompt1) + cache.put(prompt2) + + # Get specific versions + assert cache.get(name="test", version=1) is prompt1 + assert cache.get(name="test", version=2) is prompt2 + + # Get by name should return latest version + assert cache.get(name="test") is prompt2 # Returns the last indexed version + +def test_clear_cache(): + """Test clearing the cache""" + cache = SharedPromptCache() + prompt = default_prompt() + cache.put(prompt) + + cache.clear() + assert cache._prompts == {} + assert cache._name_index == {} + assert cache._name_version_index == {} + +def test_update_existing_prompt(): + """Test updating an existing prompt""" + cache = SharedPromptCache() + cache.clear() + + prompt1 = default_prompt() + prompt2 = default_prompt(id="1", version=2) # Same ID, different version + + cache.put(prompt1) + cache.put(prompt2) + + retrieved = cache.get(id="1") + assert retrieved is prompt2 + assert retrieved.version == 2 + +def test_lookup_priority(): + """Test that lookup priority is id > name-version > name""" + cache = SharedPromptCache() + cache.clear() + + prompt1 = default_prompt() + prompt2 = default_prompt(id="2", name="test", version=2) + + cache.put(prompt1) + cache.put(prompt2) + + # ID should take precedence + assert cache.get(id="1", name="test", version=2) is prompt1 + + # Name-version should take precedence over name + assert cache.get(name="test", version=2) is prompt2 + +def test_thread_safety(): + """Test thread safety of the cache""" + cache = SharedPromptCache() + cache.clear() + + def worker(worker_id: int): + for i in range(100): + prompt = default_prompt( + id=f"{worker_id}-{i}", + name=f"test-{worker_id}", + version=i + ) + cache.put(prompt) + time.sleep(random.uniform(0, 0.001)) + + retrieved = cache.get(id=prompt.id) + assert retrieved is prompt + + threads = [Thread(target=worker, args=(i,)) for i in range(10)] + + for t in threads: + t.start() + for t in threads: + t.join() + + for worker_id in range(10): + for i in range(100): + prompt_id = f"{worker_id}-{i}" + assert cache.get(id=prompt_id) is not None + +def test_error_handling(): + """Test error handling for invalid inputs""" + cache = SharedPromptCache() + cache.clear() + + assert cache.get() is None + assert cache.get(id=None, name=None, version=None) is None + + with pytest.raises(TypeError): + cache.get(version="invalid") # type: ignore + + with pytest.raises(TypeError): + cache.put("not a prompt") # type: ignore \ No newline at end of file From 3730581a865895103011d20d4608d42d808f879a Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Mon, 18 Nov 2024 14:25:18 +0100 Subject: [PATCH 15/21] feat: adds tests and updates run-test.sh --- literalai/api/__init__.py | 21 ++++++++++----------- literalai/api/prompt_helpers.py | 1 + run-test.sh | 2 +- tests/e2e/test_e2e.py | 14 ++++++++++++++ tests/unit/test_cache.py | 16 ++++++---------- 5 files changed, 32 insertions(+), 22 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 6a6fcee..0451cf5 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -1,5 +1,4 @@ import logging -import time import os from threading import Lock import uuid @@ -147,20 +146,21 @@ class SharedPromptCache: """ Thread-safe singleton cache for storing prompts with memory leak prevention. Only one instance will exist regardless of how many times it's instantiated. - Implements LRU eviction policy when cache reaches maximum size. """ _instance = None _lock = Lock() + _prompts: dict[str, Prompt] + _name_index: dict[str, str] + _name_version_index: dict[tuple[str, int], str] - def __new__(cls, max_size: int = 1000): + def __new__(cls): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) - cls._instance._max_size = max_size - cls._instance._prompts: dict[str, Prompt] = {} - cls._instance._name_index: dict[str, str] = {} - cls._instance._name_version_index: dict[tuple[str, int], str] = {} + cls._instance._prompts = {} + cls._instance._name_index = {} + cls._instance._name_version_index = {} return cls._instance def get( @@ -171,7 +171,6 @@ def get( ) -> Optional[Prompt]: """ Retrieves a prompt using the most specific criteria provided. - Updates access time for LRU tracking. Lookup priority: id, name-version, name """ if id and not isinstance(id, str): @@ -184,9 +183,9 @@ def get( if id: prompt_id = id elif name and version: - prompt_id = self._name_version_index.get((name, version)) + prompt_id = self._name_version_index.get((name, version)) or "" elif name: - prompt_id = self._name_index.get(name) + prompt_id = self._name_index.get(name) or "" else: return None @@ -196,7 +195,7 @@ def get( def put(self, prompt: Prompt): """ - Stores a prompt in the cache, managing size limits with LRU eviction. + Stores a prompt in the cache. """ if not isinstance(prompt, Prompt): raise TypeError("Expected a Prompt object") diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index 68e713a..528cdd2 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from literalai.api import LiteralAPI + from literalai.api import SharedPromptCache from literalai.api import gql diff --git a/run-test.sh b/run-test.sh index 2a51465..b256ce3 100755 --- a/run-test.sh +++ b/run-test.sh @@ -1 +1 @@ -LITERAL_API_URL=http://localhost:3000 LITERAL_API_KEY=my-initial-api-key pytest -m e2e -s -v +LITERAL_API_URL=http://localhost:3000 LITERAL_API_KEY=my-initial-api-key pytest -m e2e -s -v tests/e2e/ tests/unit/ diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 3b58199..426670b 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -662,6 +662,20 @@ async def test_prompt(self, async_client: AsyncLiteralClient): assert messages[0]["content"] == expected + @pytest.mark.timeout(5) + async def test_prompt_cache(self, async_client: AsyncLiteralClient): + prompt = await async_client.api.get_prompt(name="Default", version=0) + assert prompt is not None + + original_key = async_client.api.api_key + async_client.api.api_key = "invalid-api-key" + + cached_prompt = await async_client.api.get_prompt(name="Default", version=0) + assert cached_prompt is not None + assert cached_prompt.id == prompt.id + + async_client.api.api_key = original_key + @pytest.mark.timeout(5) async def test_prompt_ab_testing(self, client: LiteralClient): prompt_name = "Python SDK E2E Tests" diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py index e6603c4..9e32ba0 100644 --- a/tests/unit/test_cache.py +++ b/tests/unit/test_cache.py @@ -4,17 +4,17 @@ import random from literalai.prompt_engineering.prompt import Prompt -from literalai.api import SharedPromptCache +from literalai.api import SharedPromptCache, LiteralAPI def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Prompt: return Prompt( - api=None, + api=LiteralAPI(), id=id, name=name, version=version, created_at="", updated_at="", - type="chat", + type="chat", # type: ignore url="", version_desc=None, template_messages=[], @@ -34,7 +34,7 @@ def test_singleton_instance(): def test_get_empty_cache(): """Test getting from empty cache returns None""" cache = SharedPromptCache() - cache.clear() # Ensure clean state + cache.clear() assert cache._prompts == {} assert cache._name_index == {} @@ -90,12 +90,10 @@ def test_multiple_versions(): cache.put(prompt1) cache.put(prompt2) - # Get specific versions assert cache.get(name="test", version=1) is prompt1 assert cache.get(name="test", version=2) is prompt2 - # Get by name should return latest version - assert cache.get(name="test") is prompt2 # Returns the last indexed version + assert cache.get(name="test") is prompt2 def test_clear_cache(): """Test clearing the cache""" @@ -114,7 +112,7 @@ def test_update_existing_prompt(): cache.clear() prompt1 = default_prompt() - prompt2 = default_prompt(id="1", version=2) # Same ID, different version + prompt2 = default_prompt(id="1", version=2) cache.put(prompt1) cache.put(prompt2) @@ -134,10 +132,8 @@ def test_lookup_priority(): cache.put(prompt1) cache.put(prompt2) - # ID should take precedence assert cache.get(id="1", name="test", version=2) is prompt1 - # Name-version should take precedence over name assert cache.get(name="test", version=2) is prompt2 def test_thread_safety(): From 53187517e373028c233f089c058898c4dd917a4a Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Tue, 19 Nov 2024 17:55:03 +0100 Subject: [PATCH 16/21] refactor: finishes the simplication --- literalai/api/__init__.py | 76 +++++++++------------------------ literalai/api/prompt_helpers.py | 15 +++++-- 2 files changed, 32 insertions(+), 59 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 0451cf5..96cd650 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -1,6 +1,5 @@ import logging import os -from threading import Lock import uuid from typing import ( Any, @@ -142,77 +141,40 @@ def handle_bytes(item): return handle_bytes(variables) -class SharedPromptCache: +class SharedCache: """ Thread-safe singleton cache for storing prompts with memory leak prevention. Only one instance will exist regardless of how many times it's instantiated. """ _instance = None - _lock = Lock() - _prompts: dict[str, Prompt] - _name_index: dict[str, str] - _name_version_index: dict[tuple[str, int], str] + _cache: dict[str, Any] def __new__(cls): - with cls._lock: - if cls._instance is None: - cls._instance = super().__new__(cls) - - cls._instance._prompts = {} - cls._instance._name_index = {} - cls._instance._name_version_index = {} + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.cache = {} return cls._instance - def get( - self, - id: Optional[str] = None, - name: Optional[str] = None, - version: Optional[int] = None - ) -> Optional[Prompt]: + def get_cache(self) -> dict[str, Any]: + return self._cache + + def get(self, key: str) -> Optional[Any]: """ - Retrieves a prompt using the most specific criteria provided. - Lookup priority: id, name-version, name + Retrieves a value from the cache using the provided key. """ - if id and not isinstance(id, str): - raise TypeError("Expected a string for id") - if name and not isinstance(name, str): - raise TypeError("Expected a string for name") - if version and not isinstance(version, int): - raise TypeError("Expected an integer for version") + return self._cache.get(key) - if id: - prompt_id = id - elif name and version: - prompt_id = self._name_version_index.get((name, version)) or "" - elif name: - prompt_id = self._name_index.get(name) or "" - else: - return None - - if prompt_id and prompt_id in self._prompts: - return self._prompts.get(prompt_id) - return None - - def put(self, prompt: Prompt): + def put(self, key: str, value: Any): """ - Stores a prompt in the cache. + Stores a value in the cache. """ - if not isinstance(prompt, Prompt): - raise TypeError("Expected a Prompt object") - - with self._lock: - self._prompts[prompt.id] = prompt - self._name_index[prompt.name] = prompt.id - self._name_version_index[(prompt.name, prompt.version)] = prompt.id + self._cache[key] = value def clear(self) -> None: """ - Clears all cached promopts and indices. + Clears all cached values. """ - with self._lock: - self._prompts.clear() - self._name_index.clear() - self._name_version_index.clear() + self._cache.clear() class BaseLiteralAPI: @@ -239,7 +201,7 @@ def __init__( self.graphql_endpoint = self.url + "/api/graphql" self.rest_endpoint = self.url + "/api" - self.prompt_cache = SharedPromptCache() + self.cache = SharedCache() @property def headers(self): @@ -1445,7 +1407,9 @@ def get_prompt( elif name: prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout) - self.prompt_cache.put(prompt) + self.cache.put(prompt.id, prompt) + self.cache.put(prompt.name, prompt) + self.cache.put(f"{prompt.name}-{prompt.version}", prompt) return prompt except Exception as e: diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index 528cdd2..e45282e 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -57,6 +57,17 @@ def process_response(response): return gql.CREATE_PROMPT_VERSION, description, variables, process_response +def get_prompt_cache_key(id: Optional[str], name: Optional[str], version: Optional[int]) -> str: + if id: + return id + elif name and version: + return f"{name}-{version}" + elif name: + return name + else: + raise ValueError("Either the `id` or the `name` must be provided.") + + def get_prompt_helper( api: "LiteralAPI", id: Optional[str] = None, @@ -65,14 +76,12 @@ def get_prompt_helper( prompt_cache: Optional["SharedPromptCache"] = None, ) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]: """Helper function for getting prompts with caching logic""" - if not (id or name): - raise ValueError("Either the `id` or the `name` must be provided.") cached_prompt = None timeout = 10 if prompt_cache: - cached_prompt = prompt_cache.get(id, name, version) + cached_prompt = prompt_cache.get(get_prompt_cache_key(id, name, version)) timeout = 1 if cached_prompt else timeout variables = {"id": id, "name": name, "version": version} From 85c72d108b83f14d247202a9a414d3c45c13f232 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Wed, 20 Nov 2024 10:50:58 +0100 Subject: [PATCH 17/21] fix: test and implementation --- literalai/api/__init__.py | 22 +++-- literalai/api/prompt_helpers.py | 12 +-- tests/unit/test_cache.py | 144 ++++++-------------------------- 3 files changed, 46 insertions(+), 132 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 96cd650..5d81ea4 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -143,16 +143,15 @@ def handle_bytes(item): class SharedCache: """ - Thread-safe singleton cache for storing prompts with memory leak prevention. + Singleton cache for storing data. Only one instance will exist regardless of how many times it's instantiated. """ _instance = None - _cache: dict[str, Any] def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) - cls._instance.cache = {} + cls._instance._cache = {} return cls._instance def get_cache(self) -> dict[str, Any]: @@ -162,14 +161,23 @@ def get(self, key: str) -> Optional[Any]: """ Retrieves a value from the cache using the provided key. """ + if not isinstance(key, str): + raise TypeError("Key must be a string") return self._cache.get(key) def put(self, key: str, value: Any): """ Stores a value in the cache. """ + if not isinstance(key, str): + raise TypeError("Key must be a string") self._cache[key] = value + def put_prompt(self, prompt: Prompt): + self.put(prompt.id, prompt) + self.put(prompt.name, prompt) + self.put(f"{prompt.name}-{prompt.version}", prompt) + def clear(self) -> None: """ Clears all cached values. @@ -1398,7 +1406,7 @@ def get_prompt( raise ValueError("Either the `id` or the `name` must be provided.") get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper( - api=self,id=id, name=name, version=version, prompt_cache=self.prompt_cache + api=self,id=id, name=name, version=version, cache=self.cache ) try: @@ -1407,9 +1415,6 @@ def get_prompt( elif name: prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout) - self.cache.put(prompt.id, prompt) - self.cache.put(prompt.name, prompt) - self.cache.put(f"{prompt.name}-{prompt.version}", prompt) return prompt except Exception as e: @@ -2645,7 +2650,7 @@ async def get_prompt( sync_api = LiteralAPI(self.api_key, self.url) get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper( - api=sync_api, id=id, name=name, version=version, prompt_cache=self.prompt_cache + api=sync_api, id=id, name=name, version=version, cache=self.cache ) try: @@ -2658,7 +2663,6 @@ async def get_prompt( get_prompt_query, description, variables, process_response, timeout ) - self.prompt_cache.put(prompt) return prompt except Exception as e: diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index e45282e..c35c691 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from literalai.api import LiteralAPI - from literalai.api import SharedPromptCache + from literalai.api import SharedCache from literalai.api import gql @@ -73,15 +73,15 @@ def get_prompt_helper( id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = 0, - prompt_cache: Optional["SharedPromptCache"] = None, + cache: Optional["SharedCache"] = None, ) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]: """Helper function for getting prompts with caching logic""" cached_prompt = None timeout = 10 - if prompt_cache: - cached_prompt = prompt_cache.get(get_prompt_cache_key(id, name, version)) + if cache: + cached_prompt = cache.get(get_prompt_cache_key(id, name, version)) timeout = 1 if cached_prompt else timeout variables = {"id": id, "name": name, "version": version} @@ -89,8 +89,8 @@ def get_prompt_helper( def process_response(response): prompt_version = response["data"]["promptVersion"] prompt = Prompt.from_dict(api, prompt_version) if prompt_version else None - if prompt_cache: - prompt_cache.put(prompt) + if cache and prompt: + cache.put_prompt(prompt) return prompt description = "get prompt" diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py index 9e32ba0..c18ffb4 100644 --- a/tests/unit/test_cache.py +++ b/tests/unit/test_cache.py @@ -1,10 +1,7 @@ import pytest -from threading import Thread -import time -import random from literalai.prompt_engineering.prompt import Prompt -from literalai.api import SharedPromptCache, LiteralAPI +from literalai.api import SharedCache, LiteralAPI def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Prompt: return Prompt( @@ -26,156 +23,69 @@ def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Promp ) def test_singleton_instance(): - """Test that SharedPromptCache maintains singleton pattern""" - cache1 = SharedPromptCache() - cache2 = SharedPromptCache() + """Test that SharedCache maintains singleton pattern""" + cache1 = SharedCache() + cache2 = SharedCache() assert cache1 is cache2 def test_get_empty_cache(): """Test getting from empty cache returns None""" - cache = SharedPromptCache() + cache = SharedCache() cache.clear() - assert cache._prompts == {} - assert cache._name_index == {} - assert cache._name_version_index == {} + assert cache.get_cache() == {} -def test_put_and_get_by_id(): - """Test storing and retrieving prompt by ID""" - cache = SharedPromptCache() +def test_put_and_get_prompt_by_id_by_name_version_by_name(): + """Test storing and retrieving prompt by ID by name-version by name""" + cache = SharedCache() cache.clear() prompt = default_prompt() - cache.put(prompt) + cache.put_prompt(prompt) - retrieved = cache.get(id="1") - assert retrieved is prompt - assert retrieved.id == "1" - assert retrieved.name == "test" - assert retrieved.version == 1 - -def test_put_and_get_by_name(): - """Test storing and retrieving prompt by name""" - cache = SharedPromptCache() - cache.clear() + retrieved_by_id = cache.get(id="1") + assert retrieved_by_id is prompt - prompt = default_prompt() - cache.put(prompt) + retrieved_by_name_version = cache.get(name="test", version=1) + assert retrieved_by_name_version is prompt - retrieved = cache.get(name="test") - assert retrieved is prompt - assert retrieved.name == "test" - -def test_put_and_get_by_name_version(): - """Test storing and retrieving prompt by name and version""" - cache = SharedPromptCache() - cache.clear() - - prompt = default_prompt() - cache.put(prompt) - - retrieved = cache.get(name="test", version=1) - assert retrieved is prompt - assert retrieved.name == "test" - assert retrieved.version == 1 - -def test_multiple_versions(): - """Test handling multiple versions of the same prompt""" - cache = SharedPromptCache() - cache.clear() - - prompt1 = default_prompt() - prompt2 = default_prompt(id="2", version=2) - - cache.put(prompt1) - cache.put(prompt2) - - assert cache.get(name="test", version=1) is prompt1 - assert cache.get(name="test", version=2) is prompt2 - - assert cache.get(name="test") is prompt2 + retrieved_by_name = cache.get(name="test") + assert retrieved_by_name is prompt def test_clear_cache(): """Test clearing the cache""" - cache = SharedPromptCache() + cache = SharedCache() prompt = default_prompt() - cache.put(prompt) + cache.put_prompt(prompt) cache.clear() - assert cache._prompts == {} - assert cache._name_index == {} - assert cache._name_version_index == {} + assert cache.get_cache() == {} def test_update_existing_prompt(): """Test updating an existing prompt""" - cache = SharedPromptCache() + cache = SharedCache() cache.clear() prompt1 = default_prompt() prompt2 = default_prompt(id="1", version=2) - cache.put(prompt1) - cache.put(prompt2) + cache.put_prompt(prompt1) + cache.put_prompt(prompt2) retrieved = cache.get(id="1") assert retrieved is prompt2 assert retrieved.version == 2 -def test_lookup_priority(): - """Test that lookup priority is id > name-version > name""" - cache = SharedPromptCache() - cache.clear() - - prompt1 = default_prompt() - prompt2 = default_prompt(id="2", name="test", version=2) - - cache.put(prompt1) - cache.put(prompt2) - - assert cache.get(id="1", name="test", version=2) is prompt1 - - assert cache.get(name="test", version=2) is prompt2 - -def test_thread_safety(): - """Test thread safety of the cache""" - cache = SharedPromptCache() - cache.clear() - - def worker(worker_id: int): - for i in range(100): - prompt = default_prompt( - id=f"{worker_id}-{i}", - name=f"test-{worker_id}", - version=i - ) - cache.put(prompt) - time.sleep(random.uniform(0, 0.001)) - - retrieved = cache.get(id=prompt.id) - assert retrieved is prompt - - threads = [Thread(target=worker, args=(i,)) for i in range(10)] - - for t in threads: - t.start() - for t in threads: - t.join() - - for worker_id in range(10): - for i in range(100): - prompt_id = f"{worker_id}-{i}" - assert cache.get(id=prompt_id) is not None - def test_error_handling(): """Test error handling for invalid inputs""" - cache = SharedPromptCache() + cache = SharedCache() cache.clear() - assert cache.get() is None - assert cache.get(id=None, name=None, version=None) is None + assert cache.get_cache() == {} + assert cache.get(key="") is None with pytest.raises(TypeError): - cache.get(version="invalid") # type: ignore + cache.get(5) # type: ignore with pytest.raises(TypeError): - cache.put("not a prompt") # type: ignore \ No newline at end of file + cache.put(5, "test") # type: ignore \ No newline at end of file From 6dfce9c0f8a077bd0f63460cd3a826e2703da745 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Wed, 20 Nov 2024 10:58:30 +0100 Subject: [PATCH 18/21] fix: add typing for sharedcache typing --- literalai/api/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 5d81ea4..0b228e2 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -147,6 +147,7 @@ class SharedCache: Only one instance will exist regardless of how many times it's instantiated. """ _instance = None + _cache: dict[str, Any] def __new__(cls): if cls._instance is None: From 06c5047b1867a4cfdaf92d83ee59a57ef1b3e677 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 21 Nov 2024 16:58:58 +0100 Subject: [PATCH 19/21] feat: align with literalai-typescript chagnes --- literalai/api/__init__.py | 46 +------------------------------ literalai/api/prompt_helpers.py | 4 ++- literalai/cache/__init__.py | 0 literalai/cache/prompt_helpers.py | 8 ++++++ literalai/cache/shared_cache.py | 42 ++++++++++++++++++++++++++++ tests/unit/test_cache.py | 8 ++++-- 6 files changed, 59 insertions(+), 49 deletions(-) create mode 100644 literalai/cache/__init__.py create mode 100644 literalai/cache/prompt_helpers.py create mode 100644 literalai/cache/shared_cache.py diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 0b228e2..13d015a 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -103,6 +103,7 @@ import httpx +from literalai.cache.shared_cache import SharedCache from literalai.my_types import Environment, PaginatedResponse from literalai.observability.generation import ( ChatGeneration, @@ -141,51 +142,6 @@ def handle_bytes(item): return handle_bytes(variables) -class SharedCache: - """ - Singleton cache for storing data. - Only one instance will exist regardless of how many times it's instantiated. - """ - _instance = None - _cache: dict[str, Any] - - def __new__(cls): - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._instance._cache = {} - return cls._instance - - def get_cache(self) -> dict[str, Any]: - return self._cache - - def get(self, key: str) -> Optional[Any]: - """ - Retrieves a value from the cache using the provided key. - """ - if not isinstance(key, str): - raise TypeError("Key must be a string") - return self._cache.get(key) - - def put(self, key: str, value: Any): - """ - Stores a value in the cache. - """ - if not isinstance(key, str): - raise TypeError("Key must be a string") - self._cache[key] = value - - def put_prompt(self, prompt: Prompt): - self.put(prompt.id, prompt) - self.put(prompt.name, prompt) - self.put(f"{prompt.name}-{prompt.version}", prompt) - - def clear(self) -> None: - """ - Clears all cached values. - """ - self._cache.clear() - - class BaseLiteralAPI: def __init__( self, diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index c35c691..15fabaa 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -3,6 +3,8 @@ from literalai.observability.generation import GenerationMessage from literalai.prompt_engineering.prompt import Prompt, ProviderSettings +from literalai.cache.prompt_helpers import put_prompt + if TYPE_CHECKING: from literalai.api import LiteralAPI from literalai.api import SharedCache @@ -90,7 +92,7 @@ def process_response(response): prompt_version = response["data"]["promptVersion"] prompt = Prompt.from_dict(api, prompt_version) if prompt_version else None if cache and prompt: - cache.put_prompt(prompt) + put_prompt(cache, prompt) return prompt description = "get prompt" diff --git a/literalai/cache/__init__.py b/literalai/cache/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/literalai/cache/prompt_helpers.py b/literalai/cache/prompt_helpers.py new file mode 100644 index 0000000..56646f9 --- /dev/null +++ b/literalai/cache/prompt_helpers.py @@ -0,0 +1,8 @@ +from literalai.prompt_engineering.prompt import Prompt +from literalai.cache.shared_cache import SharedCache + + +def put_prompt(cache: SharedCache, prompt: Prompt): + cache.put(prompt.id, prompt) + cache.put(prompt.name, prompt) + cache.put(f"{prompt.name}-{prompt.version}", prompt) diff --git a/literalai/cache/shared_cache.py b/literalai/cache/shared_cache.py new file mode 100644 index 0000000..b356f32 --- /dev/null +++ b/literalai/cache/shared_cache.py @@ -0,0 +1,42 @@ +from typing import Any, Optional + + +class SharedCache: + """ + Singleton cache for storing data. + Only one instance will exist regardless of how many times it's instantiated. + """ + _instance = None + _cache: dict[str, Any] + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._cache = {} + return cls._instance + + def get_cache(self) -> dict[str, Any]: + return self._cache + + def get(self, key: str) -> Optional[Any]: + """ + Retrieves a value from the cache using the provided key. + """ + if not isinstance(key, str): + raise TypeError("Key must be a string") + return self._cache.get(key) + + def put(self, key: str, value: Any): + """ + Stores a value in the cache. + """ + if not isinstance(key, str): + raise TypeError("Key must be a string") + self._cache[key] = value + + def clear(self) -> None: + """ + Clears all cached values. + """ + self._cache.clear() + diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py index c18ffb4..ccf751c 100644 --- a/tests/unit/test_cache.py +++ b/tests/unit/test_cache.py @@ -1,7 +1,9 @@ import pytest from literalai.prompt_engineering.prompt import Prompt -from literalai.api import SharedCache, LiteralAPI +from literalai.api import LiteralAPI +from literalai.cache.shared_cache import SharedCache +from literalai.cache.prompt_helpers import put_prompt def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Prompt: return Prompt( @@ -41,7 +43,7 @@ def test_put_and_get_prompt_by_id_by_name_version_by_name(): cache.clear() prompt = default_prompt() - cache.put_prompt(prompt) + put_prompt(cache, prompt) retrieved_by_id = cache.get(id="1") assert retrieved_by_id is prompt @@ -56,7 +58,7 @@ def test_clear_cache(): """Test clearing the cache""" cache = SharedCache() prompt = default_prompt() - cache.put_prompt(prompt) + put_prompt(cache, prompt) cache.clear() assert cache.get_cache() == {} From cf98d744c7aca046668d974b5a88786930fb6ff3 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 28 Nov 2024 11:34:48 +0100 Subject: [PATCH 20/21] fix: ci --- literalai/api/asynchronous.py | 16 ++++++---------- literalai/api/helpers/prompt_helpers.py | 2 +- literalai/api/synchronous.py | 11 +++++------ 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/literalai/api/asynchronous.py b/literalai/api/asynchronous.py index 5fccc73..3ef6871 100644 --- a/literalai/api/asynchronous.py +++ b/literalai/api/asynchronous.py @@ -106,9 +106,6 @@ from literalai.observability.thread import Thread from literalai.prompt_engineering.prompt import Prompt, ProviderSettings -if TYPE_CHECKING: - from typing import Tuple # noqa: F401 - import httpx from literalai.my_types import PaginatedResponse, User @@ -145,20 +142,20 @@ class AsyncLiteralAPI(BaseLiteralAPI): R = TypeVar("R") async def make_gql_call( - self, description: str, query: str, variables: Dict[str, Any] + self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10 ) -> Dict: def raise_error(error): logger.error(f"Failed to {description}: {error}") raise Exception(error) - variables = prepare_variables(variables) + variables = _prepare_variables(variables) async with httpx.AsyncClient(follow_redirects=True) as client: response = await client.post( self.graphql_endpoint, json={"query": query, "variables": variables}, headers=self.headers, - timeout=10, + timeout=timeout, ) try: @@ -179,13 +176,12 @@ def raise_error(error): if json.get("data"): if isinstance(json["data"], dict): - for _, value in json["data"].items(): + for value in json["data"].values(): if value and value.get("ok") is False: raise_error( f"""Failed to {description}: { value.get('message')}""" ) - return json async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: @@ -211,15 +207,15 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: f"""Failed to parse JSON response: { e}, content: {response.content!r}""" ) - async def gql_helper( self, query: str, description: str, variables: Dict, process_response: Callable[..., R], + timeout: Optional[int] = 10, ) -> R: - response = await self.make_gql_call(description, query, variables) + response = await self.make_gql_call(description, query, variables, timeout) return process_response(response) ################################################################################## diff --git a/literalai/api/helpers/prompt_helpers.py b/literalai/api/helpers/prompt_helpers.py index a3a49d7..fe5cb23 100644 --- a/literalai/api/helpers/prompt_helpers.py +++ b/literalai/api/helpers/prompt_helpers.py @@ -7,7 +7,7 @@ if TYPE_CHECKING: from literalai.api import LiteralAPI - from literalai.api import SharedCache + from literalai.cache.shared_cache import SharedCache from literalai.api.helpers import gql diff --git a/literalai/api/synchronous.py b/literalai/api/synchronous.py index ba9cce7..7ec2156 100644 --- a/literalai/api/synchronous.py +++ b/literalai/api/synchronous.py @@ -144,19 +144,19 @@ class LiteralAPI(BaseLiteralAPI): R = TypeVar("R") def make_gql_call( - self, description: str, query: str, variables: Dict[str, Any] - ) -> Dict: + self, description: str, query: str, variables: dict[str, Any], timeout: Optional[int] = 10 + ) -> dict: def raise_error(error): logger.error(f"Failed to {description}: {error}") raise Exception(error) - variables = prepare_variables(variables) + variables = _prepare_variables(variables) with httpx.Client(follow_redirects=True) as client: response = client.post( self.graphql_endpoint, json={"query": query, "variables": variables}, headers=self.headers, - timeout=10, + timeout=timeout, ) try: @@ -177,7 +177,7 @@ def raise_error(error): if json.get("data"): if isinstance(json["data"], dict): - for _, value in json["data"].items(): + for value in json["data"].values(): if value and value.get("ok") is False: raise_error( f"""Failed to {description}: { @@ -186,7 +186,6 @@ def raise_error(error): return json - def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: with httpx.Client(follow_redirects=True) as client: response = client.post( From c5faa02c5b9e2a5a99290702488593638134eb86 Mon Sep 17 00:00:00 2001 From: Matthieu Olenga Date: Thu, 28 Nov 2024 11:41:57 +0100 Subject: [PATCH 21/21] fix: more ci fixes --- literalai/api/asynchronous.py | 5 ++--- literalai/api/synchronous.py | 8 ++------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/literalai/api/asynchronous.py b/literalai/api/asynchronous.py index 3ef6871..ddb8538 100644 --- a/literalai/api/asynchronous.py +++ b/literalai/api/asynchronous.py @@ -3,7 +3,6 @@ from typing_extensions import deprecated from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -148,7 +147,7 @@ def raise_error(error): logger.error(f"Failed to {description}: {error}") raise Exception(error) - variables = _prepare_variables(variables) + variables = prepare_variables(variables) async with httpx.AsyncClient(follow_redirects=True) as client: response = await client.post( @@ -443,7 +442,7 @@ async def upload_file( # Prepare form data form_data = ( {} - ) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]] + ) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]] for field_name, field_value in fields.items(): form_data[field_name] = (None, field_value) diff --git a/literalai/api/synchronous.py b/literalai/api/synchronous.py index 7ec2156..43455ee 100644 --- a/literalai/api/synchronous.py +++ b/literalai/api/synchronous.py @@ -3,7 +3,6 @@ from typing_extensions import deprecated from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -105,9 +104,6 @@ from literalai.observability.thread import Thread from literalai.prompt_engineering.prompt import Prompt, ProviderSettings -if TYPE_CHECKING: - from typing import Tuple # noqa: F401 - import httpx from literalai.my_types import PaginatedResponse, User @@ -150,7 +146,7 @@ def raise_error(error): logger.error(f"Failed to {description}: {error}") raise Exception(error) - variables = _prepare_variables(variables) + variables = prepare_variables(variables) with httpx.Client(follow_redirects=True) as client: response = client.post( self.graphql_endpoint, @@ -441,7 +437,7 @@ def upload_file( # Prepare form data form_data = ( {} - ) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]] + ) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]] for field_name, field_value in fields.items(): form_data[field_name] = (None, field_value)