Skip to content

Commit b9fa706

Browse files
authored
feat: extend ChatPromptBuilder to support string templates (#9631)
1 parent 7414ef6 commit b9fa706

File tree

5 files changed

+1159
-49
lines changed

5 files changed

+1159
-49
lines changed

haystack/components/builders/chat_prompt_builder.py

Lines changed: 119 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import json
56
from copy import deepcopy
67
from typing import Any, Dict, List, Literal, Optional, Set, Union
78

@@ -10,15 +11,30 @@
1011

1112
from haystack import component, default_from_dict, default_to_dict, logging
1213
from haystack.dataclasses.chat_message import ChatMessage, ChatRole, TextContent
14+
from haystack.lazy_imports import LazyImport
1315
from haystack.utils import Jinja2TimeExtension
16+
from haystack.utils.jinja2_chat_extension import ChatMessageExtension, templatize_part
1417

1518
logger = logging.getLogger(__name__)
1619

20+
with LazyImport("Run 'pip install \"arrow>=1.3.0\"'") as arrow_import:
21+
import arrow # pylint: disable=unused-import
22+
23+
NO_TEXT_ERROR_MESSAGE = "ChatMessages from {role} role must contain text. Received ChatMessage with no text: {message}"
24+
25+
FILTER_NOT_ALLOWED_ERROR_MESSAGE = (
26+
"The templatize_part filter cannot be used with a template containing a list of"
27+
"ChatMessage objects. Use a string template or remove the templatize_part filter "
28+
"from the template."
29+
)
30+
1731

1832
@component
1933
class ChatPromptBuilder:
2034
"""
21-
Renders a chat prompt from a template string using Jinja2 syntax.
35+
Renders a chat prompt from a template using Jinja2 syntax.
36+
37+
A template can be a list of `ChatMessage` objects, or a special string, as shown in the usage examples.
2238
2339
It constructs prompts using static or dynamic templates, which you can update for each pipeline run.
2440
@@ -28,15 +44,15 @@ class ChatPromptBuilder:
2844
2945
### Usage examples
3046
31-
#### With static prompt template
47+
#### Static ChatMessage prompt template
3248
3349
```python
3450
template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
3551
builder = ChatPromptBuilder(template=template)
3652
builder.run(target_language="spanish", snippet="I can't speak spanish.")
3753
```
3854
39-
#### Overriding static template at runtime
55+
#### Overriding static ChatMessage template at runtime
4056
4157
```python
4258
template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
@@ -48,7 +64,7 @@ class ChatPromptBuilder:
4864
builder.run(target_language="spanish", snippet="I can't speak spanish.", template=summary_template)
4965
```
5066
51-
#### With dynamic prompt template
67+
#### Dynamic ChatMessage prompt template
5268
5369
```python
5470
from haystack.components.builders import ChatPromptBuilder
@@ -97,19 +113,42 @@ class ChatPromptBuilder:
97113
'total_tokens': 238}})]}}
98114
```
99115
116+
#### String prompt template
117+
```python
118+
from haystack.components.builders import ChatPromptBuilder
119+
from haystack.dataclasses.image_content import ImageContent
120+
121+
template = \"\"\"
122+
{% message role="system" %}
123+
You are a helpful assistant.
124+
{% endmessage %}
125+
126+
{% message role="user" %}
127+
Hello! I am {{user_name}}. What's the difference between the following images?
128+
{% for image in images %}
129+
{{ image | templatize_part }}
130+
{% endfor %}
131+
{% endmessage %}
132+
\"\"\"
133+
134+
images = [ImageContent.from_file_path("apple.jpg"), ImageContent.from_file_path("orange.jpg")]
135+
136+
builder = ChatPromptBuilder(template=template)
137+
builder.run(user_name="John", images=images)
138+
```
100139
"""
101140

102141
def __init__(
103142
self,
104-
template: Optional[List[ChatMessage]] = None,
143+
template: Optional[Union[List[ChatMessage], str]] = None,
105144
required_variables: Optional[Union[List[str], Literal["*"]]] = None,
106145
variables: Optional[List[str]] = None,
107146
):
108147
"""
109148
Constructs a ChatPromptBuilder component.
110149
111150
:param template:
112-
A list of `ChatMessage` objects. The component looks for Jinja2 template syntax and
151+
A list of `ChatMessage` objects or a string template. The component looks for Jinja2 template syntax and
113152
renders the prompt with the provided variables. Provide the template in either
114153
the `init` method` or the `run` method.
115154
:param required_variables:
@@ -123,26 +162,32 @@ def __init__(
123162
"""
124163
self._variables = variables
125164
self._required_variables = required_variables
126-
self.required_variables = required_variables or []
127165
self.template = template
128-
variables = variables or []
129-
try:
130-
# The Jinja2TimeExtension needs an optional dependency to be installed.
131-
# If it's not available we can do without it and use the ChatPromptBuilder as is.
132-
self._env = SandboxedEnvironment(extensions=[Jinja2TimeExtension])
133-
except ImportError:
134-
self._env = SandboxedEnvironment()
135166

167+
self._env = SandboxedEnvironment(extensions=[ChatMessageExtension])
168+
self._env.filters["templatize_part"] = templatize_part
169+
if arrow_import.is_successful():
170+
self._env.add_extension(Jinja2TimeExtension)
171+
172+
extracted_variables = []
136173
if template and not variables:
137-
for message in template:
138-
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
139-
# infer variables from template
140-
if message.text is None:
141-
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
142-
ast = self._env.parse(message.text)
143-
template_variables = meta.find_undeclared_variables(ast)
144-
variables += list(template_variables)
145-
self.variables = variables
174+
if isinstance(template, list):
175+
for message in template:
176+
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
177+
# infer variables from template
178+
if message.text is None:
179+
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message))
180+
if message.text and "templatize_part" in message.text:
181+
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
182+
ast = self._env.parse(message.text)
183+
template_variables = meta.find_undeclared_variables(ast)
184+
extracted_variables += list(template_variables)
185+
elif isinstance(template, str):
186+
ast = self._env.parse(template)
187+
extracted_variables = list(meta.find_undeclared_variables(ast))
188+
189+
self.variables = variables or extracted_variables
190+
self.required_variables = required_variables or []
146191

147192
if len(self.variables) > 0 and required_variables is None:
148193
logger.warning(
@@ -163,7 +208,7 @@ def __init__(
163208
@component.output_types(prompt=List[ChatMessage])
164209
def run(
165210
self,
166-
template: Optional[List[ChatMessage]] = None,
211+
template: Optional[Union[List[ChatMessage], str]] = None,
167212
template_variables: Optional[Dict[str, Any]] = None,
168213
**kwargs,
169214
):
@@ -175,7 +220,8 @@ def run(
175220
To overwrite pipeline kwargs, you can set the `template_variables` parameter.
176221
177222
:param template:
178-
An optional list of `ChatMessage` objects to overwrite ChatPromptBuilder's default template.
223+
An optional list of `ChatMessage` objects or string template to overwrite ChatPromptBuilder's default
224+
template.
179225
If `None`, the default template provided at initialization is used.
180226
:param template_variables:
181227
An optional dictionary of template variables to overwrite the pipeline variables.
@@ -200,30 +246,56 @@ def run(
200246
f"Please provide a valid list of ChatMessage instances to render the prompt."
201247
)
202248

203-
if not all(isinstance(message, ChatMessage) for message in template):
249+
if isinstance(template, list) and not all(isinstance(message, ChatMessage) for message in template):
204250
raise ValueError(
205251
f"The {self.__class__.__name__} expects a list containing only ChatMessage instances. "
206252
f"The provided list contains other types. Please ensure that all elements in the list "
207253
f"are ChatMessage instances."
208254
)
209255

210256
processed_messages = []
211-
for message in template:
212-
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
213-
self._validate_variables(set(template_variables_combined.keys()))
214-
if message.text is None:
215-
raise ValueError(f"The provided ChatMessage has no text. ChatMessage: {message}")
216-
compiled_template = self._env.from_string(message.text)
217-
rendered_text = compiled_template.render(template_variables_combined)
218-
# deep copy the message to avoid modifying the original message
219-
rendered_message: ChatMessage = deepcopy(message)
220-
rendered_message._content = [TextContent(text=rendered_text)]
221-
processed_messages.append(rendered_message)
222-
else:
223-
processed_messages.append(message)
257+
if isinstance(template, list):
258+
for message in template:
259+
if message.is_from(ChatRole.USER) or message.is_from(ChatRole.SYSTEM):
260+
self._validate_variables(set(template_variables_combined.keys()))
261+
if message.text is None:
262+
raise ValueError(NO_TEXT_ERROR_MESSAGE.format(role=message.role.value, message=message))
263+
if message.text and "templatize_part" in message.text:
264+
raise ValueError(FILTER_NOT_ALLOWED_ERROR_MESSAGE)
265+
compiled_template = self._env.from_string(message.text)
266+
rendered_text = compiled_template.render(template_variables_combined)
267+
# deep copy the message to avoid modifying the original message
268+
rendered_message: ChatMessage = deepcopy(message)
269+
rendered_message._content = [TextContent(text=rendered_text)]
270+
processed_messages.append(rendered_message)
271+
else:
272+
processed_messages.append(message)
273+
elif isinstance(template, str):
274+
self._validate_variables(set(template_variables_combined.keys()))
275+
processed_messages = self._render_chat_messages_from_str_template(template, template_variables_combined)
224276

225277
return {"prompt": processed_messages}
226278

279+
def _render_chat_messages_from_str_template(
280+
self, template: str, template_variables: Dict[str, Any]
281+
) -> List[ChatMessage]:
282+
"""
283+
Renders a chat message from a string template.
284+
285+
This must be used in conjunction with the `ChatMessageExtension` Jinja2 extension
286+
and the `templatize_part` filter.
287+
"""
288+
compiled_template = self._env.from_string(template)
289+
rendered = compiled_template.render(template_variables)
290+
291+
messages = []
292+
for line in rendered.strip().split("\n"):
293+
line = line.strip()
294+
if line:
295+
messages.append(ChatMessage.from_dict(json.loads(line)))
296+
297+
return messages
298+
227299
def _validate_variables(self, provided_variables: Set[str]):
228300
"""
229301
Checks if all the required template variables are provided.
@@ -252,10 +324,11 @@ def to_dict(self) -> Dict[str, Any]:
252324
:returns:
253325
Serialized dictionary representation of the component.
254326
"""
255-
if self.template is not None:
327+
template: Optional[Union[List[Dict[str, Any]], str]] = None
328+
if isinstance(self.template, list):
256329
template = [m.to_dict() for m in self.template]
257-
else:
258-
template = None
330+
elif isinstance(self.template, str):
331+
template = self.template
259332

260333
return default_to_dict(
261334
self, template=template, variables=self._variables, required_variables=self._required_variables
@@ -275,6 +348,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatPromptBuilder":
275348
init_parameters = data["init_parameters"]
276349
template = init_parameters.get("template")
277350
if template:
278-
init_parameters["template"] = [ChatMessage.from_dict(d) for d in template]
351+
if isinstance(template, list):
352+
init_parameters["template"] = [ChatMessage.from_dict(d) for d in template]
353+
elif isinstance(template, str):
354+
init_parameters["template"] = template
279355

280356
return default_from_dict(cls, data)

0 commit comments

Comments
 (0)