Skip to content

fix: thread/step concurrency #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions literalai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from literalai.client import AsyncLiteralClient, LiteralClient
from literalai.evaluation.dataset import Dataset
from literalai.evaluation.dataset_experiment import (
DatasetExperiment,
DatasetExperimentItem,
)
from literalai.evaluation.dataset_item import DatasetItem
from literalai.evaluation.dataset_experiment import DatasetExperiment, DatasetExperimentItem
from literalai.prompt_engineering.prompt import Prompt
from literalai.my_types import * # noqa
from literalai.observability.generation import (
BaseGeneration,
Expand All @@ -13,6 +15,7 @@
from literalai.observability.message import Message
from literalai.observability.step import Attachment, Score, Step
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt
from literalai.version import __version__

__all__ = [
Expand Down
79 changes: 41 additions & 38 deletions literalai/api/asynchronous.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
import logging
import uuid
from typing import Any, Callable, Dict, List, Literal, Optional, TypeVar, Union, cast

import httpx
from typing_extensions import deprecated
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
TypeVar,
Union,
cast,
)

from literalai.api.base import BaseLiteralAPI, prepare_variables

from literalai.api.helpers.attachment_helpers import (
AttachmentUpload,
create_attachment_helper,
Expand Down Expand Up @@ -91,6 +81,7 @@
DatasetExperimentItem,
)
from literalai.evaluation.dataset_item import DatasetItem
from literalai.my_types import PaginatedResponse, User
from literalai.observability.filter import (
generations_filters,
generations_order_by,
Expand All @@ -102,12 +93,6 @@
threads_order_by,
users_filters,
)
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

import httpx

from literalai.my_types import PaginatedResponse, User
from literalai.observability.generation import (
BaseGeneration,
ChatGeneration,
Expand All @@ -123,6 +108,8 @@
StepDict,
StepType,
)
from literalai.observability.thread import Thread
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

logger = logging.getLogger(__name__)

Expand All @@ -141,7 +128,11 @@ class AsyncLiteralAPI(BaseLiteralAPI):
R = TypeVar("R")

async def make_gql_call(
self, description: str, query: str, variables: Dict[str, Any], timeout: Optional[int] = 10
self,
description: str,
query: str,
variables: Dict[str, Any],
timeout: Optional[int] = 10,
) -> Dict:
def raise_error(error):
logger.error(f"Failed to {description}: {error}")
Expand All @@ -166,8 +157,7 @@ def raise_error(error):
json = response.json()
except ValueError as e:
raise_error(
f"""Failed to parse JSON response: {
e}, content: {response.content!r}"""
f"Failed to parse JSON response: {e}, content: {response.content!r}"
)

if json.get("errors"):
Expand All @@ -178,8 +168,7 @@ def raise_error(error):
for value in json["data"].values():
if value and value.get("ok") is False:
raise_error(
f"""Failed to {description}: {
value.get('message')}"""
f"""Failed to {description}: {value.get("message")}"""
)
return json

Expand All @@ -203,9 +192,9 @@ async def make_rest_call(self, subpath: str, body: Dict[str, Any]) -> Dict:
return response.json()
except ValueError as e:
raise ValueError(
f"""Failed to parse JSON response: {
e}, content: {response.content!r}"""
f"Failed to parse JSON response: {e}, content: {response.content!r}"
)

async def gql_helper(
self,
query: str,
Expand Down Expand Up @@ -235,7 +224,9 @@ async def get_user(
) -> "User":
return await self.gql_helper(*get_user_helper(id, identifier))

async def create_user(self, identifier: str, metadata: Optional[Dict] = None) -> "User":
async def create_user(
self, identifier: str, metadata: Optional[Dict] = None
) -> "User":
return await self.gql_helper(*create_user_helper(identifier, metadata))

async def update_user(
Expand All @@ -245,7 +236,7 @@ async def update_user(

async def delete_user(self, id: str) -> Dict:
return await self.gql_helper(*delete_user_helper(id))

async def get_or_create_user(
self, identifier: str, metadata: Optional[Dict] = None
) -> "User":
Expand Down Expand Up @@ -273,7 +264,7 @@ async def get_threads(
first, after, before, filters, order_by, step_types_to_keep
)
)

async def list_threads(
self,
first: Optional[int] = None,
Expand Down Expand Up @@ -491,7 +482,7 @@ async def create_attachment(
thread_id = active_thread.id

if not step_id:
if active_steps := active_steps_var.get([]):
if active_steps := active_steps_var.get():
step_id = active_steps[-1].id
else:
raise Exception("No step_id provided and no active step found.")
Expand Down Expand Up @@ -532,7 +523,9 @@ async def create_attachment(
response = await self.make_gql_call(description, query, variables)
return process_response(response)

async def update_attachment(self, id: str, update_params: AttachmentUpload) -> "Attachment":
async def update_attachment(
self, id: str, update_params: AttachmentUpload
) -> "Attachment":
return await self.gql_helper(*update_attachment_helper(id, update_params))

async def get_attachment(self, id: str) -> Optional["Attachment"]:
Expand All @@ -545,7 +538,6 @@ async def delete_attachment(self, id: str) -> Dict:
# Step APIs #
##################################################################################


async def create_step(
self,
thread_id: Optional[str] = None,
Expand Down Expand Up @@ -646,7 +638,7 @@ async def get_generations(
return await self.gql_helper(
*get_generations_helper(first, after, before, filters, order_by)
)

async def create_generation(
self, generation: Union["ChatGeneration", "CompletionGeneration"]
) -> Union["ChatGeneration", "CompletionGeneration"]:
Expand All @@ -667,8 +659,10 @@ async def create_dataset(
return await self.gql_helper(
*create_dataset_helper(sync_api, name, description, metadata, type)
)

async def get_dataset(self, id: Optional[str] = None, name: Optional[str] = None) -> "Dataset":

async def get_dataset(
self, id: Optional[str] = None, name: Optional[str] = None
) -> "Dataset":
sync_api = LiteralAPI(self.api_key, self.url)
subpath, _, variables, process_response = get_dataset_helper(
sync_api, id=id, name=name
Expand Down Expand Up @@ -738,7 +732,7 @@ async def create_experiment_item(
result.scores = await self.create_scores(experiment_item.scores)

return result

##################################################################################
# DatasetItem APIs #
##################################################################################
Expand All @@ -753,7 +747,7 @@ async def create_dataset_item(
return await self.gql_helper(
*create_dataset_item_helper(dataset_id, input, expected_output, metadata)
)

async def get_dataset_item(self, id: str) -> "DatasetItem":
return await self.gql_helper(*get_dataset_item_helper(id))

Expand Down Expand Up @@ -784,7 +778,9 @@ async def get_or_create_prompt_lineage(
return await self.gql_helper(*create_prompt_lineage_helper(name, description))

@deprecated('Please use "get_or_create_prompt_lineage" instead.')
async def create_prompt_lineage(self, name: str, description: Optional[str] = None) -> Dict:
async def create_prompt_lineage(
self, name: str, description: Optional[str] = None
) -> Dict:
return await self.get_or_create_prompt_lineage(name, description)

async def get_or_create_prompt(
Expand Down Expand Up @@ -838,7 +834,14 @@ async def get_prompt(
raise ValueError("At least the `id` or the `name` must be provided.")

sync_api = LiteralAPI(self.api_key, self.url)
get_prompt_query, description, variables, process_response, timeout, cached_prompt = get_prompt_helper(
(
get_prompt_query,
description,
variables,
process_response,
timeout,
cached_prompt,
) = get_prompt_helper(
api=sync_api, id=id, name=name, version=version, cache=self.cache
)

Expand Down
56 changes: 14 additions & 42 deletions literalai/api/base.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
import os

from abc import ABC, abstractmethod
from typing import (
Any,
Dict,
List,
Optional,
Union,
)
from typing import Any, Dict, List, Optional, Union

from typing_extensions import deprecated

from literalai.my_types import Environment

from literalai.api.helpers.attachment_helpers import AttachmentUpload
from literalai.api.helpers.prompt_helpers import PromptRollout
from literalai.api.helpers.score_helpers import ScoreUpdate
from literalai.cache.shared_cache import SharedCache
from literalai.evaluation.dataset import DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperimentItem,
)
from literalai.api.helpers.attachment_helpers import (
AttachmentUpload)
from literalai.api.helpers.score_helpers import (
ScoreUpdate,
)

from literalai.evaluation.dataset_experiment import DatasetExperimentItem
from literalai.my_types import Environment
from literalai.observability.filter import (
generations_filters,
generations_order_by,
Expand All @@ -35,24 +22,14 @@
threads_order_by,
users_filters,
)
from literalai.prompt_engineering.prompt import ProviderSettings


from literalai.api.helpers.prompt_helpers import (
PromptRollout)

from literalai.observability.generation import (
ChatGeneration,
CompletionGeneration,
GenerationMessage,
)
from literalai.observability.step import (
ScoreDict,
ScoreType,
Step,
StepDict,
StepType,
)
from literalai.observability.step import ScoreDict, ScoreType, Step, StepDict, StepType
from literalai.prompt_engineering.prompt import ProviderSettings


def prepare_variables(variables: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand All @@ -72,6 +49,7 @@ def handle_bytes(item):

return handle_bytes(variables)


class BaseLiteralAPI(ABC):
def __init__(
self,
Expand Down Expand Up @@ -676,7 +654,7 @@ def delete_step(
@abstractmethod
def send_steps(self, steps: List[Union[StepDict, "Step"]]):
"""
Sends a list of steps to process.
Sends a list of steps to process.
Step ingestion happens asynchronously if you configured a cache. See [Cache Configuration](https://docs.literalai.com/self-hosting/deployment#4-cache-configuration-optional).

Args:
Expand Down Expand Up @@ -773,9 +751,7 @@ def create_dataset(
pass

@abstractmethod
def get_dataset(
self, id: Optional[str] = None, name: Optional[str] = None
):
def get_dataset(self, id: Optional[str] = None, name: Optional[str] = None):
"""
Retrieves a dataset by its ID or name.

Expand Down Expand Up @@ -846,9 +822,7 @@ def create_experiment(
pass

@abstractmethod
def create_experiment_item(
self, experiment_item: DatasetExperimentItem
):
def create_experiment_item(self, experiment_item: DatasetExperimentItem):
"""
Creates an experiment item within an existing experiment.

Expand Down Expand Up @@ -1065,9 +1039,7 @@ def get_prompt_ab_testing(self, name: str):
pass

@abstractmethod
def update_prompt_ab_testing(
self, name: str, rollouts: List[PromptRollout]
):
def update_prompt_ab_testing(self, name: str, rollouts: List[PromptRollout]):
"""
Update the A/B testing configuration for a prompt lineage.

Expand Down
11 changes: 6 additions & 5 deletions literalai/api/helpers/generation_helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Any, Dict, Optional, Union

from literalai.api.helpers import gql
from literalai.my_types import PaginatedResponse
from literalai.observability.filter import generations_filters, generations_order_by
from literalai.my_types import (
PaginatedResponse,
from literalai.observability.generation import (
BaseGeneration,
ChatGeneration,
CompletionGeneration,
)
from literalai.observability.generation import BaseGeneration, CompletionGeneration, ChatGeneration

from literalai.api.helpers import gql


def get_generations_helper(
Expand Down
Loading