Skip to content

Commit 840101d

Browse files
authored
[Feature] ContentBase (#2985)
1 parent ba0faef commit 840101d

File tree

10 files changed

+167
-20
lines changed

10 files changed

+167
-20
lines changed

test/llm/test_data.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77

88
import argparse
99
import importlib.util
10+
from typing import Mapping
1011

1112
import pytest
1213
import torch
13-
from tensordict import set_list_to_stack
14+
from tensordict import lazy_stack, set_list_to_stack
1415

1516
from torchrl.data import History
17+
from torchrl.data.llm.chat import ContentBase
1618

1719
_has_transformers = importlib.util.find_spec("transformers")
1820
_has_vllm = importlib.util.find_spec("vllm")
@@ -216,6 +218,53 @@ def test_history_spec(self):
216218
assert spec.is_in(r)
217219
assert spec.is_in(history)
218220

221+
def test_content_base(self):
222+
from transformers import AutoProcessor
223+
224+
processor = AutoProcessor.from_pretrained(
225+
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf"
226+
)
227+
228+
content_text = ContentBase(type="text", text="Hello, world!")
229+
content_img = ContentBase(
230+
type="image",
231+
url="https://github.com/pytorch/rl/blob/main/docs/source/_static/img/icon.png?raw=true",
232+
)
233+
content = lazy_stack([content_text, content_img])
234+
history0 = History(
235+
role="assistant",
236+
content=ContentBase(
237+
type="text",
238+
text="You are going to see an image and a hello world message. Ignore both.",
239+
batch_size=1,
240+
),
241+
)
242+
history1 = History(role="user", content=content)
243+
history = lazy_stack([history0, history1])
244+
proc = history.apply_chat_template(
245+
tokenizer=processor,
246+
add_generation_prompt=False,
247+
return_dict=True,
248+
tokenize=False,
249+
)
250+
assert (
251+
proc
252+
== "<|im_start|>assistant \nYou are going to see an image and a hello world message. Ignore both.<|im_end|><|im_start|>user <image>\nHello, world!<|im_end|>"
253+
)
254+
proc = history.apply_chat_template(
255+
tokenizer=processor,
256+
add_generation_prompt=False,
257+
return_dict=True,
258+
tokenize=True,
259+
)
260+
assert isinstance(proc, Mapping)
261+
assert proc["input_ids"].shape == (1, 7294)
262+
assert proc["attention_mask"].shape == (1, 7294)
263+
assert proc["pixel_values"].shape == (1, 37, 3, 384, 384), proc[
264+
"pixel_values"
265+
].shape
266+
assert (proc["image_sizes"] == torch.tensor([[2096, 2324]])).all()
267+
219268

220269
if __name__ == "__main__":
221270
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/llm/chat.py

Lines changed: 109 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,98 @@
3434
}
3535

3636

37+
# We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields
38+
class ContentBase(TensorClass["nocast", "shadow"]):
39+
"""Base class for all message content types.
40+
41+
Attributes:
42+
type (str): The type of the content.
43+
text (str, optional): The text content.
44+
url (str, optional): The URL content.
45+
data (str, optional): The data content.
46+
mime_type (str, optional): The MIME type of the content.
47+
name (str, optional): The name of the content.
48+
size (int, optional): The size of the content.
49+
function_name (str, optional): The name of the function.
50+
function_args (dict, optional): The arguments of the function.
51+
52+
Examples:
53+
>>> from tensordict import lazy_stack
54+
>>> content1 = ContentBase(type="text", text="Hello, world!")
55+
>>> print(content1)
56+
ContentBase(
57+
text=NonTensorData(data=Hello, world!, batch_size=torch.Size([]), device=None),
58+
type=NonTensorData(data=text, batch_size=torch.Size([]), device=None),
59+
url=None,
60+
data=None,
61+
mime_type=None,
62+
name=None,
63+
size=None,
64+
function_name=None,
65+
function_args=None,
66+
batch_size=torch.Size([]),
67+
device=None,
68+
is_shared=False)
69+
>>> content2 = ContentBase(type="image", url="https://example.com/image.jpg")
70+
>>> print(content2)
71+
ContentBase(
72+
type=NonTensorData(data=image, batch_size=torch.Size([]), device=None),
73+
url=NonTensorData(data=https://example.com/image.jpg, batch_size=torch.Size([]), device=None),
74+
text=None,
75+
data=None,
76+
mime_type=None,
77+
name=None,
78+
size=None,
79+
function_name=None,
80+
function_args=None,
81+
batch_size=torch.Size([]),
82+
device=None,
83+
is_shared=False)
84+
>>> content = lazy_stack([content1, content2])
85+
>>> print(content)
86+
ContentBase(
87+
type=NonTensorStack(
88+
['text', 'image'],
89+
batch_size=torch.Size([2]),
90+
device=None),
91+
url=None,
92+
data=None,
93+
mime_type=None,
94+
name=None,
95+
size=None,
96+
function_name=None,
97+
function_args=None,
98+
text=None,
99+
batch_size=torch.Size([2]),
100+
device=None,
101+
is_shared=False)
102+
>>> # A content is typically used in a History object. Usually, its batch dimension is
103+
>>> # one dimension greater than the History object.
104+
>>> history = History(role="user", content=content)
105+
106+
"""
107+
108+
type: Literal[
109+
"text", "image", "audio", "video", "file", "function_call"
110+
] # Required: "text", "image", "audio", "video", "file", "function_call"
111+
112+
# Text content
113+
text: str | None = None
114+
115+
# Media/file content (either URL or data)
116+
url: str | None = None # HTTP URL to content
117+
data: str | None = None # Base64 encoded content
118+
119+
# Metadata
120+
mime_type: str | None = None # "image/jpeg", "audio/mp3", "application/pdf"
121+
name: str | None = None # Original filename or description
122+
size: int | None = None # File size in bytes
123+
124+
# Function calling (for AI agents)
125+
function_name: str | None = None
126+
function_args: dict | None = None
127+
128+
37129
class History(TensorClass["nocast"]):
38130
"""A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models.
39131
@@ -98,7 +190,7 @@ class History(TensorClass["nocast"]):
98190
"""
99191

100192
role: str
101-
content: str
193+
content: str | ContentBase
102194

103195
def __post_init__(self):
104196
if not list_to_stack():
@@ -110,27 +202,29 @@ def __post_init__(self):
110202
def apply_chat_template(
111203
self,
112204
*,
113-
tokenizer: transformers.AutoTokenizer, # noqa
205+
tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa
114206
add_generation_prompt: bool = True,
115207
chat_template: str | None = None,
116208
continue_final_message: bool = False,
117209
tokenize: bool = False,
118210
padding: bool | str = False,
119211
truncation: bool | str = False,
120212
return_tensors: str | None = "pt",
213+
return_dict: bool = False,
121214
**kwargs,
122215
):
123216
"""Applies a chat template to the history.
124217
125218
Keyword Args:
126-
tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use.
127-
add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to True.
219+
tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
220+
add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to `True`.
128221
chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
129-
continue_final_message (bool, optional): Whether to continue the final message. Defaults to False.
130-
tokenize (bool, optional): Whether to tokenize the output. Defaults to False.
131-
padding (bool | str, optional): The padding strategy to use. Defaults to False.
132-
truncation (bool | str, optional): The truncation strategy to use. Defaults to False.
222+
continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`.
223+
tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`.
224+
padding (bool | str, optional): The padding strategy to use. Defaults to `False`.
225+
truncation (bool | str, optional): The truncation strategy to use. Defaults to `False`.
133226
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
227+
return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
134228
**kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
135229
136230
Returns:
@@ -155,20 +249,24 @@ def apply_chat_template(
155249
truncation=truncation,
156250
return_tensors=return_tensors,
157251
continue_final_message=continue_final_message,
252+
return_dict=return_dict,
158253
**kwargs,
159254
)
160255
for i in range(self.batch_size[0])
161256
]
162-
self_flat = self.view(-1).tolist()
257+
self_flat = self.view(-1)
258+
# tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
259+
self_flat = self_flat.tolist(tolist_first=True)
163260
return tokenizer.apply_chat_template(
164-
self_flat,
261+
conversation=self_flat,
165262
add_generation_prompt=add_generation_prompt,
166263
chat_template=chat_template,
167264
tokenize=tokenize,
168265
padding=padding,
169266
truncation=truncation,
170267
return_tensors=return_tensors,
171268
continue_final_message=continue_final_message,
269+
return_dict=return_dict,
172270
)
173271

174272
@classmethod
@@ -275,7 +373,7 @@ def append(
275373
276374
Args:
277375
history (History): The new history to append.
278-
inplace (bool, optional): Whether to perform the operation in-place. Defaults to True.
376+
inplace (bool, optional): Whether to perform the operation in-place. Defaults to `True`.
279377
dim (int, optional): The dimension to append along. Defaults to -1.
280378
281379
Returns:

torchrl/envs/custom/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LLMHashingEnv(EnvBase):
4343
observation_key (NestedKey, optional): The key for the observation in the TensorDict.
4444
Defaults to "observation".
4545
text_output (bool, optional): Whether to include the text output in the observation.
46-
Defaults to True.
46+
Defaults to `True`.
4747
tokenizer (transformers.Tokenizer | None, optional):
4848
A tokenizer function that converts text to tensors.
4949
Only used when `text_output` is `True`.

torchrl/envs/llm/envs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ class LLMHashingEnv(EnvBase):
602602
observation_key (NestedKey, optional): The key for the observation in the TensorDict.
603603
Defaults to "observation".
604604
text_output (bool, optional): Whether to include the text output in the observation.
605-
Defaults to True.
605+
Defaults to `True`.
606606
tokenizer (transformers.Tokenizer | None, optional):
607607
A tokenizer function that converts text to tensors.
608608
Only used when `text_output` is `True`.

torchrl/envs/llm/transforms/dataloading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class DataLoadingPrimer(TensorDictPrimer):
234234
... Args:
235235
... batch_size (int, optional): The batch size of the generated tensors. Defaults to 0.
236236
... max_length (int, optional): The maximum length of the generated tensors. Defaults to 10.
237-
... padding (bool, optional): Whether to pad the tensors to the maximum length. Defaults to False.
237+
... padding (bool, optional): Whether to pad the tensors to the maximum length. Defaults to `False`.
238238
... '''
239239
... self.batch_size = batch_size
240240
... self.max_length = max_length

torchrl/envs/transforms/r3m.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ class R3MTransform(Compose):
214214
If the torchvision weights are needed, there are two ways they can be
215215
obtained: :obj:`download=ResNet50_Weights.IMAGENET1K_V1` or :obj:`download="IMAGENET1K_V1"`
216216
where :obj:`ResNet50_Weights` can be imported via :obj:`from torchvision.models import resnet50, ResNet50_Weights`.
217-
Defaults to False.
217+
Defaults to `False`.
218218
download_path (str, optional): path where to download the models.
219219
Default is None (cache path determined by torch.hub utils).
220220
tensor_pixels_keys (list of str, optional): Optionally, one can keep the

torchrl/envs/transforms/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5835,7 +5835,7 @@ class DiscreteActionProjection(Transform):
58355835
action_key (NestedKey, optional): key name of the action. Defaults to "action".
58365836
include_forward (bool, optional): if ``True``, a call to forward will also
58375837
map the action from one domain to the other when the module is called
5838-
by a replay buffer or an nn.Module chain. Defaults to True.
5838+
by a replay buffer or an nn.Module chain. Defaults to `True`.
58395839
58405840
Examples:
58415841
>>> torch.manual_seed(0)

torchrl/envs/transforms/vip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ class VIPTransform(Compose):
186186
If the torchvision weights are needed, there are two ways they can be
187187
obtained: :obj:`download=ResNet50_Weights.IMAGENET1K_V1` or :obj:`download="IMAGENET1K_V1"`
188188
where :obj:`ResNet50_Weights` can be imported via :obj:`from torchvision.models import resnet50, ResNet50_Weights`.
189-
Defaults to False.
189+
Defaults to `False`.
190190
download_path (str, optional): path where to download the models.
191191
Default is None (cache path determined by torch.hub utils).
192192
tensor_pixels_keys (list of str, optional): Optionally, one can keep the

torchrl/envs/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def check_env_specs(
704704
return_contiguous=True. This will fail in some cases (e.g. heterogeneous shapes
705705
of inputs/outputs). Defaults to ``None`` (determined by the presence of dynamic specs).
706706
check_dtype (bool, optional): if False, dtype checks will be skipped.
707-
Defaults to True.
707+
Defaults to `True`.
708708
seed (int, optional): for reproducibility, a seed can be set.
709709
The seed will be set in pytorch temporarily, then the RNG state will
710710
be reverted to what it was before. For the env, we set the seed but since

torchrl/trainers/helpers/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def make_dreamer(
216216
value_key (str, optional): Key to use for the value.
217217
Defaults to "state_value".
218218
use_decoder_in_env (bool, optional): Whether to use the decoder in the model based dreamer env.
219-
Defaults to False.
219+
Defaults to `False`.
220220
obs_norm_state_dict (dict, optional): the state_dict of the ObservationNorm transform used
221221
when proof_environment is missing. Defaults to None.
222222

0 commit comments

Comments
 (0)