2
2
#
3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
+ import json
5
6
from copy import deepcopy
6
7
from typing import Any , Dict , List , Literal , Optional , Set , Union
7
8
10
11
11
12
from haystack import component , default_from_dict , default_to_dict , logging
12
13
from haystack .dataclasses .chat_message import ChatMessage , ChatRole , TextContent
14
+ from haystack .lazy_imports import LazyImport
13
15
from haystack .utils import Jinja2TimeExtension
16
+ from haystack .utils .jinja2_chat_extension import ChatMessageExtension , templatize_part
14
17
15
18
logger = logging .getLogger (__name__ )
16
19
20
+ with LazyImport ("Run 'pip install \" arrow>=1.3.0\" '" ) as arrow_import :
21
+ import arrow # pylint: disable=unused-import
22
+
23
+ NO_TEXT_ERROR_MESSAGE = "ChatMessages from {role} role must contain text. Received ChatMessage with no text: {message}"
24
+
25
+ FILTER_NOT_ALLOWED_ERROR_MESSAGE = (
26
+ "The templatize_part filter cannot be used with a template containing a list of"
27
+ "ChatMessage objects. Use a string template or remove the templatize_part filter "
28
+ "from the template."
29
+ )
30
+
17
31
18
32
@component
19
33
class ChatPromptBuilder :
20
34
"""
21
- Renders a chat prompt from a template string using Jinja2 syntax.
35
+ Renders a chat prompt from a template using Jinja2 syntax.
36
+
37
+ A template can be a list of `ChatMessage` objects, or a special string, as shown in the usage examples.
22
38
23
39
It constructs prompts using static or dynamic templates, which you can update for each pipeline run.
24
40
@@ -28,15 +44,15 @@ class ChatPromptBuilder:
28
44
29
45
### Usage examples
30
46
31
- #### With static prompt template
47
+ #### Static ChatMessage prompt template
32
48
33
49
```python
34
50
template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
35
51
builder = ChatPromptBuilder(template=template)
36
52
builder.run(target_language="spanish", snippet="I can't speak spanish.")
37
53
```
38
54
39
- #### Overriding static template at runtime
55
+ #### Overriding static ChatMessage template at runtime
40
56
41
57
```python
42
58
template = [ChatMessage.from_user("Translate to {{ target_language }}. Context: {{ snippet }}; Translation:")]
@@ -48,7 +64,7 @@ class ChatPromptBuilder:
48
64
builder.run(target_language="spanish", snippet="I can't speak spanish.", template=summary_template)
49
65
```
50
66
51
- #### With dynamic prompt template
67
+ #### Dynamic ChatMessage prompt template
52
68
53
69
```python
54
70
from haystack.components.builders import ChatPromptBuilder
@@ -97,19 +113,42 @@ class ChatPromptBuilder:
97
113
'total_tokens': 238}})]}}
98
114
```
99
115
116
+ #### String prompt template
117
+ ```python
118
+ from haystack.components.builders import ChatPromptBuilder
119
+ from haystack.dataclasses.image_content import ImageContent
120
+
121
+ template = \" \" \"
122
+ {% message role="system" %}
123
+ You are a helpful assistant.
124
+ {% endmessage %}
125
+
126
+ {% message role="user" %}
127
+ Hello! I am {{user_name}}. What's the difference between the following images?
128
+ {% for image in images %}
129
+ {{ image | templatize_part }}
130
+ {% endfor %}
131
+ {% endmessage %}
132
+ \" \" \"
133
+
134
+ images = [ImageContent.from_file_path("apple.jpg"), ImageContent.from_file_path("orange.jpg")]
135
+
136
+ builder = ChatPromptBuilder(template=template)
137
+ builder.run(user_name="John", images=images)
138
+ ```
100
139
"""
101
140
102
141
def __init__ (
103
142
self ,
104
- template : Optional [List [ChatMessage ]] = None ,
143
+ template : Optional [Union [ List [ChatMessage ], str ]] = None ,
105
144
required_variables : Optional [Union [List [str ], Literal ["*" ]]] = None ,
106
145
variables : Optional [List [str ]] = None ,
107
146
):
108
147
"""
109
148
Constructs a ChatPromptBuilder component.
110
149
111
150
:param template:
112
- A list of `ChatMessage` objects. The component looks for Jinja2 template syntax and
151
+ A list of `ChatMessage` objects or a string template . The component looks for Jinja2 template syntax and
113
152
renders the prompt with the provided variables. Provide the template in either
114
153
the `init` method` or the `run` method.
115
154
:param required_variables:
@@ -123,26 +162,32 @@ def __init__(
123
162
"""
124
163
self ._variables = variables
125
164
self ._required_variables = required_variables
126
- self .required_variables = required_variables or []
127
165
self .template = template
128
- variables = variables or []
129
- try :
130
- # The Jinja2TimeExtension needs an optional dependency to be installed.
131
- # If it's not available we can do without it and use the ChatPromptBuilder as is.
132
- self ._env = SandboxedEnvironment (extensions = [Jinja2TimeExtension ])
133
- except ImportError :
134
- self ._env = SandboxedEnvironment ()
135
166
167
+ self ._env = SandboxedEnvironment (extensions = [ChatMessageExtension ])
168
+ self ._env .filters ["templatize_part" ] = templatize_part
169
+ if arrow_import .is_successful ():
170
+ self ._env .add_extension (Jinja2TimeExtension )
171
+
172
+ extracted_variables = []
136
173
if template and not variables :
137
- for message in template :
138
- if message .is_from (ChatRole .USER ) or message .is_from (ChatRole .SYSTEM ):
139
- # infer variables from template
140
- if message .text is None :
141
- raise ValueError (f"The provided ChatMessage has no text. ChatMessage: { message } " )
142
- ast = self ._env .parse (message .text )
143
- template_variables = meta .find_undeclared_variables (ast )
144
- variables += list (template_variables )
145
- self .variables = variables
174
+ if isinstance (template , list ):
175
+ for message in template :
176
+ if message .is_from (ChatRole .USER ) or message .is_from (ChatRole .SYSTEM ):
177
+ # infer variables from template
178
+ if message .text is None :
179
+ raise ValueError (NO_TEXT_ERROR_MESSAGE .format (role = message .role .value , message = message ))
180
+ if message .text and "templatize_part" in message .text :
181
+ raise ValueError (FILTER_NOT_ALLOWED_ERROR_MESSAGE )
182
+ ast = self ._env .parse (message .text )
183
+ template_variables = meta .find_undeclared_variables (ast )
184
+ extracted_variables += list (template_variables )
185
+ elif isinstance (template , str ):
186
+ ast = self ._env .parse (template )
187
+ extracted_variables = list (meta .find_undeclared_variables (ast ))
188
+
189
+ self .variables = variables or extracted_variables
190
+ self .required_variables = required_variables or []
146
191
147
192
if len (self .variables ) > 0 and required_variables is None :
148
193
logger .warning (
@@ -163,7 +208,7 @@ def __init__(
163
208
@component .output_types (prompt = List [ChatMessage ])
164
209
def run (
165
210
self ,
166
- template : Optional [List [ChatMessage ]] = None ,
211
+ template : Optional [Union [ List [ChatMessage ], str ]] = None ,
167
212
template_variables : Optional [Dict [str , Any ]] = None ,
168
213
** kwargs ,
169
214
):
@@ -175,7 +220,8 @@ def run(
175
220
To overwrite pipeline kwargs, you can set the `template_variables` parameter.
176
221
177
222
:param template:
178
- An optional list of `ChatMessage` objects to overwrite ChatPromptBuilder's default template.
223
+ An optional list of `ChatMessage` objects or string template to overwrite ChatPromptBuilder's default
224
+ template.
179
225
If `None`, the default template provided at initialization is used.
180
226
:param template_variables:
181
227
An optional dictionary of template variables to overwrite the pipeline variables.
@@ -200,30 +246,56 @@ def run(
200
246
f"Please provide a valid list of ChatMessage instances to render the prompt."
201
247
)
202
248
203
- if not all (isinstance (message , ChatMessage ) for message in template ):
249
+ if isinstance ( template , list ) and not all (isinstance (message , ChatMessage ) for message in template ):
204
250
raise ValueError (
205
251
f"The { self .__class__ .__name__ } expects a list containing only ChatMessage instances. "
206
252
f"The provided list contains other types. Please ensure that all elements in the list "
207
253
f"are ChatMessage instances."
208
254
)
209
255
210
256
processed_messages = []
211
- for message in template :
212
- if message .is_from (ChatRole .USER ) or message .is_from (ChatRole .SYSTEM ):
213
- self ._validate_variables (set (template_variables_combined .keys ()))
214
- if message .text is None :
215
- raise ValueError (f"The provided ChatMessage has no text. ChatMessage: { message } " )
216
- compiled_template = self ._env .from_string (message .text )
217
- rendered_text = compiled_template .render (template_variables_combined )
218
- # deep copy the message to avoid modifying the original message
219
- rendered_message : ChatMessage = deepcopy (message )
220
- rendered_message ._content = [TextContent (text = rendered_text )]
221
- processed_messages .append (rendered_message )
222
- else :
223
- processed_messages .append (message )
257
+ if isinstance (template , list ):
258
+ for message in template :
259
+ if message .is_from (ChatRole .USER ) or message .is_from (ChatRole .SYSTEM ):
260
+ self ._validate_variables (set (template_variables_combined .keys ()))
261
+ if message .text is None :
262
+ raise ValueError (NO_TEXT_ERROR_MESSAGE .format (role = message .role .value , message = message ))
263
+ if message .text and "templatize_part" in message .text :
264
+ raise ValueError (FILTER_NOT_ALLOWED_ERROR_MESSAGE )
265
+ compiled_template = self ._env .from_string (message .text )
266
+ rendered_text = compiled_template .render (template_variables_combined )
267
+ # deep copy the message to avoid modifying the original message
268
+ rendered_message : ChatMessage = deepcopy (message )
269
+ rendered_message ._content = [TextContent (text = rendered_text )]
270
+ processed_messages .append (rendered_message )
271
+ else :
272
+ processed_messages .append (message )
273
+ elif isinstance (template , str ):
274
+ self ._validate_variables (set (template_variables_combined .keys ()))
275
+ processed_messages = self ._render_chat_messages_from_str_template (template , template_variables_combined )
224
276
225
277
return {"prompt" : processed_messages }
226
278
279
+ def _render_chat_messages_from_str_template (
280
+ self , template : str , template_variables : Dict [str , Any ]
281
+ ) -> List [ChatMessage ]:
282
+ """
283
+ Renders a chat message from a string template.
284
+
285
+ This must be used in conjunction with the `ChatMessageExtension` Jinja2 extension
286
+ and the `templatize_part` filter.
287
+ """
288
+ compiled_template = self ._env .from_string (template )
289
+ rendered = compiled_template .render (template_variables )
290
+
291
+ messages = []
292
+ for line in rendered .strip ().split ("\n " ):
293
+ line = line .strip ()
294
+ if line :
295
+ messages .append (ChatMessage .from_dict (json .loads (line )))
296
+
297
+ return messages
298
+
227
299
def _validate_variables (self , provided_variables : Set [str ]):
228
300
"""
229
301
Checks if all the required template variables are provided.
@@ -252,10 +324,11 @@ def to_dict(self) -> Dict[str, Any]:
252
324
:returns:
253
325
Serialized dictionary representation of the component.
254
326
"""
255
- if self .template is not None :
327
+ template : Optional [Union [List [Dict [str , Any ]], str ]] = None
328
+ if isinstance (self .template , list ):
256
329
template = [m .to_dict () for m in self .template ]
257
- else :
258
- template = None
330
+ elif isinstance ( self . template , str ) :
331
+ template = self . template
259
332
260
333
return default_to_dict (
261
334
self , template = template , variables = self ._variables , required_variables = self ._required_variables
@@ -275,6 +348,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatPromptBuilder":
275
348
init_parameters = data ["init_parameters" ]
276
349
template = init_parameters .get ("template" )
277
350
if template :
278
- init_parameters ["template" ] = [ChatMessage .from_dict (d ) for d in template ]
351
+ if isinstance (template , list ):
352
+ init_parameters ["template" ] = [ChatMessage .from_dict (d ) for d in template ]
353
+ elif isinstance (template , str ):
354
+ init_parameters ["template" ] = template
279
355
280
356
return default_from_dict (cls , data )
0 commit comments