Skip to content

feat: prompt variant #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions literalai/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@
PromptRollout,
create_prompt_helper,
create_prompt_lineage_helper,
create_prompt_variant_helper,
get_prompt_ab_testing_helper,
get_prompt_helper,
get_prompt_lineage_helper,
update_prompt_ab_testing_helper,
)
from literalai.api.score_helpers import (
Expand Down Expand Up @@ -144,7 +146,6 @@ def handle_bytes(item):


class BaseLiteralAPI:

def __init__(
self,
api_key: Optional[str] = None,
Expand Down Expand Up @@ -201,7 +202,6 @@ class LiteralAPI(BaseLiteralAPI):
def make_gql_call(
self, description: str, query: str, variables: Dict[str, Any]
) -> Dict:

def raise_error(error):
logger.error(f"Failed to {description}: {error}")
raise Exception(error)
Expand Down Expand Up @@ -1141,7 +1141,7 @@ def create_experiment(
self,
name: str,
dataset_id: Optional[str] = None,
prompt_id: Optional[str] = None,
prompt_variant_id: Optional[str] = None,
params: Optional[Dict] = None,
) -> "DatasetExperiment":
"""
Expand All @@ -1150,7 +1150,7 @@ def create_experiment(
Args:
name (str): The name of the experiment.
dataset_id (Optional[str]): The unique identifier of the dataset.
prompt_id (Optional[str]): The identifier of the prompt associated with the experiment.
prompt_variant_id (Optional[str]): The identifier of the prompt variant to associate to the experiment.
params (Optional[Dict]): Additional parameters for the experiment.

Returns:
Expand All @@ -1161,7 +1161,7 @@ def create_experiment(
api=self,
name=name,
dataset_id=dataset_id,
prompt_id=prompt_id,
prompt_variant_id=prompt_variant_id,
params=params,
)
)
Expand Down Expand Up @@ -1369,6 +1369,34 @@ def get_prompt(
else:
raise ValueError("Either the `id` or the `name` must be provided.")

def create_prompt_variant(
self,
name: str,
template_messages: List[GenerationMessage],
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
) -> Optional[str]:
"""
Creates a prompt variation for an experiment.
This variation is not an official version until manually saved.

Args:
name (str): The name of the prompt to retrieve or create.
template_messages (List[GenerationMessage]): A list of template messages for the prompt.
settings (Optional[Dict]): Optional settings for the prompt.
tools (Optional[List[Dict]]): Optional tool options for the model

Returns:
prompt_variant_id: The prompt variant id to link with the experiment.
"""
lineage = self.gql_helper(*get_prompt_lineage_helper(name))
lineage_id = lineage["id"] if lineage else None
return self.gql_helper(
*create_prompt_variant_helper(
lineage_id, template_messages, settings, tools
)
)

def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
"""
Get the A/B testing configuration for a prompt lineage.
Expand Down Expand Up @@ -2351,7 +2379,7 @@ async def create_experiment(
self,
name: str,
dataset_id: Optional[str] = None,
prompt_id: Optional[str] = None,
prompt_variant_id: Optional[str] = None,
params: Optional[Dict] = None,
) -> "DatasetExperiment":
sync_api = LiteralAPI(self.api_key, self.url)
Expand All @@ -2361,7 +2389,7 @@ async def create_experiment(
api=sync_api,
name=name,
dataset_id=dataset_id,
prompt_id=prompt_id,
prompt_variant_id=prompt_variant_id,
params=params,
)
)
Expand Down Expand Up @@ -2529,6 +2557,36 @@ async def create_prompt(
):
return await self.get_or_create_prompt(name, template_messages, settings)

async def create_prompt_variant(
self,
name: str,
template_messages: List[GenerationMessage],
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
) -> Optional[str]:
"""
Creates a prompt variation for an experiment.
This variation is not an official version until manually saved.

Args:
name (str): The name of the prompt to retrieve or create.
template_messages (List[GenerationMessage]): A list of template messages for the prompt.
settings (Optional[Dict]): Optional settings for the prompt.
tools (Optional[List[Dict]]): Optional tool options for the model

Returns:
prompt_variant_id: The prompt variant id to link with the experiment.
"""
lineage = await self.gql_helper(*get_prompt_lineage_helper(name))
lineage_id = lineage["id"] if lineage else None
return await self.gql_helper(
*create_prompt_variant_helper(
lineage_id, template_messages, settings, tools
)
)

create_prompt_variant.__doc__ = LiteralAPI.create_prompt_variant.__doc__

async def get_prompt(
self,
id: Optional[str] = None,
Expand Down
11 changes: 6 additions & 5 deletions literalai/api/dataset_helpers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from typing import TYPE_CHECKING, Dict, Optional

from literalai.api import gql

from literalai.evaluation.dataset import Dataset, DatasetType
from literalai.evaluation.dataset_experiment import DatasetExperiment, DatasetExperimentItem
from literalai.evaluation.dataset_experiment import (
DatasetExperiment,
DatasetExperimentItem,
)
from literalai.evaluation.dataset_item import DatasetItem

if TYPE_CHECKING:
from literalai.api import LiteralAPI



def create_dataset_helper(
api: "LiteralAPI",
name: str,
Expand Down Expand Up @@ -98,13 +99,13 @@ def create_experiment_helper(
api: "LiteralAPI",
name: str,
dataset_id: Optional[str] = None,
prompt_id: Optional[str] = None,
prompt_variant_id: Optional[str] = None,
params: Optional[Dict] = None,
):
variables = {
"datasetId": dataset_id,
"name": name,
"promptId": prompt_id,
"promptExperimentId": prompt_variant_id,
"params": params,
}

Expand Down
47 changes: 45 additions & 2 deletions literalai/api/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,18 +833,19 @@
mutation CreateDatasetExperiment(
$name: String!
$datasetId: String
$promptId: String
$promptExperimentId: String
$params: Json
) {
createDatasetExperiment(
name: $name
datasetId: $datasetId
promptId: $promptId
promptExperimentId: $promptExperimentId
params: $params
) {
id
name
datasetId
promptExperimentId
params
}
}
Expand Down Expand Up @@ -991,6 +992,16 @@
}
}"""

GET_PROMPT_LINEAGE = """query promptLineage(
$name: String!
) {
promptLineage(
name: $name
) {
id
}
}"""

CREATE_PROMPT_VERSION = """mutation createPromptVersion(
$lineageId: String!
$versionDesc: String
Expand Down Expand Up @@ -1021,6 +1032,38 @@
}
}"""

CREATE_PROMPT_VARIANT = """mutation createPromptExperiment(
$fromLineageId: String
$fromVersion: Int
$scoreTemplateId: String
$templateMessages: Json
$settings: Json
$tools: Json
$variables: Json
) {
createPromptExperiment(
fromLineageId: $fromLineageId
fromVersion: $fromVersion
scoreTemplateId: $scoreTemplateId
templateMessages: $templateMessages
settings: $settings
tools: $tools
variables: $variables
) {
id
fromLineageId
fromVersion
scoreTemplateId
projectId
projectUserId
tools
settings
variables
templateMessages
}
}
"""

GET_PROMPT_VERSION = """
query GetPrompt($id: String, $name: String, $version: Int) {
promptVersion(id: $id, name: $name, version: $version) {
Expand Down
34 changes: 34 additions & 0 deletions literalai/api/prompt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def process_response(response):
return gql.CREATE_PROMPT_LINEAGE, description, variables, process_response


def get_prompt_lineage_helper(name: str):
variables = {"name": name}

def process_response(response):
prompt = response["data"]["promptLineage"]
return prompt

description = "get prompt lineage"

return gql.GET_PROMPT_LINEAGE, description, variables, process_response


def create_prompt_helper(
api: "LiteralAPI",
lineage_id: str,
Expand Down Expand Up @@ -61,6 +73,28 @@ def process_response(response):
return gql.GET_PROMPT_VERSION, description, variables, process_response


def create_prompt_variant_helper(
from_lineage_id: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take we don't want to expose from_version integer.

template_messages: List[GenerationMessage] = [],
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
):
variables = {
"fromLineageId": from_lineage_id,
"templateMessages": template_messages,
"settings": settings,
"tools": tools,
}

def process_response(response):
variant = response["data"]["createPromptExperiment"]
return variant["id"] if variant else None

description = "create prompt variant"

return gql.CREATE_PROMPT_VARIANT, description, variables, process_response


class PromptRollout(TypedDict):
version: int
rollout: int
Expand Down
16 changes: 11 additions & 5 deletions literalai/evaluation/dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, cast

from literalai.my_types import Utils

from typing_extensions import TypedDict

from literalai.my_types import Utils

if TYPE_CHECKING:
from literalai.api import LiteralAPI

Expand Down Expand Up @@ -101,17 +101,23 @@ def create_item(
return dataset_item

def create_experiment(
self, name: str, prompt_id: Optional[str] = None, params: Optional[Dict] = None
self,
name: str,
prompt_variant_id: Optional[str] = None,
params: Optional[Dict] = None,
) -> DatasetExperiment:
"""
Creates a new dataset experiment based on this dataset.
:param name: The name of the experiment .
:param prompt_id: The Prompt ID used on LLM calls (optional).
:param prompt_variant_id: The Prompt variant ID to experiment on.
:param params: The params used on the experiment.
:return: The created DatasetExperiment instance.
"""
experiment = self.api.create_experiment(
name=name, dataset_id=self.id, prompt_id=prompt_id, params=params
name=name,
dataset_id=self.id,
prompt_variant_id=prompt_variant_id,
params=params,
)
return experiment

Expand Down
8 changes: 4 additions & 4 deletions literalai/evaluation/dataset_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class DatasetExperimentDict(TypedDict, total=False):
name: str
datasetId: str
params: Dict
promptId: Optional[str]
promptExperimentId: Optional[str]
items: Optional[List[DatasetExperimentItemDict]]


Expand All @@ -71,7 +71,7 @@ class DatasetExperiment(Utils):
name: str
dataset_id: Optional[str]
params: Optional[Dict]
prompt_id: Optional[str] = None
prompt_variant_id: Optional[str] = None
items: List[DatasetExperimentItem] = field(default_factory=lambda: [])

def log(self, item_dict: DatasetExperimentItemDict) -> DatasetExperimentItem:
Expand All @@ -97,7 +97,7 @@ def to_dict(self):
"createdAt": self.created_at,
"name": self.name,
"datasetId": self.dataset_id,
"promptId": self.prompt_id,
"promptExperimentId": self.prompt_variant_id,
"params": self.params,
"items": [item.to_dict() for item in self.items],
}
Expand All @@ -116,6 +116,6 @@ def from_dict(
name=dataset_experiment.get("name", ""),
dataset_id=dataset_experiment.get("datasetId", ""),
params=dataset_experiment.get("params"),
prompt_id=dataset_experiment.get("promptId"),
prompt_variant_id=dataset_experiment.get("promptExperimentId"),
items=[DatasetExperimentItem.from_dict(item) for item in items],
)
Loading
Loading