Skip to content

feat: add caching if prompt request fails #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4ddda3e
feat: create the dict cache and the method to go with it
Matthieu-OD Nov 14, 2024
fd9f462
feat: get_prompt add caching
Matthieu-OD Nov 14, 2024
8774a00
feat: implement caching on get_prompt
Matthieu-OD Nov 14, 2024
5d8b5f7
fix: ci
Matthieu-OD Nov 14, 2024
e7589c6
feat: add timeout if prompt cached
Matthieu-OD Nov 14, 2024
723f7fd
feat: improve caching
Matthieu-OD Nov 14, 2024
32b971f
feat: improve logging
Matthieu-OD Nov 14, 2024
5bcdce4
fix: ci errors
Matthieu-OD Nov 14, 2024
2476ebc
feat: improve the prompt cache class
Matthieu-OD Nov 15, 2024
0aec701
refactor: remove useless code
Matthieu-OD Nov 15, 2024
32b4e48
feat: implement the new SharedCachePrompt class
Matthieu-OD Nov 15, 2024
f5d460b
refactor: improve typing and move some logic
Matthieu-OD Nov 15, 2024
49fd140
feat: adds memory management to the SharedCachePrompt class
Matthieu-OD Nov 18, 2024
3e139f2
feat: add unit tests for SharedCachePrompt
Matthieu-OD Nov 18, 2024
3730581
feat: adds tests and updates run-test.sh
Matthieu-OD Nov 18, 2024
5318751
refactor: finishes the simplication
Matthieu-OD Nov 19, 2024
85c72d1
fix: test and implementation
Matthieu-OD Nov 20, 2024
6dfce9c
fix: add typing for sharedcache typing
Matthieu-OD Nov 20, 2024
06c5047
feat: align with literalai-typescript chagnes
Matthieu-OD Nov 21, 2024
e3c7ea0
Merge branch 'main' into matt/eng-2115-add-client-caching-for-prompts
Matthieu-OD Nov 28, 2024
cf98d74
fix: ci
Matthieu-OD Nov 28, 2024
c5faa02
fix: more ci fixes
Matthieu-OD Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 116 additions & 36 deletions literalai/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from threading import Lock
import logging
import os
import uuid
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -102,9 +102,6 @@
)
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

if TYPE_CHECKING:
from typing import Tuple # noqa: F401

import httpx

from literalai.my_types import Environment, PaginatedResponse
Expand Down Expand Up @@ -145,6 +142,59 @@ def handle_bytes(item):
return handle_bytes(variables)


class SharedPromptCache:
"""
Thread-safe singleton cache for storing prompts.
Only one instance will exist regardless of how many times it's instantiated.
"""
_instance = None
_lock = Lock()

def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)

cls._instance._prompts: dict[str, Prompt] = {}
cls._instance._name_index: dict[str, str] = {}
cls._instance._name_version_index: dict[tuple[str, int], str] = {}
return cls._instance

def get(
self,
id: Optional[str] = None,
name: Optional[str] = None,
version: Optional[int] = None
) -> Optional[Prompt]:
"""
Retrieves a prompt using the most specific criteria provided.
Lookup priority: id, name-version, name
"""
if id:
prompt_id = id
elif name and version:
prompt_id = self._name_version_index.get((name, version))
elif name:
prompt_id = self._name_index.get(name)

return self._prompts.get(prompt_id) if prompt_id else None

def put(self, prompt: Prompt):
with self._lock:
self._prompts[prompt.id] = prompt
self._name_index[prompt.name] = prompt.id
self._name_version_index[(prompt.name, prompt.version)] = prompt.id

def clear(self) -> None:
"""
Clears all cached promopts and indices.
"""
with self._lock:
self._prompts.clear()
self._name_index.clear()
self._name_version_index.clear()


class BaseLiteralAPI:
def __init__(
self,
Expand All @@ -169,6 +219,8 @@ def __init__(
self.graphql_endpoint = self.url + "/api/graphql"
self.rest_endpoint = self.url + "/api"

self.prompt_cache = SharedPromptCache()

@property
def headers(self):
from literalai.version import __version__
Expand All @@ -186,6 +238,7 @@ def headers(self):
return h



class LiteralAPI(BaseLiteralAPI):
"""
```python
Expand All @@ -200,8 +253,8 @@ class LiteralAPI(BaseLiteralAPI):
R = TypeVar("R")

def make_gql_call(
self, description: str, query: str, variables: Dict[str, Any]
) -> Dict:
self, description: str, query: str, variables: dict[str, Any], timeout: Optional[int] = 10
) -> dict:
def raise_error(error):
logger.error(f"Failed to {description}: {error}")
raise Exception(error)
Expand All @@ -212,7 +265,7 @@ def raise_error(error):
self.graphql_endpoint,
json={"query": query, "variables": variables},
headers=self.headers,
timeout=10,
timeout=timeout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the modification for the async version of make_gql_call

Copy link
Contributor Author

@Matthieu-OD Matthieu-OD Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is lines 1519 and 1532 in the same file

)

try:
Expand All @@ -233,7 +286,7 @@ def raise_error(error):

if json.get("data"):
if isinstance(json["data"], dict):
for key, value in json["data"].items():
for value in json["data"].values():
if value and value.get("ok") is False:
raise_error(
f"""Failed to {description}: {
Expand All @@ -242,10 +295,6 @@ def raise_error(error):

return json

# This should not be reached, exceptions should be thrown beforehands
# Added because of mypy
raise Exception("Unknown error")

def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
with httpx.Client(follow_redirects=True) as client:
response = client.post(
Expand Down Expand Up @@ -276,8 +325,9 @@ def gql_helper(
description: str,
variables: Dict,
process_response: Callable[..., R],
timeout: Optional[int] = None,
) -> R:
response = self.make_gql_call(description, query, variables)
response = self.make_gql_call(description, query, variables, timeout)
return process_response(response)

# User API
Expand Down Expand Up @@ -684,7 +734,7 @@ def upload_file(
# Prepare form data
form_data = (
{}
) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]]
) # type: Dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
for field_name, field_value in fields.items():
form_data[field_name] = (None, field_value)

Expand Down Expand Up @@ -1362,13 +1412,28 @@ def get_prompt(
Returns:
Prompt: The prompt with the given identifier or name.
"""
if id:
return self.gql_helper(*get_prompt_helper(self, id=id))
elif name:
return self.gql_helper(*get_prompt_helper(self, name=name, version=version))
else:
if not (id or name):
raise ValueError("Either the `id` or the `name` must be provided.")

cached_prompt = self.prompt_cache.get(id, name, version)
timeout = 1 if cached_prompt else None

try:
if id:
prompt = self.gql_helper(*get_prompt_helper(self, id=id, timeout=timeout))
elif name:
prompt = self.gql_helper(*get_prompt_helper(self, name=name, version=version, timeout=timeout))

self.prompt_cache.put(prompt)
return prompt

except Exception as e:
if cached_prompt:
logger.warning("Failed to get prompt from API, returning cached prompt")
return cached_prompt

raise e

def create_prompt_variant(
self,
name: str,
Expand Down Expand Up @@ -1451,7 +1516,7 @@ class AsyncLiteralAPI(BaseLiteralAPI):
R = TypeVar("R")

async def make_gql_call(
self, description: str, query: str, variables: Dict[str, Any]
self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10
) -> Dict:
def raise_error(error):
logger.error(f"Failed to {description}: {error}")
Expand All @@ -1464,7 +1529,7 @@ def raise_error(error):
self.graphql_endpoint,
json={"query": query, "variables": variables},
headers=self.headers,
timeout=10,
timeout=timeout,
)

try:
Expand All @@ -1485,7 +1550,7 @@ def raise_error(error):

if json.get("data"):
if isinstance(json["data"], dict):
for key, value in json["data"].items():
for value in json["data"].values():
if value and value.get("ok") is False:
raise_error(
f"""Failed to {description}: {
Expand All @@ -1494,10 +1559,6 @@ def raise_error(error):

return json

# This should not be reached, exceptions should be thrown beforehands
# Added because of mypy
raise Exception("Unkown error")

async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
async with httpx.AsyncClient(follow_redirects=True) as client:
response = await client.post(
Expand Down Expand Up @@ -1528,8 +1589,9 @@ async def gql_helper(
description: str,
variables: Dict,
process_response: Callable[..., R],
timeout: Optional[int] = 10,
) -> R:
response = await self.make_gql_call(description, query, variables)
response = await self.make_gql_call(description, query, variables, timeout)
return process_response(response)

async def get_users(
Expand Down Expand Up @@ -1963,7 +2025,7 @@ async def upload_file(
# Prepare form data
form_data = (
{}
) # type: Dict[str, Union[Tuple[Union[str, None], Any], Tuple[Union[str, None], Any, Any]]]
) # type: dict[str, Union[tuple[Union[str, None], Any], tuple[Union[str, None], Any, Any]]]
for field_name, field_value in fields.items():
form_data[field_name] = (None, field_value)

Expand Down Expand Up @@ -2593,16 +2655,34 @@ async def get_prompt(
name: Optional[str] = None,
version: Optional[int] = None,
) -> Prompt:
sync_api = LiteralAPI(self.api_key, self.url)
if id:
return await self.gql_helper(*get_prompt_helper(sync_api, id=id))
elif name:
return await self.gql_helper(
*get_prompt_helper(sync_api, name=name, version=version)
)
else:
if not (id or name):
raise ValueError("Either the `id` or the `name` must be provided.")

sync_api = LiteralAPI(self.api_key, self.url)
cached_prompt = self.prompt_cache.get(id, name, version)
timeout = 1 if cached_prompt else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could move the cache logic in the get_prompt_helper to avoid duplicating it for the sync/async versions.


try:
if id:
prompt = await self.gql_helper(
*get_prompt_helper(sync_api, id=id, timeout=timeout)
)
elif name:
prompt = await self.gql_helper(
*get_prompt_helper(
sync_api, name=name, version=version, timeout=timeout
)
)

self.prompt_cache.put(prompt)
return prompt

except Exception as e:
if cached_prompt:
logger.warning("Failed to get prompt from API, returning cached prompt")
return cached_prompt
raise e

get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__

async def update_prompt_ab_testing(
Expand Down
3 changes: 2 additions & 1 deletion literalai/api/prompt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def get_prompt_helper(
id: Optional[str] = None,
name: Optional[str] = None,
version: Optional[int] = 0,
timeout: Optional[int] = None,
):
variables = {"id": id, "name": name, "version": version}

Expand All @@ -70,7 +71,7 @@ def process_response(response):

description = "get prompt"

return gql.GET_PROMPT_VERSION, description, variables, process_response
return gql.GET_PROMPT_VERSION, description, variables, process_response, timeout


def create_prompt_variant_helper(
Expand Down
Loading