From f91f396f96883cd1110c7268cb9c4a48e9eed4fe Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Wed, 24 Jul 2024 11:48:11 +0200 Subject: [PATCH 1/6] feat: make datasetId optional for experiments --- literalai/api/__init__.py | 10 +++++----- literalai/api/dataset_helpers.py | 4 ++-- literalai/api/gql.py | 4 ++-- literalai/dataset_experiment.py | 18 +++++++----------- 4 files changed, 16 insertions(+), 20 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 2696b86..e0f61dd 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -1104,8 +1104,8 @@ def delete_dataset(self, id: str): def create_experiment( self, - dataset_id: str, name: str, + dataset_id: Optional[str] = None, prompt_id: Optional[str] = None, params: Optional[Dict] = None, ) -> "DatasetExperiment": @@ -1113,8 +1113,8 @@ def create_experiment( Creates a new experiment associated with a specific dataset. Args: - dataset_id (str): The unique identifier of the dataset. name (str): The name of the experiment. + dataset_id (Optional[str]): The unique identifier of the dataset. prompt_id (Optional[str]): The identifier of the prompt associated with the experiment. params (Optional[Dict]): Additional parameters for the experiment. @@ -1122,7 +1122,7 @@ def create_experiment( DatasetExperiment: The newly created experiment object. """ return self.gql_helper( - *create_experiment_helper(self, dataset_id, name, prompt_id, params) + *create_experiment_helper(self, name, dataset_id, prompt_id, params) ) def create_experiment_item( @@ -2315,15 +2315,15 @@ async def delete_dataset(self, id: str): async def create_experiment( self, - dataset_id: str, name: str, + dataset_id: Optional[str] = None, prompt_id: Optional[str] = None, params: Optional[Dict] = None, ) -> "DatasetExperiment": sync_api = LiteralAPI(self.api_key, self.url) return await self.gql_helper( - *create_experiment_helper(sync_api, dataset_id, name, prompt_id, params) + *create_experiment_helper(sync_api, name, dataset_id, prompt_id, params) ) create_experiment.__doc__ = LiteralAPI.create_experiment.__doc__ diff --git a/literalai/api/dataset_helpers.py b/literalai/api/dataset_helpers.py index 65f6694..7e03323 100644 --- a/literalai/api/dataset_helpers.py +++ b/literalai/api/dataset_helpers.py @@ -95,8 +95,8 @@ def process_response(response): def create_experiment_helper( api: "LiteralAPI", - dataset_id: str, name: str, + dataset_id: Optional[str] = None, prompt_id: Optional[str] = None, params: Optional[Dict] = None, ): @@ -119,7 +119,7 @@ def process_response(response): def create_experiment_item_helper( dataset_experiment_id: str, - dataset_item_id: str, + dataset_item_id: Optional[str] = None, input: Optional[Dict] = None, output: Optional[Dict] = None, ): diff --git a/literalai/api/gql.py b/literalai/api/gql.py index 5284ed8..e2adcc1 100644 --- a/literalai/api/gql.py +++ b/literalai/api/gql.py @@ -833,7 +833,7 @@ CREATE_EXPERIMENT = """ mutation CreateDatasetExperiment( $name: String! - $datasetId: String! + $datasetId: String $promptId: String $params: Json ) { @@ -854,7 +854,7 @@ CREATE_EXPERIMENT_ITEM = """ mutation CreateDatasetExperimentItem( $datasetExperimentId: String! - $datasetItemId: String! + $datasetItemId: String $input: Json $output: Json ) { diff --git a/literalai/dataset_experiment.py b/literalai/dataset_experiment.py index 6c85751..7254e00 100644 --- a/literalai/dataset_experiment.py +++ b/literalai/dataset_experiment.py @@ -1,6 +1,5 @@ from dataclasses import dataclass, field -from typing import Dict, List, Optional, TypedDict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict from literalai.my_types import ScoreDict, Utils @@ -11,7 +10,7 @@ class DatasetExperimentItemDict(TypedDict, total=False): id: str datasetExperimentId: str - datasetItemId: str + datasetItemId: Optional[str] scores: List[ScoreDict] input: Optional[Dict] output: Optional[Dict] @@ -21,7 +20,7 @@ class DatasetExperimentItemDict(TypedDict, total=False): class DatasetExperimentItem(Utils): id: str dataset_experiment_id: str - dataset_item_id: str + dataset_item_id: Optional[str] scores: List[ScoreDict] input: Optional[Dict] output: Optional[Dict] @@ -41,7 +40,7 @@ def from_dict(cls, item: DatasetExperimentItemDict) -> "DatasetExperimentItem": return cls( id=item.get("id", ""), dataset_experiment_id=item.get("datasetExperimentId", ""), - dataset_item_id=item.get("datasetItemId", ""), + dataset_item_id=item.get("datasetItemId"), scores=item.get("scores", []), input=item.get("input"), output=item.get("output"), @@ -64,7 +63,7 @@ class DatasetExperiment(Utils): id: str created_at: str name: str - dataset_id: str + dataset_id: Optional[str] params: Optional[Dict] prompt_id: Optional[str] = None items: List[DatasetExperimentItem] = field(default_factory=lambda: []) @@ -73,7 +72,7 @@ def log(self, item_dict: DatasetExperimentItemDict) -> DatasetExperimentItem: dataset_experiment_item = DatasetExperimentItem.from_dict( { "datasetExperimentId": self.id, - "datasetItemId": item_dict.get("datasetItemId", ""), + "datasetItemId": item_dict.get("datasetItemId"), "input": item_dict.get("input", {}), "output": item_dict.get("output", {}), "scores": item_dict.get("scores", []), @@ -110,8 +109,5 @@ def from_dict( dataset_id=dataset_experiment.get("datasetId", ""), params=dataset_experiment.get("params"), prompt_id=dataset_experiment.get("promptId"), - items=[ - DatasetExperimentItem.from_dict(item) - for item in items - ], + items=[DatasetExperimentItem.from_dict(item) for item in items], ) From c2b8003a006b99c2d84b817697dfffad09759bc5 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Thu, 25 Jul 2024 14:06:51 +0200 Subject: [PATCH 2/6] wip --- literalai/api/__init__.py | 34 +++++++-------- literalai/api/gql.py | 6 --- literalai/api/thread_helpers.py | 6 --- literalai/client.py | 51 ++++++++++++++++++++-- literalai/context.py | 4 ++ literalai/dataset_experiment.py | 15 ++++--- literalai/environment.py | 67 ++++++++++++++++++++++++++++ literalai/experiment_run.py | 77 +++++++++++++++++++++++++++++++++ literalai/my_types.py | 1 + 9 files changed, 221 insertions(+), 40 deletions(-) create mode 100644 literalai/environment.py create mode 100644 literalai/experiment_run.py diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 2696b86..7a1adba 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -100,6 +100,7 @@ Attachment, ChatGeneration, CompletionGeneration, + Environment, GenerationMessage, PaginatedResponse, Score, @@ -112,7 +113,12 @@ class BaseLiteralAPI: - def __init__(self, api_key: Optional[str] = None, url: Optional[str] = None): + def __init__( + self, + api_key: Optional[str] = None, + url: Optional[str] = None, + environment: Environment = "prod", + ): if url and url[-1] == "/": url = url[:-1] @@ -123,6 +129,7 @@ def __init__(self, api_key: Optional[str] = None, url: Optional[str] = None): self.api_key = api_key self.url = url + self.environment = environment self.graphql_endpoint = self.url + "/api/graphql" self.rest_endpoint = self.url + "/api" @@ -134,6 +141,7 @@ def headers(self): return { "Content-Type": "application/json", "x-api-key": self.api_key, + "x-env": self.environment, "x-client-name": "py-literal-client", "x-client-version": __version__, } @@ -441,7 +449,6 @@ def create_thread( name: Optional[str] = None, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, - environment: Optional[str] = None, tags: Optional[List[str]] = None, ): """ @@ -451,14 +458,13 @@ def create_thread( name (Optional[str]): Name of the thread. metadata (Optional[Dict]): Metadata associated with the thread. participant_id (Optional[str]): Identifier for the participant. - environment (Optional[str]): Environment in which the thread operates. tags (Optional[List[str]]): List of tags associated with the thread. Returns: The newly created thread. """ return self.gql_helper( - *create_thread_helper(name, metadata, participant_id, environment, tags) + *create_thread_helper(name, metadata, participant_id, tags) ) def upsert_thread( @@ -467,7 +473,6 @@ def upsert_thread( name: Optional[str] = None, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, - environment: Optional[str] = None, tags: Optional[List[str]] = None, ): """ @@ -478,14 +483,13 @@ def upsert_thread( name (Optional[str]): Name of the thread. metadata (Optional[Dict]): Metadata associated with the thread. participant_id (Optional[str]): Identifier for the participant. - environment (Optional[str]): Environment in which the thread operates. tags (Optional[List[str]]): List of tags associated with the thread. Returns: The updated or newly created thread. """ return self.gql_helper( - *upsert_thread_helper(id, name, metadata, participant_id, environment, tags) + *upsert_thread_helper(id, name, metadata, participant_id, tags) ) def update_thread( @@ -494,7 +498,6 @@ def update_thread( name: Optional[str] = None, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, - environment: Optional[str] = None, tags: Optional[List[str]] = None, ): """ @@ -505,14 +508,13 @@ def update_thread( name (Optional[str]): New name of the thread. metadata (Optional[Dict]): New metadata for the thread. participant_id (Optional[str]): New identifier for the participant. - environment (Optional[str]): New environment for the thread. tags (Optional[List[str]]): New list of tags for the thread. Returns: The updated thread. """ return self.gql_helper( - *update_thread_helper(id, name, metadata, participant_id, environment, tags) + *update_thread_helper(id, name, metadata, participant_id, tags) ) def delete_thread(self, id: str): @@ -1644,7 +1646,6 @@ async def create_thread( name: Optional[str] = None, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, - environment: Optional[str] = None, tags: Optional[List[str]] = None, ): """ @@ -1654,14 +1655,13 @@ async def create_thread( name (Optional[str]): The name of the thread. metadata (Optional[Dict]): Metadata associated with the thread. participant_id (Optional[str]): Identifier for the participant associated with the thread. - environment (Optional[str]): The environment in which the thread operates. tags (Optional[List[str]]): Tags associated with the thread. Returns: The result of the GraphQL helper function for creating a thread. """ return await self.gql_helper( - *create_thread_helper(name, metadata, participant_id, environment, tags) + *create_thread_helper(name, metadata, participant_id, tags) ) async def upsert_thread( @@ -1670,7 +1670,6 @@ async def upsert_thread( name: Optional[str] = None, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, - environment: Optional[str] = None, tags: Optional[List[str]] = None, ): """ @@ -1681,14 +1680,13 @@ async def upsert_thread( name (Optional[str]): The name of the thread. metadata (Optional[Dict]): Metadata associated with the thread. participant_id (Optional[str]): Identifier for the participant associated with the thread. - environment (Optional[str]): The environment in which the thread operates. tags (Optional[List[str]]): Tags associated with the thread. Returns: The result of the GraphQL helper function for upserting a thread. """ return await self.gql_helper( - *upsert_thread_helper(id, name, metadata, participant_id, environment, tags) + *upsert_thread_helper(id, name, metadata, participant_id, tags) ) async def update_thread( @@ -1697,7 +1695,6 @@ async def update_thread( name: Optional[str] = None, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, - environment: Optional[str] = None, tags: Optional[List[str]] = None, ): """ @@ -1708,14 +1705,13 @@ async def update_thread( name (Optional[str]): New name of the thread. metadata (Optional[Dict]): New metadata for the thread. participant_id (Optional[str]): New identifier for the participant. - environment (Optional[str]): New environment for the thread. tags (Optional[List[str]]): New list of tags for the thread. Returns: The result of the GraphQL helper function for updating a thread. """ return await self.gql_helper( - *update_thread_helper(id, name, metadata, participant_id, environment, tags) + *update_thread_helper(id, name, metadata, participant_id, tags) ) async def delete_thread(self, id: str): diff --git a/literalai/api/gql.py b/literalai/api/gql.py index 5284ed8..0540922 100644 --- a/literalai/api/gql.py +++ b/literalai/api/gql.py @@ -269,14 +269,12 @@ $name: String, $metadata: Json, $participantId: String, - $environment: String, $tags: [String!], ) { createThread( name: $name metadata: $metadata participantId: $participantId - environment: $environment tags: $tags ) { """ @@ -294,7 +292,6 @@ $name: String, $metadata: Json, $participantId: String, - $environment: String, $tags: [String!], ) { upsertThread( @@ -302,7 +299,6 @@ name: $name metadata: $metadata participantId: $participantId - environment: $environment tags: $tags ) { """ @@ -320,7 +316,6 @@ $name: String, $metadata: Json, $participantId: String, - $environment: String, $tags: [String!], ) { updateThread( @@ -328,7 +323,6 @@ name: $name metadata: $metadata participantId: $participantId - environment: $environment tags: $tags ) { """ diff --git a/literalai/api/thread_helpers.py b/literalai/api/thread_helpers.py index c88fe72..1b89ff4 100644 --- a/literalai/api/thread_helpers.py +++ b/literalai/api/thread_helpers.py @@ -89,14 +89,12 @@ def create_thread_helper( name: Optional[str] = None, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, - environment: Optional[str] = None, tags: Optional[List[str]] = None, ): variables = { "name": name, "metadata": metadata, "participantId": participant_id, - "environment": environment, "tags": tags, } @@ -113,7 +111,6 @@ def upsert_thread_helper( name: Optional[str] = None, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, - environment: Optional[str] = None, tags: Optional[List[str]] = None, ): variables = { @@ -121,7 +118,6 @@ def upsert_thread_helper( "name": name, "metadata": metadata, "participantId": participant_id, - "environment": environment, "tags": tags, } @@ -141,7 +137,6 @@ def update_thread_helper( name: Optional[str] = None, metadata: Optional[Dict] = None, participant_id: Optional[str] = None, - environment: Optional[str] = None, tags: Optional[List[str]] = None, ): variables = { @@ -149,7 +144,6 @@ def update_thread_helper( "name": name, "metadata": metadata, "participantId": participant_id, - "environment": environment, "tags": tags, } diff --git a/literalai/client.py b/literalai/client.py index 6ca3cba..c581f0c 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -5,11 +5,16 @@ from literalai.callback.langchain_callback import get_langchain_callback from literalai.callback.llama_index_callback import get_llama_index_callback from literalai.context import active_steps_var, active_thread_var +from literalai.environment import EnvContextManager, env_decorator from literalai.event_processor import EventProcessor +from literalai.experiment_run import ( + ExperimentRunContextManager, + experiment_run_decorator, +) from literalai.instrumentation.mistralai import instrument_mistralai from literalai.instrumentation.openai import instrument_openai from literalai.message import Message -from literalai.my_types import Attachment +from literalai.my_types import Attachment, Environment from literalai.step import ( MessageStepType, Step, @@ -29,6 +34,7 @@ def __init__( is_async: bool = False, api_key: Optional[str] = None, url: Optional[str] = None, + environment: Environment = "prod", disabled: bool = False, ): if not api_key: @@ -38,9 +44,11 @@ def __init__( if not url: url = os.getenv("LITERAL_API_URL", "https://cloud.getliteral.ai") if is_async: - self.api = AsyncLiteralAPI(api_key=api_key, url=url) + self.api = AsyncLiteralAPI( + api_key=api_key, url=url, environment=environment + ) else: - self.api = LiteralAPI(api_key=api_key, url=url) + self.api = LiteralAPI(api_key=api_key, url=url, environment=environment) self.disabled = disabled @@ -185,6 +193,43 @@ def message( return step + def environment( + self, + original_function=None, + env: Environment = "prod", + **kwargs, + ): + if original_function: + return env_decorator( + self, + func=original_function, + env=env, + **kwargs, + ) + else: + return EnvContextManager( + self, + env=env, + **kwargs, + ) + + def experiment_run( + self, + original_function=None, + **kwargs, + ): + if original_function: + return experiment_run_decorator( + self, + func=original_function, + **kwargs, + ) + else: + return ExperimentRunContextManager( + self, + **kwargs, + ) + def start_step( self, name: str = "", diff --git a/literalai/context.py b/literalai/context.py index fe8462f..5506210 100644 --- a/literalai/context.py +++ b/literalai/context.py @@ -7,3 +7,7 @@ active_steps_var = ContextVar[List["Step"]]("active_steps", default=[]) active_thread_var = ContextVar[Optional["Thread"]]("active_thread", default=None) + +active_experiment_run_id_var = ContextVar[Optional[str]]( + "active_experiment_run", default=None +) diff --git a/literalai/dataset_experiment.py b/literalai/dataset_experiment.py index 6c85751..ac81646 100644 --- a/literalai/dataset_experiment.py +++ b/literalai/dataset_experiment.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field -from typing import Dict, List, Optional, TypedDict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict +from literalai.context import active_experiment_run_id_var from literalai.my_types import ScoreDict, Utils if TYPE_CHECKING: @@ -15,6 +15,7 @@ class DatasetExperimentItemDict(TypedDict, total=False): scores: List[ScoreDict] input: Optional[Dict] output: Optional[Dict] + runExperimentId: Optional[str] @dataclass(repr=False) @@ -25,12 +26,14 @@ class DatasetExperimentItem(Utils): scores: List[ScoreDict] input: Optional[Dict] output: Optional[Dict] + run_experiment_id: Optional[str] def to_dict(self): return { "id": self.id, "datasetExperimentId": self.dataset_experiment_id, "datasetItemId": self.dataset_item_id, + "runExperimentId": self.run_experiment_id, "scores": self.scores, "input": self.input, "output": self.output, @@ -40,6 +43,7 @@ def to_dict(self): def from_dict(cls, item: DatasetExperimentItemDict) -> "DatasetExperimentItem": return cls( id=item.get("id", ""), + run_experiment_id=item.get("runExperimentId"), dataset_experiment_id=item.get("datasetExperimentId", ""), dataset_item_id=item.get("datasetItemId", ""), scores=item.get("scores", []), @@ -70,8 +74,10 @@ class DatasetExperiment(Utils): items: List[DatasetExperimentItem] = field(default_factory=lambda: []) def log(self, item_dict: DatasetExperimentItemDict) -> DatasetExperimentItem: + experiment_run_id = active_experiment_run_id_var.get() dataset_experiment_item = DatasetExperimentItem.from_dict( { + "runExperimentId": experiment_run_id, "datasetExperimentId": self.id, "datasetItemId": item_dict.get("datasetItemId", ""), "input": item_dict.get("input", {}), @@ -110,8 +116,5 @@ def from_dict( dataset_id=dataset_experiment.get("datasetId", ""), params=dataset_experiment.get("params"), prompt_id=dataset_experiment.get("promptId"), - items=[ - DatasetExperimentItem.from_dict(item) - for item in items - ], + items=[DatasetExperimentItem.from_dict(item) for item in items], ) diff --git a/literalai/environment.py b/literalai/environment.py new file mode 100644 index 0000000..e28bc6a --- /dev/null +++ b/literalai/environment.py @@ -0,0 +1,67 @@ +import inspect +from functools import wraps +from typing import TYPE_CHECKING, Callable, Optional + +from literalai.my_types import Environment + +if TYPE_CHECKING: + from literalai.client import BaseLiteralClient + + +class EnvContextManager: + def __init__(self, client: "BaseLiteralClient", env: Environment = "prod"): + self.client = client + self.env = env + self.original_env = client.api.environment + + def __call__(self, func): + return env_decorator( + self.client, + func=func, + ctx_manager=self, + ) + + async def __aenter__(self): + self.client.api.environment = self.env + + async def __aexit__(self): + self.client.api.environment = self.original_env + + def __enter__(self): + self.client.api.environment = self.env + + def __exit__(self): + self.client.api.environment = self.original_env + + +def env_decorator( + client: "BaseLiteralClient", + func: Callable, + env: Environment = "prod", + ctx_manager: Optional[EnvContextManager] = None, + **decorator_kwargs, +): + if not ctx_manager: + ctx_manager = EnvContextManager( + client=client, + env=env, + **decorator_kwargs, + ) + + # Handle async decorator + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + with ctx_manager: + await func(*args, **kwargs) + + return async_wrapper + else: + # Handle sync decorator + @wraps(func) + def sync_wrapper(*args, **kwargs): + with ctx_manager: + func(*args, **kwargs) + + return sync_wrapper diff --git a/literalai/experiment_run.py b/literalai/experiment_run.py new file mode 100644 index 0000000..bd07c6c --- /dev/null +++ b/literalai/experiment_run.py @@ -0,0 +1,77 @@ +import inspect +import uuid +from functools import wraps +from typing import TYPE_CHECKING, Callable, Optional + +from literalai.context import active_experiment_run_id_var +from literalai.environment import EnvContextManager +from literalai.step import StepContextManager + +if TYPE_CHECKING: + from literalai.client import BaseLiteralClient + + +class ExperimentRunContextManager(EnvContextManager, StepContextManager): + def __init__( + self, + client: "BaseLiteralClient", + ): + self.id = str(uuid.uuid4()) + EnvContextManager.__init__(self, client=client, env="experiment") + StepContextManager.__init__( + self, client=client, name="Experiment Run", type="run", id=self.id + ) + + def __call__(self, func): + return experiment_run_decorator( + self.client, + func=func, + ctx_manager=self, + ) + + async def __aenter__(self): + super().__aenter__() + active_experiment_run_id_var.set(self.id) + + async def __aexit__(self): + super().__aexit__() + active_experiment_run_id_var.set(None) + + def __enter__(self): + super().__enter__() + active_experiment_run_id_var.set(self.id) + + def __exit__(self): + super().__exit__() + active_experiment_run_id_var.set(None) + + +def experiment_run_decorator( + client: "BaseLiteralClient", + func: Callable, + ctx_manager: Optional[ExperimentRunContextManager] = None, + **decorator_kwargs, +): + if not ctx_manager: + ctx_manager = ExperimentRunContextManager( + client=client, + **decorator_kwargs, + ) + + # Handle async decorator + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + with ctx_manager: + await func(*args, **kwargs) + + return async_wrapper + else: + # Handle sync decorator + @wraps(func) + def sync_wrapper(*args, **kwargs): + with ctx_manager: + func(*args, **kwargs) + + return sync_wrapper diff --git a/literalai/my_types.py b/literalai/my_types.py index 0e7468d..7fa4ddc 100644 --- a/literalai/my_types.py +++ b/literalai/my_types.py @@ -11,6 +11,7 @@ from pydantic.dataclasses import Field, dataclass +Environment = Literal["dev", "staging", "prod", "experiment"] GenerationMessageRole = Literal["user", "assistant", "tool", "function", "system"] ScoreType = Literal["HUMAN", "AI"] From 2a8e022faa6b850ebe17f958baff143dfc268bfb Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Thu, 25 Jul 2024 14:49:17 +0200 Subject: [PATCH 3/6] fix: tests --- literalai/api/__init__.py | 16 ++++++++++++++-- literalai/dataset.py | 8 +++----- literalai/thread.py | 6 ++---- tests/e2e/test_e2e.py | 3 ++- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index 4d648d4..fe7c0f1 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -1124,7 +1124,13 @@ def create_experiment( DatasetExperiment: The newly created experiment object. """ return self.gql_helper( - *create_experiment_helper(self, name, dataset_id, prompt_id, params) + *create_experiment_helper( + api=self, + name=name, + dataset_id=dataset_id, + prompt_id=prompt_id, + params=params, + ) ) def create_experiment_item( @@ -2319,7 +2325,13 @@ async def create_experiment( sync_api = LiteralAPI(self.api_key, self.url) return await self.gql_helper( - *create_experiment_helper(sync_api, name, dataset_id, prompt_id, params) + *create_experiment_helper( + api=sync_api, + name=name, + dataset_id=dataset_id, + prompt_id=prompt_id, + params=params, + ) ) create_experiment.__doc__ = LiteralAPI.create_experiment.__doc__ diff --git a/literalai/dataset.py b/literalai/dataset.py index 96b70a9..656ec95 100644 --- a/literalai/dataset.py +++ b/literalai/dataset.py @@ -112,10 +112,10 @@ def create_experiment( :param name: The name of the experiment . :param prompt_id: The Prompt ID used on LLM calls (optional). :param params: The params used on the experiment. - :return: The created DatasetExperiment instance as a dictionary. + :return: The created DatasetExperiment instance. """ experiment = self.api.create_experiment( - self.id, name, prompt_id, params + name=name, dataset_id=self.id, prompt_id=prompt_id, params=params ) return experiment @@ -129,9 +129,7 @@ def delete_item(self, item_id: str): if self.items is not None: self.items = [item for item in self.items if item.id != item_id] - def add_step( - self, step_id: str, metadata: Optional[Dict] = None - ) -> DatasetItem: + def add_step(self, step_id: str, metadata: Optional[Dict] = None) -> DatasetItem: """ Create a new dataset item based on a step and add it to this dataset. :param step_id: The id of the step to add to the dataset. diff --git a/literalai/thread.py b/literalai/thread.py index f3ac44f..4ebbde4 100644 --- a/literalai/thread.py +++ b/literalai/thread.py @@ -34,7 +34,6 @@ class Thread(Utils): participant_id: Optional[str] participant_identifier: Optional[str] = None created_at: Optional[str] # read-only, set by server - needs_upsert: Optional[bool] def __init__( self, @@ -51,7 +50,6 @@ def __init__( self.metadata = metadata self.tags = tags self.participant_id = participant_id - self.needs_upsert = bool(metadata or tags or participant_id or name) def to_dict(self) -> ThreadDict: return { @@ -144,7 +142,7 @@ def __enter__(self) -> "Optional[Thread]": return active_thread_var.get() def __exit__(self, exc_type, exc_val, exc_tb): - if (thread := active_thread_var.get()) and thread.needs_upsert: + if active_thread_var.get(): self.upsert() active_thread_var.set(None) @@ -154,7 +152,7 @@ async def __aenter__(self): return active_thread_var.get() async def __aexit__(self, exc_type, exc_val, exc_tb): - if (thread := active_thread_var.get()) and thread.needs_upsert: + if active_thread_var.get(): self.upsert() active_thread_var.set(None) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 7f84fdb..1b31a4e 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -261,7 +261,8 @@ async def test_ingestion( stack = active_steps_var.get() assert len(stack) == 1 - assert async_client.event_processor.event_queue._qsize() == 1 + assert async_client.event_processor.event_queue._qsize() == 1 + stack = active_steps_var.get() assert len(stack) == 0 From 24ec09a2edc033b97f11d869f14e3a3a9873d3ad Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Thu, 25 Jul 2024 16:48:22 +0200 Subject: [PATCH 4/6] feat: test experiment run --- literalai/api/__init__.py | 17 +++++++++++++---- literalai/api/dataset_helpers.py | 2 ++ literalai/api/gql.py | 6 ++++++ literalai/client.py | 6 +++++- literalai/dataset_experiment.py | 10 +++++----- literalai/environment.py | 15 ++++++++------- literalai/experiment_run.py | 18 ++++++++++++------ literalai/step.py | 4 ++++ tests/e2e/test_e2e.py | 26 ++++++++++++++++++++++++++ 9 files changed, 81 insertions(+), 23 deletions(-) diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index fe7c0f1..fb45041 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -1,4 +1,5 @@ import logging +import os import uuid from typing import ( TYPE_CHECKING, @@ -117,7 +118,7 @@ def __init__( self, api_key: Optional[str] = None, url: Optional[str] = None, - environment: Environment = "prod", + environment: Optional[Environment] = None, ): if url and url[-1] == "/": url = url[:-1] @@ -129,7 +130,9 @@ def __init__( self.api_key = api_key self.url = url - self.environment = environment + + if environment: + os.environ["LITERAL_ENV"] = environment self.graphql_endpoint = self.url + "/api/graphql" self.rest_endpoint = self.url + "/api" @@ -138,14 +141,18 @@ def __init__( def headers(self): from literalai.version import __version__ - return { + h = { "Content-Type": "application/json", "x-api-key": self.api_key, - "x-env": self.environment, "x-client-name": "py-literal-client", "x-client-version": __version__, } + if env := os.getenv("LITERAL_ENV"): + h["x-env"] = env + + return h + def _prepare_variables(self, variables: Dict[str, Any]) -> Dict[str, Any]: """ Recursively checks and converts bytes objects in variables. @@ -1150,6 +1157,7 @@ def create_experiment_item( *create_experiment_item_helper( dataset_experiment_id=experiment_item.dataset_experiment_id, dataset_item_id=experiment_item.dataset_item_id, + experiment_run_id=experiment_item.experiment_run_id, input=experiment_item.input, output=experiment_item.output, ) @@ -2355,6 +2363,7 @@ async def create_experiment_item( *create_experiment_item_helper( dataset_experiment_id=experiment_item.dataset_experiment_id, dataset_item_id=experiment_item.dataset_item_id, + experiment_run_id=experiment_item.experiment_run_id, input=experiment_item.input, output=experiment_item.output, ) diff --git a/literalai/api/dataset_helpers.py b/literalai/api/dataset_helpers.py index 7e03323..154091a 100644 --- a/literalai/api/dataset_helpers.py +++ b/literalai/api/dataset_helpers.py @@ -119,11 +119,13 @@ def process_response(response): def create_experiment_item_helper( dataset_experiment_id: str, + experiment_run_id: Optional[str] = None, dataset_item_id: Optional[str] = None, input: Optional[Dict] = None, output: Optional[Dict] = None, ): variables = { + "experimentRunId": experiment_run_id, "datasetExperimentId": dataset_experiment_id, "datasetItemId": dataset_item_id, "input": input, diff --git a/literalai/api/gql.py b/literalai/api/gql.py index 63c049c..3458e7c 100644 --- a/literalai/api/gql.py +++ b/literalai/api/gql.py @@ -15,6 +15,7 @@ input output metadata + environment scores { id type @@ -848,6 +849,7 @@ CREATE_EXPERIMENT_ITEM = """ mutation CreateDatasetExperimentItem( $datasetExperimentId: String! + $experimentRunId: String $datasetItemId: String $input: Json $output: Json @@ -855,12 +857,16 @@ createDatasetExperimentItem( datasetExperimentId: $datasetExperimentId datasetItemId: $datasetItemId + experimentRunId: $experimentRunId input: $input output: $output ) { id input output + datasetExperimentId + experimentRunId + datasetItemId } } """ diff --git a/literalai/client.py b/literalai/client.py index c581f0c..8ecf11d 100644 --- a/literalai/client.py +++ b/literalai/client.py @@ -34,7 +34,7 @@ def __init__( is_async: bool = False, api_key: Optional[str] = None, url: Optional[str] = None, - environment: Environment = "prod", + environment: Optional[Environment] = None, disabled: bool = False, ): if not api_key: @@ -277,6 +277,7 @@ def __init__( batch_size: int = 5, api_key: Optional[str] = None, url: Optional[str] = None, + environment: Optional[Environment] = None, disabled: bool = False, ): super().__init__( @@ -285,6 +286,7 @@ def __init__( api_key=api_key, url=url, disabled=disabled, + environment=environment, ) def flush(self): @@ -299,6 +301,7 @@ def __init__( batch_size: int = 5, api_key: Optional[str] = None, url: Optional[str] = None, + environment: Optional[Environment] = None, disabled: bool = False, ): super().__init__( @@ -307,6 +310,7 @@ def __init__( api_key=api_key, url=url, disabled=disabled, + environment=environment, ) async def flush(self): diff --git a/literalai/dataset_experiment.py b/literalai/dataset_experiment.py index e4c6ca8..298a846 100644 --- a/literalai/dataset_experiment.py +++ b/literalai/dataset_experiment.py @@ -15,7 +15,7 @@ class DatasetExperimentItemDict(TypedDict, total=False): scores: List[ScoreDict] input: Optional[Dict] output: Optional[Dict] - runExperimentId: Optional[str] + experimentRunId: Optional[str] @dataclass(repr=False) @@ -26,14 +26,14 @@ class DatasetExperimentItem(Utils): scores: List[ScoreDict] input: Optional[Dict] output: Optional[Dict] - run_experiment_id: Optional[str] + experiment_run_id: Optional[str] def to_dict(self): return { "id": self.id, "datasetExperimentId": self.dataset_experiment_id, "datasetItemId": self.dataset_item_id, - "runExperimentId": self.run_experiment_id, + "experimentRunId": self.experiment_run_id, "scores": self.scores, "input": self.input, "output": self.output, @@ -43,7 +43,7 @@ def to_dict(self): def from_dict(cls, item: DatasetExperimentItemDict) -> "DatasetExperimentItem": return cls( id=item.get("id", ""), - run_experiment_id=item.get("runExperimentId"), + experiment_run_id=item.get("experimentRunId"), dataset_experiment_id=item.get("datasetExperimentId", ""), dataset_item_id=item.get("datasetItemId"), scores=item.get("scores", []), @@ -77,7 +77,7 @@ def log(self, item_dict: DatasetExperimentItemDict) -> DatasetExperimentItem: experiment_run_id = active_experiment_run_id_var.get() dataset_experiment_item = DatasetExperimentItem.from_dict( { - "runExperimentId": experiment_run_id, + "experimentRunId": experiment_run_id, "datasetExperimentId": self.id, "datasetItemId": item_dict.get("datasetItemId"), "input": item_dict.get("input", {}), diff --git a/literalai/environment.py b/literalai/environment.py index e28bc6a..1f8151a 100644 --- a/literalai/environment.py +++ b/literalai/environment.py @@ -1,4 +1,5 @@ import inspect +import os from functools import wraps from typing import TYPE_CHECKING, Callable, Optional @@ -12,7 +13,7 @@ class EnvContextManager: def __init__(self, client: "BaseLiteralClient", env: Environment = "prod"): self.client = client self.env = env - self.original_env = client.api.environment + self.original_env = os.environ.get("LITERAL_ENV", "") def __call__(self, func): return env_decorator( @@ -22,16 +23,16 @@ def __call__(self, func): ) async def __aenter__(self): - self.client.api.environment = self.env + os.environ["LITERAL_ENV"] = self.env - async def __aexit__(self): - self.client.api.environment = self.original_env + async def __aexit__(self, exc_type, exc_val, exc_tb): + os.environ = self.original_env def __enter__(self): - self.client.api.environment = self.env + os.environ["LITERAL_ENV"] = self.env - def __exit__(self): - self.client.api.environment = self.original_env + def __exit__(self, exc_type, exc_val, exc_tb): + os.environ["LITERAL_ENV"] = self.original_env def env_decorator( diff --git a/literalai/experiment_run.py b/literalai/experiment_run.py index bd07c6c..ce92928 100644 --- a/literalai/experiment_run.py +++ b/literalai/experiment_run.py @@ -30,19 +30,25 @@ def __call__(self, func): ) async def __aenter__(self): - super().__aenter__() active_experiment_run_id_var.set(self.id) + await EnvContextManager.__aenter__(self) + await StepContextManager.__aenter__(self) - async def __aexit__(self): - super().__aexit__() + async def __aexit__(self, exc_type, exc_val, exc_tb): + await StepContextManager.__aexit__(self, exc_type, exc_val, exc_tb) + await self.client.event_processor.aflush() + await EnvContextManager.__aexit__(self, exc_type, exc_val, exc_tb) active_experiment_run_id_var.set(None) def __enter__(self): - super().__enter__() active_experiment_run_id_var.set(self.id) + EnvContextManager.__enter__(self) + StepContextManager.__enter__(self) - def __exit__(self): - super().__exit__() + def __exit__(self, exc_type, exc_val, exc_tb): + StepContextManager.__exit__(self, exc_type, exc_val, exc_tb) + self.client.event_processor.flush() + EnvContextManager.__exit__(self, exc_type, exc_val, exc_tb) active_experiment_run_id_var.set(None) diff --git a/literalai/step.py b/literalai/step.py index 5197df5..69803e6 100644 --- a/literalai/step.py +++ b/literalai/step.py @@ -26,6 +26,7 @@ BaseGeneration, ChatGeneration, CompletionGeneration, + Environment, Score, ScoreDict, Utils, @@ -44,6 +45,7 @@ class StepDict(TypedDict, total=False): id: Optional[str] name: Optional[str] type: Optional[StepType] + environment: Optional[Environment] threadId: Optional[str] error: Optional[str] input: Optional[Dict] @@ -73,6 +75,7 @@ class Step(Utils): output: Optional[Dict[str, Any]] = None tags: Optional[List[str]] = None thread_id: Optional[str] = None + environment: Optional[Environment] = None generation: Optional[Union[ChatGeneration, CompletionGeneration]] = None scores: Optional[List[Score]] = [] @@ -173,6 +176,7 @@ def from_dict(cls, step_dict: StepDict) -> "Step": step.input = step_dict.get("input", None) step.error = step_dict.get("error", None) step.output = step_dict.get("output", None) + step.environment = step_dict.get("environment", None) step.metadata = step_dict.get("metadata", {}) step.tags = step_dict.get("tags", []) step.parent_id = step_dict.get("parentId", None) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 1b31a4e..fdba835 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -636,3 +636,29 @@ async def test_experiment_params_optional(self, client: LiteralClient): experiment = dataset.create_experiment(name="test-experiment") assert experiment.params is None dataset.delete() + + @pytest.mark.timeout(5) + async def test_experiment_run(self, client: LiteralClient): + experiment = client.api.create_experiment(name="test-experiment-run") + + @client.step(type="run") + def agent(input): + return {"content": "hello world!"} + + with client.experiment_run(): + input = {"question": "question"} + output = agent(input) + item = experiment.log( + { + "scores": [ + {"name": "context_relevancy", "type": "AI", "value": 0.6} + ], + "input": input, + "output": output, + } + ) + + assert item.experiment_run_id is not None + experiment_run = client.api.get_step(item.experiment_run_id) + assert experiment_run is not None + assert experiment_run.environment == "experiment" From 1b1c4c45981c94883df6d644b36b0721b273d9cb Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Thu, 25 Jul 2024 19:17:12 +0200 Subject: [PATCH 5/6] feat: return step from experiment run context manager --- literalai/experiment_run.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/literalai/experiment_run.py b/literalai/experiment_run.py index ce92928..7172243 100644 --- a/literalai/experiment_run.py +++ b/literalai/experiment_run.py @@ -32,7 +32,8 @@ def __call__(self, func): async def __aenter__(self): active_experiment_run_id_var.set(self.id) await EnvContextManager.__aenter__(self) - await StepContextManager.__aenter__(self) + step = await StepContextManager.__aenter__(self) + return step async def __aexit__(self, exc_type, exc_val, exc_tb): await StepContextManager.__aexit__(self, exc_type, exc_val, exc_tb) @@ -43,7 +44,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def __enter__(self): active_experiment_run_id_var.set(self.id) EnvContextManager.__enter__(self) - StepContextManager.__enter__(self) + step = StepContextManager.__enter__(self) + return step def __exit__(self, exc_type, exc_val, exc_tb): StepContextManager.__exit__(self, exc_type, exc_val, exc_tb) From b1baefd65a604e9a2c8eb42d1eebe144cf05d8a0 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Thu, 25 Jul 2024 20:28:44 +0200 Subject: [PATCH 6/6] feat: add env test --- tests/e2e/test_e2e.py | 39 +++++++++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index fdba835..a9f55aa 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -46,6 +46,18 @@ def client(self): yield client client.event_processor.flush_and_stop() + @pytest.fixture(scope="session") + def staging_client(self): + url = os.getenv("LITERAL_API_URL", None) + api_key = os.getenv("LITERAL_API_KEY", None) + assert url is not None and api_key is not None, "Missing environment variables" + + client = LiteralClient( + batch_size=1, url=url, api_key=api_key, environment="staging" + ) + yield client + client.event_processor.flush_and_stop() + @pytest.fixture(scope="session") def async_client(self): url = os.getenv("LITERAL_API_URL", None) @@ -362,9 +374,7 @@ def step_decorated(): step_id = step_decorated() await assert_delete(step_id) - async def test_parallel_requests( - self, client: LiteralClient, async_client: AsyncLiteralClient - ): + async def test_parallel_requests(self, async_client: AsyncLiteralClient): ids = [] @async_client.thread @@ -393,9 +403,7 @@ async def create_test_step(self, async_client: AsyncLiteralClient): ) @pytest.mark.timeout(5) - async def test_dataset( - self, client: LiteralClient, async_client: AsyncLiteralClient - ): + async def test_dataset(self, async_client: AsyncLiteralClient): dataset_name = str(uuid.uuid4()) step = await self.create_test_step(async_client) dataset = await async_client.api.create_dataset( @@ -456,9 +464,7 @@ async def test_dataset( assert deleted_dataset is None @pytest.mark.timeout(5) - async def test_generation_dataset( - self, client: LiteralClient, async_client: AsyncLiteralClient - ): + async def test_generation_dataset(self, async_client: AsyncLiteralClient): chat_generation = ChatGeneration( provider="test", model="test", @@ -554,9 +560,7 @@ async def test_dataset_sync( assert client.api.get_dataset(id=fetched_dataset.id) is None @pytest.mark.timeout(5) - async def test_prompt( - self, client: LiteralClient, async_client: AsyncLiteralClient - ): + async def test_prompt(self, async_client: AsyncLiteralClient): prompt = await async_client.api.get_prompt(name="Default", version=0) assert prompt is not None assert prompt.name == "Default" @@ -662,3 +666,14 @@ def agent(input): experiment_run = client.api.get_step(item.experiment_run_id) assert experiment_run is not None assert experiment_run.environment == "experiment" + + @pytest.mark.timeout(5) + async def test_environment(self, staging_client: LiteralClient): + run_id: str + with staging_client.run(name="foo") as run: + run_id = run.id + staging_client.event_processor.flush() + assert run_id is not None + persisted_run = staging_client.api.get_step(run_id) + assert persisted_run is not None + assert persisted_run.environment == "staging"