Skip to content

fix: pydantic issues with langchain chat prompt and allow additional … #126

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions literalai/prompt_engineering/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional

import chevron
from pydantic import Field
from typing_extensions import TypedDict, deprecated

if TYPE_CHECKING:
Expand Down Expand Up @@ -162,7 +163,7 @@ def format_messages(self, **kwargs: Any) -> List[Any]:
def format(self, variables: Optional[Dict[str, Any]] = None) -> List[Any]:
return self.format_messages(**(variables or {}))

def to_langchain_chat_prompt_template(self):
def to_langchain_chat_prompt_template(self, additional_messages=[]):
try:
version("langchain")
except Exception:
Expand All @@ -185,9 +186,11 @@ def to_langchain_chat_prompt_template(self):
)

class CustomChatPromptTemplate(ChatPromptTemplate):
orig_messages: Optional[List[GenerationMessage]]
default_vars: Optional[Dict] = None
prompt_id: Optional[str]
orig_messages: Optional[List[GenerationMessage]] = Field(
default_factory=list
)
default_vars: Optional[Dict] = Field(default_factory=dict)
prompt_id: Optional[str] = None

def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
variables_with_defaults = {
Expand All @@ -198,10 +201,19 @@ def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
rendered_messages: List[BaseMessage] = []

for index, message in enumerate(self.messages):
template = message.prompt.template # type: ignore
content = html.unescape(
chevron.render(template, variables_with_defaults)
)
content: str = ""
try:
prompt = getattr(message, "prompt") # type: ignore
content = html.unescape(
chevron.render(prompt.template, variables_with_defaults)
)
except AttributeError:
for m in ChatPromptTemplate.from_messages(
[message]
).format_messages():
rendered_messages.append(m)
continue

additonal_kwargs = {}
if self.orig_messages and index < len(self.orig_messages):
additonal_kwargs = {
Expand Down Expand Up @@ -240,8 +252,9 @@ async def aformat_messages(self, **kwargs: Any) -> List[BaseMessage]:

lc_messages = [(m["role"], m["content"]) for m in self.template_messages]

chat_template = CustomChatPromptTemplate.from_messages(lc_messages)
chat_template.input_variables = ["agent_scratchpad"]
chat_template = CustomChatPromptTemplate.from_messages(
lc_messages + additional_messages
)
chat_template.default_vars = self.variables_default_values
chat_template.orig_messages = self.template_messages
chat_template.prompt_id = self.id
Expand Down
Loading