Skip to content

feat: replace promote with prompt ab testing #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 59 additions & 43 deletions literalai/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,6 @@

from typing_extensions import deprecated

from literalai.context import active_steps_var, active_thread_var
from literalai.evaluation.dataset import Dataset, DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperiment,
DatasetExperimentItem,
)
from literalai.observability.filter import (
generations_filters,
generations_order_by,
scores_filters,
scores_order_by,
steps_filters,
steps_order_by,
threads_filters,
threads_order_by,
users_filters,
)
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

from literalai.api.attachment_helpers import (
AttachmentUpload,
create_attachment_helper,
Expand All @@ -60,10 +41,12 @@
get_generations_helper,
)
from literalai.api.prompt_helpers import (
PromptRollout,
create_prompt_helper,
create_prompt_lineage_helper,
get_prompt_ab_testing_helper,
get_prompt_helper,
promote_prompt_helper,
update_prompt_ab_testing_helper,
)
from literalai.api.score_helpers import (
ScoreUpdate,
Expand Down Expand Up @@ -98,29 +81,44 @@
get_users_helper,
update_user_helper,
)
from literalai.context import active_steps_var, active_thread_var
from literalai.evaluation.dataset import Dataset, DatasetType
from literalai.evaluation.dataset_experiment import (
DatasetExperiment,
DatasetExperimentItem,
)
from literalai.observability.filter import (
generations_filters,
generations_order_by,
scores_filters,
scores_order_by,
steps_filters,
steps_order_by,
threads_filters,
threads_order_by,
users_filters,
)
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings

if TYPE_CHECKING:
from typing import Tuple # noqa: F401

import httpx

from literalai.my_types import (
Environment,
PaginatedResponse,
)
from literalai.my_types import Environment, PaginatedResponse
from literalai.observability.generation import (
GenerationMessage,
CompletionGeneration,
ChatGeneration,
CompletionGeneration,
GenerationMessage,
)
from literalai.observability.step import (
Attachment,
Score,
ScoreDict,
ScoreType,
Step,
StepDict,
StepType,
ScoreType,
ScoreDict,
Score,
Attachment,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -1365,21 +1363,33 @@ def get_prompt(
else:
raise ValueError("Either the `id` or the `name` must be provided.")

def promote_prompt(self, name: str, version: int) -> str:
def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
"""
Promotes the prompt with name to target version.
Get the A/B testing configuration for a prompt lineage.

Args:
name (str): The name of the prompt lineage.
version (int): The version number to promote.

Returns:
str: The champion prompt ID.
List[PromptRollout]
"""
lineage = self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]
return self.gql_helper(*get_prompt_ab_testing_helper(name=name))

return self.gql_helper(*promote_prompt_helper(lineage_id, version))
def update_prompt_ab_testing(
self, name: str, rollouts: List[PromptRollout]
) -> Dict:
"""
Update the A/B testing configuration for a prompt lineage.

Args:
name (str): The name of the prompt lineage.
rollouts (List[PromptRollout]): The percentage rollout for each prompt version.

Returns:
Dict
"""
return self.gql_helper(
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
)

# Misc API

Expand Down Expand Up @@ -2552,13 +2562,19 @@ async def get_prompt(

get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__

async def promote_prompt(self, name: str, version: int) -> str:
lineage = await self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]
async def update_prompt_ab_testing(
self, name: str, rollouts: List[PromptRollout]
) -> Dict:
return await self.gql_helper(
*update_prompt_ab_testing_helper(name=name, rollouts=rollouts)
)

update_prompt_ab_testing.__doc__ = LiteralAPI.update_prompt_ab_testing.__doc__

return await self.gql_helper(*promote_prompt_helper(lineage_id, version))
async def get_prompt_ab_testing(self, name: str) -> List[PromptRollout]:
return await self.gql_helper(*get_prompt_ab_testing_helper(name=name))

promote_prompt.__doc__ = LiteralAPI.promote_prompt.__doc__
get_prompt_ab_testing.__doc__ = LiteralAPI.get_prompt_ab_testing.__doc__

# Misc API

Expand Down
35 changes: 27 additions & 8 deletions literalai/api/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,16 +1041,35 @@
}
"""

PROMOTE_PROMPT_VERSION = """mutation promotePromptVersion(
$lineageId: String!
$version: Int!
GET_PROMPT_AB_TESTING = """query getPromptLineageRollout($projectId: String, $lineageName: String!) {
promptLineageRollout(projectId: $projectId, lineageName: $lineageName) {
pageInfo {
startCursor
endCursor
}
edges {
node {
version
rollout
}
}
}
}
"""

UPDATE_PROMPT_AB_TESTING = """mutation updatePromptLineageRollout(
$projectId: String
$name: String!
$rollouts: [PromptVersionRolloutInput!]!
) {
promotePromptVersion(
lineageId: $lineageId
version: $version
updatePromptLineageRollout(
projectId: $projectId
name: $name
rollouts: $rollouts
) {
id
championId
ok
message
errorCode
}
}"""

Expand Down
35 changes: 25 additions & 10 deletions literalai/api/prompt_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict

from literalai.observability.generation import GenerationMessage
from literalai.prompt_engineering.prompt import Prompt, ProviderSettings
Expand Down Expand Up @@ -61,16 +61,31 @@ def process_response(response):
return gql.GET_PROMPT_VERSION, description, variables, process_response


def promote_prompt_helper(
lineage_id: str,
version: int,
class PromptRollout(TypedDict):
version: int
rollout: int


def get_prompt_ab_testing_helper(
name: Optional[str] = None,
):
variables = {"lineageId": lineage_id, "version": version}
variables = {"lineageName": name}

def process_response(response) -> List[PromptRollout]:
response_data = response["data"]["promptLineageRollout"]
return list(map(lambda x: x["node"], response_data["edges"]))

description = "get prompt A/B testing"

return gql.GET_PROMPT_AB_TESTING, description, variables, process_response


def update_prompt_ab_testing_helper(name: str, rollouts: List[PromptRollout]):
variables = {"name": name, "rollouts": rollouts}

def process_response(response) -> str:
prompt = response["data"]["promotePromptVersion"]
return prompt["championId"] if prompt else None
def process_response(response) -> Dict:
return response["data"]["updatePromptLineageRollout"]

description = "promote prompt version"
description = "update prompt A/B testing"

return gql.PROMOTE_PROMPT_VERSION, description, variables, process_response
return gql.UPDATE_PROMPT_AB_TESTING, description, variables, process_response
10 changes: 1 addition & 9 deletions literalai/prompt_engineering/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional

from typing_extensions import deprecated, TypedDict

import chevron
from typing_extensions import TypedDict, deprecated

if TYPE_CHECKING:
from literalai.api import LiteralAPI
Expand Down Expand Up @@ -117,13 +116,6 @@ def from_dict(cls, api: "LiteralAPI", prompt_dict: PromptDict) -> "Prompt":
variables_default_values=prompt_dict.get("variablesDefaultValues"),
)

def promote(self) -> "Prompt":
"""
Promotes this prompt to champion.
"""
self.api.promote_prompt(self.name, self.version)
return self

def format_messages(self, **kwargs: Any) -> List[Any]:
"""
Formats the prompt's template messages with the given variables.
Expand Down
47 changes: 37 additions & 10 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import secrets
import time
import uuid
from typing import List

import pytest

from literalai import AsyncLiteralClient, LiteralClient
from literalai.context import active_steps_var
from literalai.observability.generation import ChatGeneration
from literalai.observability.generation import ChatGeneration, GenerationMessage
from literalai.observability.thread import Thread

"""
Expand Down Expand Up @@ -384,7 +385,6 @@ def step_decorated():
async def test_nested_run_steps(
self, client: LiteralClient, async_client: AsyncLiteralClient
):

@async_client.run(name="foo")
def run_decorated():
s = async_client.get_current_step()
Expand Down Expand Up @@ -627,16 +627,43 @@ async def test_prompt(self, async_client: AsyncLiteralClient):
assert messages[0]["content"] == expected

@pytest.mark.timeout(5)
async def test_champion_prompt(self, client: LiteralClient):
new_prompt = client.api.get_or_create_prompt(
name="Python SDK E2E Tests",
template_messages=[{"role": "user", "content": "Hello"}],
async def test_prompt_ab_testing(self, client: LiteralClient):
prompt_name = "Python SDK E2E Tests"

v0: List[GenerationMessage] = [{"role": "user", "content": "Hello"}]
v1: List[GenerationMessage] = [{"role": "user", "content": "Hello 2"}]

prompt_v0 = client.api.get_or_create_prompt(
name=prompt_name,
template_messages=v0,
)
new_prompt.promote()

prompt = client.api.get_prompt(name="Python SDK E2E Tests")
assert prompt is not None
assert prompt.version == new_prompt.version
client.api.update_prompt_ab_testing(
prompt_v0.name, rollouts=[{"version": 0, "rollout": 100}]
)

ab_testing = client.api.get_prompt_ab_testing(name=prompt_v0.name)
assert len(ab_testing) == 1
assert ab_testing[0]["version"] == 0
assert ab_testing[0]["rollout"] == 100

prompt_v1 = client.api.get_or_create_prompt(
name=prompt_name,
template_messages=v1,
)

client.api.update_prompt_ab_testing(
name=prompt_v1.name,
rollouts=[{"version": 0, "rollout": 60}, {"version": 1, "rollout": 40}],
)

ab_testing = client.api.get_prompt_ab_testing(name=prompt_v1.name)

assert len(ab_testing) == 2
assert ab_testing[0]["version"] == 0
assert ab_testing[0]["rollout"] == 60
assert ab_testing[1]["version"] == 1
assert ab_testing[1]["rollout"] == 40

@pytest.mark.timeout(5)
async def test_gracefulness(self, broken_client: LiteralClient):
Expand Down
Loading