34
34
}
35
35
36
36
37
+ # We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields
38
+ class ContentBase (TensorClass ["nocast" , "shadow" ]):
39
+ """Base class for all message content types.
40
+
41
+ Attributes:
42
+ type (str): The type of the content.
43
+ text (str, optional): The text content.
44
+ url (str, optional): The URL content.
45
+ data (str, optional): The data content.
46
+ mime_type (str, optional): The MIME type of the content.
47
+ name (str, optional): The name of the content.
48
+ size (int, optional): The size of the content.
49
+ function_name (str, optional): The name of the function.
50
+ function_args (dict, optional): The arguments of the function.
51
+
52
+ Examples:
53
+ >>> from tensordict import lazy_stack
54
+ >>> content1 = ContentBase(type="text", text="Hello, world!")
55
+ >>> print(content1)
56
+ ContentBase(
57
+ text=NonTensorData(data=Hello, world!, batch_size=torch.Size([]), device=None),
58
+ type=NonTensorData(data=text, batch_size=torch.Size([]), device=None),
59
+ url=None,
60
+ data=None,
61
+ mime_type=None,
62
+ name=None,
63
+ size=None,
64
+ function_name=None,
65
+ function_args=None,
66
+ batch_size=torch.Size([]),
67
+ device=None,
68
+ is_shared=False)
69
+ >>> content2 = ContentBase(type="image", url="https://example.com/image.jpg")
70
+ >>> print(content2)
71
+ ContentBase(
72
+ type=NonTensorData(data=image, batch_size=torch.Size([]), device=None),
73
+ url=NonTensorData(data=https://example.com/image.jpg, batch_size=torch.Size([]), device=None),
74
+ text=None,
75
+ data=None,
76
+ mime_type=None,
77
+ name=None,
78
+ size=None,
79
+ function_name=None,
80
+ function_args=None,
81
+ batch_size=torch.Size([]),
82
+ device=None,
83
+ is_shared=False)
84
+ >>> content = lazy_stack([content1, content2])
85
+ >>> print(content)
86
+ ContentBase(
87
+ type=NonTensorStack(
88
+ ['text', 'image'],
89
+ batch_size=torch.Size([2]),
90
+ device=None),
91
+ url=None,
92
+ data=None,
93
+ mime_type=None,
94
+ name=None,
95
+ size=None,
96
+ function_name=None,
97
+ function_args=None,
98
+ text=None,
99
+ batch_size=torch.Size([2]),
100
+ device=None,
101
+ is_shared=False)
102
+ >>> # A content is typically used in a History object. Usually, its batch dimension is
103
+ >>> # one dimension greater than the History object.
104
+ >>> history = History(role="user", content=content)
105
+
106
+ """
107
+
108
+ type : Literal [
109
+ "text" , "image" , "audio" , "video" , "file" , "function_call"
110
+ ] # Required: "text", "image", "audio", "video", "file", "function_call"
111
+
112
+ # Text content
113
+ text : str | None = None
114
+
115
+ # Media/file content (either URL or data)
116
+ url : str | None = None # HTTP URL to content
117
+ data : str | None = None # Base64 encoded content
118
+
119
+ # Metadata
120
+ mime_type : str | None = None # "image/jpeg", "audio/mp3", "application/pdf"
121
+ name : str | None = None # Original filename or description
122
+ size : int | None = None # File size in bytes
123
+
124
+ # Function calling (for AI agents)
125
+ function_name : str | None = None
126
+ function_args : dict | None = None
127
+
128
+
37
129
class History (TensorClass ["nocast" ]):
38
130
"""A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models.
39
131
@@ -98,7 +190,7 @@ class History(TensorClass["nocast"]):
98
190
"""
99
191
100
192
role : str
101
- content : str
193
+ content : str | ContentBase
102
194
103
195
def __post_init__ (self ):
104
196
if not list_to_stack ():
@@ -110,27 +202,29 @@ def __post_init__(self):
110
202
def apply_chat_template (
111
203
self ,
112
204
* ,
113
- tokenizer : transformers .AutoTokenizer , # noqa
205
+ tokenizer : transformers .AutoTokenizer | transformers . AutoProcessor , # noqa
114
206
add_generation_prompt : bool = True ,
115
207
chat_template : str | None = None ,
116
208
continue_final_message : bool = False ,
117
209
tokenize : bool = False ,
118
210
padding : bool | str = False ,
119
211
truncation : bool | str = False ,
120
212
return_tensors : str | None = "pt" ,
213
+ return_dict : bool = False ,
121
214
** kwargs ,
122
215
):
123
216
"""Applies a chat template to the history.
124
217
125
218
Keyword Args:
126
- tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use.
127
- add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to True.
219
+ tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor ): The tokenizer to use.
220
+ add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to ` True` .
128
221
chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
129
- continue_final_message (bool, optional): Whether to continue the final message. Defaults to False.
130
- tokenize (bool, optional): Whether to tokenize the output. Defaults to False.
131
- padding (bool | str, optional): The padding strategy to use. Defaults to False.
132
- truncation (bool | str, optional): The truncation strategy to use. Defaults to False.
222
+ continue_final_message (bool, optional): Whether to continue the final message. Defaults to ` False` .
223
+ tokenize (bool, optional): Whether to tokenize the output. Defaults to ` False` .
224
+ padding (bool | str, optional): The padding strategy to use. Defaults to ` False` .
225
+ truncation (bool | str, optional): The truncation strategy to use. Defaults to ` False` .
133
226
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
227
+ return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
134
228
**kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
135
229
136
230
Returns:
@@ -155,20 +249,24 @@ def apply_chat_template(
155
249
truncation = truncation ,
156
250
return_tensors = return_tensors ,
157
251
continue_final_message = continue_final_message ,
252
+ return_dict = return_dict ,
158
253
** kwargs ,
159
254
)
160
255
for i in range (self .batch_size [0 ])
161
256
]
162
- self_flat = self .view (- 1 ).tolist ()
257
+ self_flat = self .view (- 1 )
258
+ # tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
259
+ self_flat = self_flat .tolist (tolist_first = True )
163
260
return tokenizer .apply_chat_template (
164
- self_flat ,
261
+ conversation = self_flat ,
165
262
add_generation_prompt = add_generation_prompt ,
166
263
chat_template = chat_template ,
167
264
tokenize = tokenize ,
168
265
padding = padding ,
169
266
truncation = truncation ,
170
267
return_tensors = return_tensors ,
171
268
continue_final_message = continue_final_message ,
269
+ return_dict = return_dict ,
172
270
)
173
271
174
272
@classmethod
@@ -275,7 +373,7 @@ def append(
275
373
276
374
Args:
277
375
history (History): The new history to append.
278
- inplace (bool, optional): Whether to perform the operation in-place. Defaults to True.
376
+ inplace (bool, optional): Whether to perform the operation in-place. Defaults to ` True` .
279
377
dim (int, optional): The dimension to append along. Defaults to -1.
280
378
281
379
Returns:
0 commit comments