From 1a24f9c7685baff636bcc284343887c2232fa9e1 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Thu, 26 Sep 2024 14:59:06 +0200 Subject: [PATCH] fix: pydantic issues with langchain chat prompt and allow additional messages --- literalai/prompt_engineering/prompt.py | 33 ++++++++++++++++++-------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/literalai/prompt_engineering/prompt.py b/literalai/prompt_engineering/prompt.py index 1527d1b..d44283a 100644 --- a/literalai/prompt_engineering/prompt.py +++ b/literalai/prompt_engineering/prompt.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional import chevron +from pydantic import Field from typing_extensions import TypedDict, deprecated if TYPE_CHECKING: @@ -162,7 +163,7 @@ def format_messages(self, **kwargs: Any) -> List[Any]: def format(self, variables: Optional[Dict[str, Any]] = None) -> List[Any]: return self.format_messages(**(variables or {})) - def to_langchain_chat_prompt_template(self): + def to_langchain_chat_prompt_template(self, additional_messages=[]): try: version("langchain") except Exception: @@ -185,9 +186,11 @@ def to_langchain_chat_prompt_template(self): ) class CustomChatPromptTemplate(ChatPromptTemplate): - orig_messages: Optional[List[GenerationMessage]] - default_vars: Optional[Dict] = None - prompt_id: Optional[str] + orig_messages: Optional[List[GenerationMessage]] = Field( + default_factory=list + ) + default_vars: Optional[Dict] = Field(default_factory=dict) + prompt_id: Optional[str] = None def format_messages(self, **kwargs: Any) -> List[BaseMessage]: variables_with_defaults = { @@ -198,10 +201,19 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]: rendered_messages: List[BaseMessage] = [] for index, message in enumerate(self.messages): - template = message.prompt.template # type: ignore - content = html.unescape( - chevron.render(template, variables_with_defaults) - ) + content: str = "" + try: + prompt = getattr(message, "prompt") # type: ignore + content = html.unescape( + chevron.render(prompt.template, variables_with_defaults) + ) + except AttributeError: + for m in ChatPromptTemplate.from_messages( + [message] + ).format_messages(): + rendered_messages.append(m) + continue + additonal_kwargs = {} if self.orig_messages and index < len(self.orig_messages): additonal_kwargs = { @@ -240,8 +252,9 @@ async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]: lc_messages = [(m["role"], m["content"]) for m in self.template_messages] - chat_template = CustomChatPromptTemplate.from_messages(lc_messages) - chat_template.input_variables = ["agent_scratchpad"] + chat_template = CustomChatPromptTemplate.from_messages( + lc_messages + additional_messages + ) chat_template.default_vars = self.variables_default_values chat_template.orig_messages = self.template_messages chat_template.prompt_id = self.id