Skip to content

fix: default prompt version to None #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/langchain_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dotenv import load_dotenv
from langchain.schema import HumanMessage
from langchain_community.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI # type: ignore

from literalai import LiteralClient

Expand Down
58 changes: 43 additions & 15 deletions literalai/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
create_prompt_helper,
create_prompt_lineage_helper,
get_prompt_helper,
promote_prompt_helper,
)
from .score_helpers import (
ScoreUpdate,
Expand Down Expand Up @@ -1216,9 +1217,10 @@ def add_generation_to_dataset(

# Prompt API

def create_prompt_lineage(self, name: str, description: Optional[str] = None):
def get_or_create_prompt_lineage(self, name: str, description: Optional[str] = None):
"""
Creates a prompt lineage with the specified name and optional description.
If the prompt lineage with that name already exists, it is returned.

Args:
name (str): The name of the prompt lineage.
Expand All @@ -1229,6 +1231,10 @@ def create_prompt_lineage(self, name: str, description: Optional[str] = None):
"""
return self.gql_helper(*create_prompt_lineage_helper(name, description))

@deprecated('Please use "get_or_create_prompt_lineage" instead.')
def create_prompt_lineage(self, name: str, description: Optional[str] = None):
return self.get_or_create_prompt_lineage(name, description)

def get_or_create_prompt(
self,
name: str,
Expand All @@ -1249,7 +1255,7 @@ def get_or_create_prompt(
Returns:
Prompt: The prompt that was retrieved or created.
"""
lineage = self.create_prompt_lineage(name)
lineage = self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]
return self.gql_helper(
*create_prompt_helper(self, lineage_id, template_messages, settings, tools)
Expand All @@ -1268,7 +1274,7 @@ def get_prompt(
self,
id: Optional[str] = None,
name: Optional[str] = None,
version: Optional[int] = 0,
version: Optional[int] = None,
) -> Prompt:
"""
Gets a prompt either by:
Expand All @@ -1292,6 +1298,23 @@ def get_prompt(
return self.gql_helper(*get_prompt_helper(self, name=name, version=version))
else:
raise ValueError("Either the `id` or the `name` must be provided.")

def promote_prompt(self, name: str, version: int) -> str:
"""
Promotes the prompt with name to target version.

Args:
name (str): The name of the prompt lineage.
version (int): The version number to promote.

Returns:
str: The champion prompt ID.
"""
lineage = self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]

return self.gql_helper(*promote_prompt_helper(lineage_id, version))


# Misc API

Expand Down Expand Up @@ -2383,18 +2406,14 @@ async def add_generation_to_dataset(

# Prompt API

async def create_prompt_lineage(self, name: str, description: Optional[str] = None):
"""
Asynchronously creates a prompt lineage.
async def get_or_create_prompt_lineage(self, name: str, description: Optional[str] = None):
return await self.gql_helper(*create_prompt_lineage_helper(name, description))

Args:
name (str): The name of the prompt lineage.
description (Optional[str]): An optional description of the prompt lineage.
get_or_create_prompt_lineage.__doc__ = LiteralAPI.get_or_create_prompt_lineage.__doc__

Returns:
The result of the GraphQL helper function for creating a prompt lineage.
"""
return await self.gql_helper(*create_prompt_lineage_helper(name, description))
@deprecated('Please use "get_or_create_prompt_lineage" instead.')
async def create_prompt_lineage(self, name: str, description: Optional[str] = None):
return await self.get_or_create_prompt_lineage(name, description)

async def get_or_create_prompt(
self,
Expand All @@ -2403,7 +2422,7 @@ async def get_or_create_prompt(
settings: Optional[ProviderSettings] = None,
tools: Optional[List[Dict]] = None,
) -> Prompt:
lineage = await self.create_prompt_lineage(name)
lineage = await self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]

sync_api = LiteralAPI(self.api_key, self.url)
Expand All @@ -2428,7 +2447,7 @@ async def get_prompt(
self,
id: Optional[str] = None,
name: Optional[str] = None,
version: Optional[int] = 0,
version: Optional[int] = None,
) -> Prompt:
sync_api = LiteralAPI(self.api_key, self.url)
if id:
Expand All @@ -2442,10 +2461,19 @@ async def get_prompt(

get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__

async def promote_prompt(self, name: str, version: int) -> str:
lineage = await self.get_or_create_prompt_lineage(name)
lineage_id = lineage["id"]

return await self.gql_helper(*promote_prompt_helper(lineage_id, version))

promote_prompt.__doc__ = LiteralAPI.promote_prompt.__doc__

# Misc API

async def get_my_project_id(self):
response = await self.make_rest_call("/my-project", {})
return response["projectId"]

get_my_project_id.__doc__ = LiteralAPI.get_my_project_id.__doc__

13 changes: 13 additions & 0 deletions literalai/api/gql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,19 @@
}
"""

PROMOTE_PROMPT_VERSION = """mutation promotePromptVersion(
$lineageId: String!
$version: Int!
) {
promotePromptVersion(
lineageId: $lineageId
version: $version
) {
id
championId
}
}"""


def serialize_step(event, id):
result = {}
Expand Down
15 changes: 15 additions & 0 deletions literalai/api/prompt_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,18 @@ def process_response(response):
description = "get prompt"

return gql.GET_PROMPT_VERSION, description, variables, process_response


def promote_prompt_helper(
lineage_id: str,
version: int,
):
variables = {"lineageId": lineage_id, "version": version}

def process_response(response) -> str:
prompt = response["data"]["promotePromptVersion"]
return prompt["championId"] if prompt else None

description = "promote prompt version"

return gql.PROMOTE_PROMPT_VERSION, description, variables, process_response
7 changes: 7 additions & 0 deletions literalai/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def from_dict(cls, api: "LiteralAPI", prompt_dict: PromptDict) -> "Prompt":
variables_default_values=prompt_dict.get("variablesDefaultValues"),
)

def promote(self) -> "Prompt":
"""
Promotes this prompt to champion.
"""
self.api.promote_prompt(self.name, self.version)
return self

def format_messages(self, **kwargs: Any) -> List[Any]:
"""
Formats the prompt's template messages with the given variables.
Expand Down
15 changes: 12 additions & 3 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,13 +560,13 @@ async def test_dataset_sync(
async def test_prompt(
self, client: LiteralClient, async_client: AsyncLiteralClient
):
prompt = await async_client.api.get_prompt(name="Default")
prompt = await async_client.api.get_prompt(name="Default", version=0)
assert prompt is not None
assert prompt.name == "Default"
assert prompt.version == 0
assert prompt.provider == "openai"

prompt = await async_client.api.get_prompt(id=prompt.id)
prompt = await async_client.api.get_prompt(id=prompt.id, version=0)
assert prompt is not None

messages = prompt.format_messages()
Expand All @@ -593,6 +593,15 @@ async def test_prompt(

assert messages[0]["content"] == expected

@pytest.mark.timeout(5)
async def test_champion_prompt(self, client: LiteralClient):
new_prompt = client.api.get_or_create_prompt(name="Default", template_messages=[{"role": "user", "content": "Hello"}])
new_prompt.promote()

prompt = client.api.get_prompt(name="Default")
assert prompt is not None
assert prompt.version == new_prompt.version

@pytest.mark.timeout(5)
async def test_gracefulness(self, broken_client: LiteralClient):
with broken_client.thread(name="Conversation"):
Expand All @@ -609,7 +618,7 @@ async def test_thread_to_dict(self, client: LiteralClient):

@pytest.mark.timeout(5)
async def test_prompt_unique(self, client: LiteralClient):
prompt = client.api.get_prompt(name="Default")
prompt = client.api.get_prompt(name="Default", version=0)

new_prompt = client.api.get_or_create_prompt(
name=prompt.name,
Expand Down
Loading