diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index cecdccb..b771092 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -44,8 +44,10 @@ PromptRollout, create_prompt_helper, create_prompt_lineage_helper, + create_prompt_variant_helper, get_prompt_ab_testing_helper, get_prompt_helper, + get_prompt_lineage_helper, update_prompt_ab_testing_helper, ) from literalai.api.score_helpers import ( @@ -144,7 +146,6 @@ def handle_bytes(item): class BaseLiteralAPI: - def __init__( self, api_key: Optional[str] = None, @@ -201,7 +202,6 @@ class LiteralAPI(BaseLiteralAPI): def make_gql_call( self, description: str, query: str, variables: Dict[str, Any] ) -> Dict: - def raise_error(error): logger.error(f"Failed to {description}: {error}") raise Exception(error) @@ -1141,7 +1141,7 @@ def create_experiment( self, name: str, dataset_id: Optional[str] = None, - prompt_id: Optional[str] = None, + prompt_variant_id: Optional[str] = None, params: Optional[Dict] = None, ) -> "DatasetExperiment": """ @@ -1150,7 +1150,7 @@ def create_experiment( Args: name (str): The name of the experiment. dataset_id (Optional[str]): The unique identifier of the dataset. - prompt_id (Optional[str]): The identifier of the prompt associated with the experiment. + prompt_variant_id (Optional[str]): The identifier of the prompt variant to associate to the experiment. params (Optional[Dict]): Additional parameters for the experiment. Returns: @@ -1161,7 +1161,7 @@ def create_experiment( api=self, name=name, dataset_id=dataset_id, - prompt_id=prompt_id, + prompt_variant_id=prompt_variant_id, params=params, ) ) @@ -1369,6 +1369,34 @@ def get_prompt( else: raise ValueError("Either the `id` or the `name` must be provided.") + def create_prompt_variant( + self, + name: str, + template_messages: List[GenerationMessage], + settings: Optional[ProviderSettings] = None, + tools: Optional[List[Dict]] = None, + ) -> Optional[str]: + """ + Creates a prompt variation for an experiment. + This variation is not an official version until manually saved. + + Args: + name (str): The name of the prompt to retrieve or create. + template_messages (List[GenerationMessage]): A list of template messages for the prompt. + settings (Optional[Dict]): Optional settings for the prompt. + tools (Optional[List[Dict]]): Optional tool options for the model + + Returns: + prompt_variant_id: The prompt variant id to link with the experiment. + """ + lineage = self.gql_helper(*get_prompt_lineage_helper(name)) + lineage_id = lineage["id"] if lineage else None + return self.gql_helper( + *create_prompt_variant_helper( + lineage_id, template_messages, settings, tools + ) + ) + def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]: """ Get the A/B testing configuration for a prompt lineage. @@ -2351,7 +2379,7 @@ async def create_experiment( self, name: str, dataset_id: Optional[str] = None, - prompt_id: Optional[str] = None, + prompt_variant_id: Optional[str] = None, params: Optional[Dict] = None, ) -> "DatasetExperiment": sync_api = LiteralAPI(self.api_key, self.url) @@ -2361,7 +2389,7 @@ async def create_experiment( api=sync_api, name=name, dataset_id=dataset_id, - prompt_id=prompt_id, + prompt_variant_id=prompt_variant_id, params=params, ) ) @@ -2529,6 +2557,36 @@ async def create_prompt( ): return await self.get_or_create_prompt(name, template_messages, settings) + async def create_prompt_variant( + self, + name: str, + template_messages: List[GenerationMessage], + settings: Optional[ProviderSettings] = None, + tools: Optional[List[Dict]] = None, + ) -> Optional[str]: + """ + Creates a prompt variation for an experiment. + This variation is not an official version until manually saved. + + Args: + name (str): The name of the prompt to retrieve or create. + template_messages (List[GenerationMessage]): A list of template messages for the prompt. + settings (Optional[Dict]): Optional settings for the prompt. + tools (Optional[List[Dict]]): Optional tool options for the model + + Returns: + prompt_variant_id: The prompt variant id to link with the experiment. + """ + lineage = await self.gql_helper(*get_prompt_lineage_helper(name)) + lineage_id = lineage["id"] if lineage else None + return await self.gql_helper( + *create_prompt_variant_helper( + lineage_id, template_messages, settings, tools + ) + ) + + create_prompt_variant.__doc__ = LiteralAPI.create_prompt_variant.__doc__ + async def get_prompt( self, id: Optional[str] = None, diff --git a/literalai/api/dataset_helpers.py b/literalai/api/dataset_helpers.py index 2f01ead..5d2ff46 100644 --- a/literalai/api/dataset_helpers.py +++ b/literalai/api/dataset_helpers.py @@ -1,16 +1,17 @@ from typing import TYPE_CHECKING, Dict, Optional from literalai.api import gql - from literalai.evaluation.dataset import Dataset, DatasetType -from literalai.evaluation.dataset_experiment import DatasetExperiment, DatasetExperimentItem +from literalai.evaluation.dataset_experiment import ( + DatasetExperiment, + DatasetExperimentItem, +) from literalai.evaluation.dataset_item import DatasetItem if TYPE_CHECKING: from literalai.api import LiteralAPI - def create_dataset_helper( api: "LiteralAPI", name: str, @@ -98,13 +99,13 @@ def create_experiment_helper( api: "LiteralAPI", name: str, dataset_id: Optional[str] = None, - prompt_id: Optional[str] = None, + prompt_variant_id: Optional[str] = None, params: Optional[Dict] = None, ): variables = { "datasetId": dataset_id, "name": name, - "promptId": prompt_id, + "promptExperimentId": prompt_variant_id, "params": params, } diff --git a/literalai/api/gql.py b/literalai/api/gql.py index 1aab01e..4a88040 100644 --- a/literalai/api/gql.py +++ b/literalai/api/gql.py @@ -833,18 +833,19 @@ mutation CreateDatasetExperiment( $name: String! $datasetId: String - $promptId: String + $promptExperimentId: String $params: Json ) { createDatasetExperiment( name: $name datasetId: $datasetId - promptId: $promptId + promptExperimentId: $promptExperimentId params: $params ) { id name datasetId + promptExperimentId params } } @@ -991,6 +992,16 @@ } }""" +GET_PROMPT_LINEAGE = """query promptLineage( + $name: String! + ) { + promptLineage( + name: $name + ) { + id + } +}""" + CREATE_PROMPT_VERSION = """mutation createPromptVersion( $lineageId: String! $versionDesc: String @@ -1021,6 +1032,38 @@ } }""" +CREATE_PROMPT_VARIANT = """mutation createPromptExperiment( + $fromLineageId: String + $fromVersion: Int + $scoreTemplateId: String + $templateMessages: Json + $settings: Json + $tools: Json + $variables: Json + ) { + createPromptExperiment( + fromLineageId: $fromLineageId + fromVersion: $fromVersion + scoreTemplateId: $scoreTemplateId + templateMessages: $templateMessages + settings: $settings + tools: $tools + variables: $variables + ) { + id + fromLineageId + fromVersion + scoreTemplateId + projectId + projectUserId + tools + settings + variables + templateMessages + } + } + """ + GET_PROMPT_VERSION = """ query GetPrompt($id: String, $name: String, $version: Int) { promptVersion(id: $id, name: $name, version: $version) { diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index e754406..3377f77 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -21,6 +21,18 @@ def process_response(response): return gql.CREATE_PROMPT_LINEAGE, description, variables, process_response +def get_prompt_lineage_helper(name: str): + variables = {"name": name} + + def process_response(response): + prompt = response["data"]["promptLineage"] + return prompt + + description = "get prompt lineage" + + return gql.GET_PROMPT_LINEAGE, description, variables, process_response + + def create_prompt_helper( api: "LiteralAPI", lineage_id: str, @@ -61,6 +73,28 @@ def process_response(response): return gql.GET_PROMPT_VERSION, description, variables, process_response +def create_prompt_variant_helper( + from_lineage_id: Optional[str] = None, + template_messages: List[GenerationMessage] = [], + settings: Optional[ProviderSettings] = None, + tools: Optional[List[Dict]] = None, +): + variables = { + "fromLineageId": from_lineage_id, + "templateMessages": template_messages, + "settings": settings, + "tools": tools, + } + + def process_response(response): + variant = response["data"]["createPromptExperiment"] + return variant["id"] if variant else None + + description = "create prompt variant" + + return gql.CREATE_PROMPT_VARIANT, description, variables, process_response + + class PromptRollout(TypedDict): version: int rollout: int diff --git a/literalai/evaluation/dataset.py b/literalai/evaluation/dataset.py index 72b0c24..6fc7b52 100644 --- a/literalai/evaluation/dataset.py +++ b/literalai/evaluation/dataset.py @@ -1,10 +1,10 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Literal, Optional, cast -from literalai.my_types import Utils - from typing_extensions import TypedDict +from literalai.my_types import Utils + if TYPE_CHECKING: from literalai.api import LiteralAPI @@ -101,17 +101,23 @@ def create_item( return dataset_item def create_experiment( - self, name: str, prompt_id: Optional[str] = None, params: Optional[Dict] = None + self, + name: str, + prompt_variant_id: Optional[str] = None, + params: Optional[Dict] = None, ) -> DatasetExperiment: """ Creates a new dataset experiment based on this dataset. :param name: The name of the experiment . - :param prompt_id: The Prompt ID used on LLM calls (optional). + :param prompt_variant_id: The Prompt variant ID to experiment on. :param params: The params used on the experiment. :return: The created DatasetExperiment instance. """ experiment = self.api.create_experiment( - name=name, dataset_id=self.id, prompt_id=prompt_id, params=params + name=name, + dataset_id=self.id, + prompt_variant_id=prompt_variant_id, + params=params, ) return experiment diff --git a/literalai/evaluation/dataset_experiment.py b/literalai/evaluation/dataset_experiment.py index 76a67a9..cacd193 100644 --- a/literalai/evaluation/dataset_experiment.py +++ b/literalai/evaluation/dataset_experiment.py @@ -59,7 +59,7 @@ class DatasetExperimentDict(TypedDict, total=False): name: str datasetId: str params: Dict - promptId: Optional[str] + promptExperimentId: Optional[str] items: Optional[List[DatasetExperimentItemDict]] @@ -71,7 +71,7 @@ class DatasetExperiment(Utils): name: str dataset_id: Optional[str] params: Optional[Dict] - prompt_id: Optional[str] = None + prompt_variant_id: Optional[str] = None items: List[DatasetExperimentItem] = field(default_factory=lambda: []) def log(self, item_dict: DatasetExperimentItemDict) -> DatasetExperimentItem: @@ -97,7 +97,7 @@ def to_dict(self): "createdAt": self.created_at, "name": self.name, "datasetId": self.dataset_id, - "promptId": self.prompt_id, + "promptExperimentId": self.prompt_variant_id, "params": self.params, "items": [item.to_dict() for item in self.items], } @@ -116,6 +116,6 @@ def from_dict( name=dataset_experiment.get("name", ""), dataset_id=dataset_experiment.get("datasetId", ""), params=dataset_experiment.get("params"), - prompt_id=dataset_experiment.get("promptId"), + prompt_variant_id=dataset_experiment.get("promptExperimentId"), items=[DatasetExperimentItem.from_dict(item) for item in items], ) diff --git a/literalai/instrumentation/mistralai.py b/literalai/instrumentation/mistralai.py index 8ceccb5..31966d1 100644 --- a/literalai/instrumentation/mistralai.py +++ b/literalai/instrumentation/mistralai.py @@ -271,8 +271,9 @@ def process_delta(new_delta: DeltaMessage, message_completion: GenerationMessage return True elif new_delta.content: if isinstance(message_completion["content"], str): - message_completion["content"] += new_delta.content - return True + if isinstance(new_delta.content, str): + message_completion["content"] += new_delta.content + return True else: return False @@ -423,7 +424,8 @@ async def async_streaming_response( time.time() - context["start"] ) * 1000 token_count += 1 - completion += chunk.data.choices[0].delta.content or "" + if isinstance(chunk.data.choices[0].delta.content, str): + completion += chunk.data.choices[0].delta.content or "" if ( generation diff --git a/literalai/version.py b/literalai/version.py index 39b47d2..3dc1f76 100644 --- a/literalai/version.py +++ b/literalai/version.py @@ -1 +1 @@ -__version__ = "0.0.629" +__version__ = "0.1.0" diff --git a/setup.py b/setup.py index 73ffbc9..37d7aaa 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="literalai", - version="0.0.629", # update version in literalai/version.py + version="0.1.0", # update version in literalai/version.py description="An SDK for observability in Python applications", long_description=open("README.md").read(), long_description_content_type="text/markdown", diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index aede802..3b58199 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -740,6 +740,17 @@ async def test_experiment_params_optional(self, client: LiteralClient): assert experiment.params is None dataset.delete() + @pytest.mark.timeout(5) + async def test_experiment_prompt(self, client: LiteralClient): + prompt_variant_id = client.api.create_prompt_variant( + name="Default", template_messages=[{"role": "user", "content": "hello"}] + ) + experiment = client.api.create_experiment( + name="test-experiment", prompt_variant_id=prompt_variant_id + ) + assert experiment.params is None + assert experiment.prompt_variant_id == prompt_variant_id + @pytest.mark.timeout(5) async def test_experiment_run(self, client: LiteralClient): experiment = client.api.create_experiment(name="test-experiment-run")