diff --git a/literalai/api/asynchronous.py b/literalai/api/asynchronous.py index 8d832dd..ddb8538 100644 --- a/literalai/api/asynchronous.py +++ b/literalai/api/asynchronous.py @@ -3,7 +3,6 @@ from typing_extensions import deprecated from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -106,9 +105,6 @@ from literalai.observability.thread import Thread from literalai.prompt_engineering.prompt import Prompt, ProviderSettings -if TYPE_CHECKING: - from typing import Tuple # noqa: F401 - import httpx from literalai.my_types import PaginatedResponse, User @@ -145,7 +141,7 @@ class AsyncLiteralAPI(BaseLiteralAPI): R = TypeVar("R") async def make_gql_call( - self, description: str, query: str, variables: Dict[str, Any] + self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10 ) -> Dict: def raise_error(error): logger.error(f"Failed to {description}: {error}") @@ -158,7 +154,7 @@ def raise_error(error): self.graphql_endpoint, json={"query": query, "variables": variables}, headers=self.headers, - timeout=10, + timeout=timeout, ) try: @@ -179,13 +175,12 @@ def raise_error(error): if json.get("data"): if isinstance(json["data"], dict): - for _, value in json["data"].items(): + for value in json["data"].values(): if value and value.get("ok") is False: raise_error( f"""Failed to {description}: { value.get('message')}""" ) - return json async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: @@ -211,15 +206,15 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: f"""Failed to parse JSON response: { e}, content: {response.content!r}""" ) - async def gql_helper( self, query: str, description: str, variables: Dict, process_response: Callable[..., R], + timeout: Optional[int] = 10, ) -> R: - response = await self.make_gql_call(description, query, variables) + response = await self.make_gql_call(description, query, variables, timeout) return process_response(response) ################################################################################## @@ -447,7 +442,7 @@ async def upload_file( # Prepare form data form_data = ( {} - ) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]] + ) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]] for field_name, field_value in fields.items(): form_data[field_name] = (None, field_value) @@ -838,16 +833,32 @@ async def get_prompt( id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = None, - ) -> "Prompt": + ) -> Prompt: + if not (id or name): + raise ValueError("At least the `id` or the `name` must be provided.") + sync_api = LiteralAPI(self.api_key, self.url) - if id: - return await self.gql_helper(*get_prompt_helper(sync_api, id=id)) - elif name: - return await self.gql_helper( - *get_prompt_helper(sync_api, name=name, version=version) - ) - else: - raise ValueError("Either the `id` or the `name` must be provided.") + get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper( + api=sync_api, id=id, name=name, version=version, cache=self.cache + ) + + try: + if id: + prompt = await self.gql_helper( + get_prompt_query, description, variables, process_response, timeout + ) + elif name: + prompt = await self.gql_helper( + get_prompt_query, description, variables, process_response, timeout + ) + + return prompt + + except Exception as e: + if cached_prompt: + logger.warning("Failed to get prompt from API, returning cached prompt") + return cached_prompt + raise e async def update_prompt_ab_testing( self, name: str, rollouts: List["PromptRollout"] diff --git a/literalai/api/base.py b/literalai/api/base.py index b011e05..2b5f887 100644 --- a/literalai/api/base.py +++ b/literalai/api/base.py @@ -13,6 +13,7 @@ from literalai.my_types import Environment +from literalai.cache.shared_cache import SharedCache from literalai.evaluation.dataset import DatasetType from literalai.evaluation.dataset_experiment import ( DatasetExperimentItem, @@ -95,6 +96,8 @@ def __init__( self.graphql_endpoint = self.url + "/api/graphql" self.rest_endpoint = self.url + "/api" + self.cache = SharedCache() + @property def headers(self): from literalai.version import __version__ @@ -1011,9 +1014,9 @@ def get_prompt( """ Gets a prompt either by: - `id` - - or `name` and (optional) `version` + - `name` and (optional) `version` - Either the `id` or the `name` must be provided. + At least the `id` or the `name` must be passed to the function. If both are provided, the `id` is used. Args: diff --git a/literalai/api/helpers/prompt_helpers.py b/literalai/api/helpers/prompt_helpers.py index 14f2613..fe5cb23 100644 --- a/literalai/api/helpers/prompt_helpers.py +++ b/literalai/api/helpers/prompt_helpers.py @@ -1,10 +1,13 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict +from typing import TYPE_CHECKING, Optional, TypedDict, Callable from literalai.observability.generation import GenerationMessage from literalai.prompt_engineering.prompt import Prompt, ProviderSettings +from literalai.cache.prompt_helpers import put_prompt + if TYPE_CHECKING: from literalai.api import LiteralAPI + from literalai.cache.shared_cache import SharedCache from literalai.api.helpers import gql @@ -36,9 +39,9 @@ def process_response(response): def create_prompt_helper( api: "LiteralAPI", lineage_id: str, - template_messages: List[GenerationMessage], + template_messages: list[GenerationMessage], settings: Optional[ProviderSettings] = None, - tools: Optional[List[Dict]] = None, + tools: Optional[list[dict]] = None, ): variables = { "lineageId": lineage_id, @@ -56,28 +59,52 @@ def process_response(response): return gql.CREATE_PROMPT_VERSION, description, variables, process_response +def get_prompt_cache_key(id: Optional[str], name: Optional[str], version: Optional[int]) -> str: + if id: + return id + elif name and version: + return f"{name}-{version}" + elif name: + return name + else: + raise ValueError("Either the `id` or the `name` must be provided.") + + def get_prompt_helper( api: "LiteralAPI", id: Optional[str] = None, name: Optional[str] = None, version: Optional[int] = 0, -): + cache: Optional["SharedCache"] = None, +) -> tuple[str, str, dict, Callable, int, Optional[Prompt]]: + """Helper function for getting prompts with caching logic""" + + cached_prompt = None + timeout = 10 + + if cache: + cached_prompt = cache.get(get_prompt_cache_key(id, name, version)) + timeout = 1 if cached_prompt else timeout + variables = {"id": id, "name": name, "version": version} def process_response(response): - prompt = response["data"]["promptVersion"] - return Prompt.from_dict(api, prompt) if prompt else None + prompt_version = response["data"]["promptVersion"] + prompt = Prompt.from_dict(api, prompt_version) if prompt_version else None + if cache and prompt: + put_prompt(cache, prompt) + return prompt description = "get prompt" - return gql.GET_PROMPT_VERSION, description, variables, process_response + return gql.GET_PROMPT_VERSION, description, variables, process_response, timeout, cached_prompt def create_prompt_variant_helper( from_lineage_id: Optional[str] = None, - template_messages: List[GenerationMessage] = [], + template_messages: list[GenerationMessage] = [], settings: Optional[ProviderSettings] = None, - tools: Optional[List[Dict]] = None, + tools: Optional[list[dict]] = None, ): variables = { "fromLineageId": from_lineage_id, @@ -105,7 +132,7 @@ def get_prompt_ab_testing_helper( ): variables = {"lineageName": name} - def process_response(response) -> List[PromptRollout]: + def process_response(response) -> list[PromptRollout]: response_data = response["data"]["promptLineageRollout"] return list(map(lambda x: x["node"], response_data["edges"])) @@ -114,10 +141,10 @@ def process_response(response) -> List[PromptRollout]: return gql.GET_PROMPT_AB_TESTING, description, variables, process_response -def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]): +def update_prompt_ab_testing_helper(name: str, rollouts: list[PromptRollout]): variables = {"name": name, "rollouts": rollouts} - def process_response(response) -> Dict: + def process_response(response) -> dict: return response["data"]["updatePromptLineageRollout"] description = "update prompt A/B testing" diff --git a/literalai/api/synchronous.py b/literalai/api/synchronous.py index 3266624..43455ee 100644 --- a/literalai/api/synchronous.py +++ b/literalai/api/synchronous.py @@ -3,7 +3,6 @@ from typing_extensions import deprecated from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -105,9 +104,6 @@ from literalai.observability.thread import Thread from literalai.prompt_engineering.prompt import Prompt, ProviderSettings -if TYPE_CHECKING: - from typing import Tuple # noqa: F401 - import httpx from literalai.my_types import PaginatedResponse, User @@ -144,8 +140,8 @@ class LiteralAPI(BaseLiteralAPI): R = TypeVar("R") def make_gql_call( - self, description: str, query: str, variables: Dict[str, Any] - ) -> Dict: + self, description: str, query: str, variables: dict[str, Any], timeout: Optional[int] = 10 + ) -> dict: def raise_error(error): logger.error(f"Failed to {description}: {error}") raise Exception(error) @@ -156,7 +152,7 @@ def raise_error(error): self.graphql_endpoint, json={"query": query, "variables": variables}, headers=self.headers, - timeout=10, + timeout=timeout, ) try: @@ -177,7 +173,7 @@ def raise_error(error): if json.get("data"): if isinstance(json["data"], dict): - for _, value in json["data"].items(): + for value in json["data"].values(): if value and value.get("ok") is False: raise_error( f"""Failed to {description}: { @@ -186,7 +182,6 @@ def raise_error(error): return json - def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict: with httpx.Client(follow_redirects=True) as client: response = client.post( @@ -217,8 +212,9 @@ def gql_helper( description: str, variables: Dict, process_response: Callable[..., R], + timeout: Optional[int] = None, ) -> R: - response = self.make_gql_call(description, query, variables) + response = self.make_gql_call(description, query, variables, timeout) return process_response(response) ################################################################################## @@ -441,7 +437,7 @@ def upload_file( # Prepare form data form_data = ( {} - ) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]] + ) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]] for field_name, field_value in fields.items(): form_data[field_name] = (None, field_value) @@ -805,12 +801,27 @@ def get_prompt( name: Optional[str] = None, version: Optional[int] = None, ) -> "Prompt": - if id: - return self.gql_helper(*get_prompt_helper(self, id=id)) - elif name: - return self.gql_helper(*get_prompt_helper(self, name=name, version=version)) - else: - raise ValueError("Either the `id` or the `name` must be provided.") + if not (id or name): + raise ValueError("At least the `id` or the `name` must be provided.") + + get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper( + api=self,id=id, name=name, version=version, cache=self.cache + ) + + try: + if id: + prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout) + elif name: + prompt = self.gql_helper(get_prompt_query, description, variables, process_response, timeout) + + return prompt + + except Exception as e: + if cached_prompt: + logger.warning("Failed to get prompt from API, returning cached prompt") + return cached_prompt + + raise e def create_prompt_variant( self, diff --git a/literalai/cache/__init__.py b/literalai/cache/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/literalai/cache/prompt_helpers.py b/literalai/cache/prompt_helpers.py new file mode 100644 index 0000000..56646f9 --- /dev/null +++ b/literalai/cache/prompt_helpers.py @@ -0,0 +1,8 @@ +from literalai.prompt_engineering.prompt import Prompt +from literalai.cache.shared_cache import SharedCache + + +def put_prompt(cache: SharedCache, prompt: Prompt): + cache.put(prompt.id, prompt) + cache.put(prompt.name, prompt) + cache.put(f"{prompt.name}-{prompt.version}", prompt) diff --git a/literalai/cache/shared_cache.py b/literalai/cache/shared_cache.py new file mode 100644 index 0000000..b356f32 --- /dev/null +++ b/literalai/cache/shared_cache.py @@ -0,0 +1,42 @@ +from typing import Any, Optional + + +class SharedCache: + """ + Singleton cache for storing data. + Only one instance will exist regardless of how many times it's instantiated. + """ + _instance = None + _cache: dict[str, Any] + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._cache = {} + return cls._instance + + def get_cache(self) -> dict[str, Any]: + return self._cache + + def get(self, key: str) -> Optional[Any]: + """ + Retrieves a value from the cache using the provided key. + """ + if not isinstance(key, str): + raise TypeError("Key must be a string") + return self._cache.get(key) + + def put(self, key: str, value: Any): + """ + Stores a value in the cache. + """ + if not isinstance(key, str): + raise TypeError("Key must be a string") + self._cache[key] = value + + def clear(self) -> None: + """ + Clears all cached values. + """ + self._cache.clear() + diff --git a/run-test.sh b/run-test.sh index 2a51465..b256ce3 100755 --- a/run-test.sh +++ b/run-test.sh @@ -1 +1 @@ -LITERAL_API_URL=http://localhost:3000 LITERAL_API_KEY=my-initial-api-key pytest -m e2e -s -v +LITERAL_API_URL=http://localhost:3000 LITERAL_API_KEY=my-initial-api-key pytest -m e2e -s -v tests/e2e/ tests/unit/ diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 3b58199..426670b 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -662,6 +662,20 @@ async def test_prompt(self, async_client: AsyncLiteralClient): assert messages[0]["content"] == expected + @pytest.mark.timeout(5) + async def test_prompt_cache(self, async_client: AsyncLiteralClient): + prompt = await async_client.api.get_prompt(name="Default", version=0) + assert prompt is not None + + original_key = async_client.api.api_key + async_client.api.api_key = "invalid-api-key" + + cached_prompt = await async_client.api.get_prompt(name="Default", version=0) + assert cached_prompt is not None + assert cached_prompt.id == prompt.id + + async_client.api.api_key = original_key + @pytest.mark.timeout(5) async def test_prompt_ab_testing(self, client: LiteralClient): prompt_name = "Python SDK E2E Tests" diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py new file mode 100644 index 0000000..ccf751c --- /dev/null +++ b/tests/unit/test_cache.py @@ -0,0 +1,93 @@ +import pytest + +from literalai.prompt_engineering.prompt import Prompt +from literalai.api import LiteralAPI +from literalai.cache.shared_cache import SharedCache +from literalai.cache.prompt_helpers import put_prompt + +def default_prompt(id: str = "1", name: str = "test", version: int = 1) -> Prompt: + return Prompt( + api=LiteralAPI(), + id=id, + name=name, + version=version, + created_at="", + updated_at="", + type="chat", # type: ignore + url="", + version_desc=None, + template_messages=[], + tools=None, + provider="", + settings={}, + variables=[], + variables_default_values=None + ) + +def test_singleton_instance(): + """Test that SharedCache maintains singleton pattern""" + cache1 = SharedCache() + cache2 = SharedCache() + assert cache1 is cache2 + +def test_get_empty_cache(): + """Test getting from empty cache returns None""" + cache = SharedCache() + cache.clear() + + assert cache.get_cache() == {} + +def test_put_and_get_prompt_by_id_by_name_version_by_name(): + """Test storing and retrieving prompt by ID by name-version by name""" + cache = SharedCache() + cache.clear() + + prompt = default_prompt() + put_prompt(cache, prompt) + + retrieved_by_id = cache.get(id="1") + assert retrieved_by_id is prompt + + retrieved_by_name_version = cache.get(name="test", version=1) + assert retrieved_by_name_version is prompt + + retrieved_by_name = cache.get(name="test") + assert retrieved_by_name is prompt + +def test_clear_cache(): + """Test clearing the cache""" + cache = SharedCache() + prompt = default_prompt() + put_prompt(cache, prompt) + + cache.clear() + assert cache.get_cache() == {} + +def test_update_existing_prompt(): + """Test updating an existing prompt""" + cache = SharedCache() + cache.clear() + + prompt1 = default_prompt() + prompt2 = default_prompt(id="1", version=2) + + cache.put_prompt(prompt1) + cache.put_prompt(prompt2) + + retrieved = cache.get(id="1") + assert retrieved is prompt2 + assert retrieved.version == 2 + +def test_error_handling(): + """Test error handling for invalid inputs""" + cache = SharedCache() + cache.clear() + + assert cache.get_cache() == {} + assert cache.get(key="") is None + + with pytest.raises(TypeError): + cache.get(5) # type: ignore + + with pytest.raises(TypeError): + cache.put(5, "test") # type: ignore \ No newline at end of file