diff --git a/examples/langchain_example.py b/examples/langchain_example.py index f5b0372..511d51a 100644 --- a/examples/langchain_example.py +++ b/examples/langchain_example.py @@ -2,7 +2,7 @@ from dotenv import load_dotenv from langchain.schema import HumanMessage -from langchain_community.chat_models import ChatOpenAI +from langchain_community.chat_models import ChatOpenAI # type: ignore from literalai import LiteralClient diff --git a/literalai/api/__init__.py b/literalai/api/__init__.py index c1c7e25..286734e 100644 --- a/literalai/api/__init__.py +++ b/literalai/api/__init__.py @@ -54,6 +54,7 @@ create_prompt_helper, create_prompt_lineage_helper, get_prompt_helper, + promote_prompt_helper, ) from .score_helpers import ( ScoreUpdate, @@ -1216,9 +1217,10 @@ def add_generation_to_dataset( # Prompt API - def create_prompt_lineage(self, name: str, description: Optional[str] = None): + def get_or_create_prompt_lineage(self, name: str, description: Optional[str] = None): """ Creates a prompt lineage with the specified name and optional description. + If the prompt lineage with that name already exists, it is returned. Args: name (str): The name of the prompt lineage. @@ -1229,6 +1231,10 @@ def create_prompt_lineage(self, name: str, description: Optional[str] = None): """ return self.gql_helper(*create_prompt_lineage_helper(name, description)) + @deprecated('Please use "get_or_create_prompt_lineage" instead.') + def create_prompt_lineage(self, name: str, description: Optional[str] = None): + return self.get_or_create_prompt_lineage(name, description) + def get_or_create_prompt( self, name: str, @@ -1249,7 +1255,7 @@ def get_or_create_prompt( Returns: Prompt: The prompt that was retrieved or created. """ - lineage = self.create_prompt_lineage(name) + lineage = self.get_or_create_prompt_lineage(name) lineage_id = lineage["id"] return self.gql_helper( *create_prompt_helper(self, lineage_id, template_messages, settings, tools) @@ -1268,7 +1274,7 @@ def get_prompt( self, id: Optional[str] = None, name: Optional[str] = None, - version: Optional[int] = 0, + version: Optional[int] = None, ) -> Prompt: """ Gets a prompt either by: @@ -1292,6 +1298,23 @@ def get_prompt( return self.gql_helper(*get_prompt_helper(self, name=name, version=version)) else: raise ValueError("Either the `id` or the `name` must be provided.") + + def promote_prompt(self, name: str, version: int) -> str: + """ + Promotes the prompt with name to target version. + + Args: + name (str): The name of the prompt lineage. + version (int): The version number to promote. + + Returns: + str: The champion prompt ID. + """ + lineage = self.get_or_create_prompt_lineage(name) + lineage_id = lineage["id"] + + return self.gql_helper(*promote_prompt_helper(lineage_id, version)) + # Misc API @@ -2383,18 +2406,14 @@ async def add_generation_to_dataset( # Prompt API - async def create_prompt_lineage(self, name: str, description: Optional[str] = None): - """ - Asynchronously creates a prompt lineage. + async def get_or_create_prompt_lineage(self, name: str, description: Optional[str] = None): + return await self.gql_helper(*create_prompt_lineage_helper(name, description)) - Args: - name (str): The name of the prompt lineage. - description (Optional[str]): An optional description of the prompt lineage. + get_or_create_prompt_lineage.__doc__ = LiteralAPI.get_or_create_prompt_lineage.__doc__ - Returns: - The result of the GraphQL helper function for creating a prompt lineage. - """ - return await self.gql_helper(*create_prompt_lineage_helper(name, description)) + @deprecated('Please use "get_or_create_prompt_lineage" instead.') + async def create_prompt_lineage(self, name: str, description: Optional[str] = None): + return await self.get_or_create_prompt_lineage(name, description) async def get_or_create_prompt( self, @@ -2403,7 +2422,7 @@ async def get_or_create_prompt( settings: Optional[ProviderSettings] = None, tools: Optional[List[Dict]] = None, ) -> Prompt: - lineage = await self.create_prompt_lineage(name) + lineage = await self.get_or_create_prompt_lineage(name) lineage_id = lineage["id"] sync_api = LiteralAPI(self.api_key, self.url) @@ -2428,7 +2447,7 @@ async def get_prompt( self, id: Optional[str] = None, name: Optional[str] = None, - version: Optional[int] = 0, + version: Optional[int] = None, ) -> Prompt: sync_api = LiteralAPI(self.api_key, self.url) if id: @@ -2442,6 +2461,14 @@ async def get_prompt( get_prompt.__doc__ = LiteralAPI.get_prompt.__doc__ + async def promote_prompt(self, name: str, version: int) -> str: + lineage = await self.get_or_create_prompt_lineage(name) + lineage_id = lineage["id"] + + return await self.gql_helper(*promote_prompt_helper(lineage_id, version)) + + promote_prompt.__doc__ = LiteralAPI.promote_prompt.__doc__ + # Misc API async def get_my_project_id(self): @@ -2449,3 +2476,4 @@ async def get_my_project_id(self): return response["projectId"] get_my_project_id.__doc__ = LiteralAPI.get_my_project_id.__doc__ + diff --git a/literalai/api/gql.py b/literalai/api/gql.py index 4ffd5e5..ea0b677 100644 --- a/literalai/api/gql.py +++ b/literalai/api/gql.py @@ -1035,6 +1035,19 @@ } """ +PROMOTE_PROMPT_VERSION = """mutation promotePromptVersion( + $lineageId: String! + $version: Int! + ) { + promotePromptVersion( + lineageId: $lineageId + version: $version + ) { + id + championId + } + }""" + def serialize_step(event, id): result = {} diff --git a/literalai/api/prompt_helpers.py b/literalai/api/prompt_helpers.py index 3a71d7e..f614328 100644 --- a/literalai/api/prompt_helpers.py +++ b/literalai/api/prompt_helpers.py @@ -59,3 +59,18 @@ def process_response(response): description = "get prompt" return gql.GET_PROMPT_VERSION, description, variables, process_response + + +def promote_prompt_helper( + lineage_id: str, + version: int, +): + variables = {"lineageId": lineage_id, "version": version} + + def process_response(response) -> str: + prompt = response["data"]["promotePromptVersion"] + return prompt["championId"] if prompt else None + + description = "promote prompt version" + + return gql.PROMOTE_PROMPT_VERSION, description, variables, process_response diff --git a/literalai/prompt.py b/literalai/prompt.py index fe61763..2b95103 100644 --- a/literalai/prompt.py +++ b/literalai/prompt.py @@ -120,6 +120,13 @@ def from_dict(cls, api: "LiteralAPI", prompt_dict: PromptDict) -> "Prompt": variables_default_values=prompt_dict.get("variablesDefaultValues"), ) + def promote(self) -> "Prompt": + """ + Promotes this prompt to champion. + """ + self.api.promote_prompt(self.name, self.version) + return self + def format_messages(self, **kwargs: Any) -> List[Any]: """ Formats the prompt's template messages with the given variables. diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 750a9be..c81742b 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -560,13 +560,13 @@ async def test_dataset_sync( async def test_prompt( self, client: LiteralClient, async_client: AsyncLiteralClient ): - prompt = await async_client.api.get_prompt(name="Default") + prompt = await async_client.api.get_prompt(name="Default", version=0) assert prompt is not None assert prompt.name == "Default" assert prompt.version == 0 assert prompt.provider == "openai" - prompt = await async_client.api.get_prompt(id=prompt.id) + prompt = await async_client.api.get_prompt(id=prompt.id, version=0) assert prompt is not None messages = prompt.format_messages() @@ -593,6 +593,15 @@ async def test_prompt( assert messages[0]["content"] == expected + @pytest.mark.timeout(5) + async def test_champion_prompt(self, client: LiteralClient): + new_prompt = client.api.get_or_create_prompt(name="Default", template_messages=[{"role": "user", "content": "Hello"}]) + new_prompt.promote() + + prompt = client.api.get_prompt(name="Default") + assert prompt is not None + assert prompt.version == new_prompt.version + @pytest.mark.timeout(5) async def test_gracefulness(self, broken_client: LiteralClient): with broken_client.thread(name="Conversation"): @@ -609,7 +618,7 @@ async def test_thread_to_dict(self, client: LiteralClient): @pytest.mark.timeout(5) async def test_prompt_unique(self, client: LiteralClient): - prompt = client.api.get_prompt(name="Default") + prompt = client.api.get_prompt(name="Default", version=0) new_prompt = client.api.get_or_create_prompt( name=prompt.name,