Skip to content

Commit b1d2dc2

Browse files
author
Vincent Moens
committed
[Feature] History
ghstack-source-id: 7128302 Pull-Request-resolved: #2965
1 parent 1b9d2c1 commit b1d2dc2

File tree

6 files changed

+694
-18
lines changed

6 files changed

+694
-18
lines changed

docs/source/reference/llms.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
.. currentmodule:: torchrl.trainers
2+
3+
LLM interface
4+
=============
5+
6+
.. _ref_llms:
7+
8+
TorchRL offers a set of tools for LLM post-training, as well as some examples for training or setup.
9+
10+
Data structures
11+
---------------
12+
13+
.. currentmodule:: torchrl.data.llm
14+
15+
.. autosummary::
16+
:toctree: generated/
17+
:template: rl_template.rst
18+
19+
History

test/llm/test_data.py

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
8+
import argparse
9+
import importlib.util
10+
11+
import pytest
12+
import torch
13+
14+
from tensordict import set_list_to_stack
15+
from torchrl.data import History
16+
17+
_has_transformers = importlib.util.find_spec("transformers")
18+
_has_vllm = importlib.util.find_spec("vllm")
19+
20+
21+
class TestHistory:
22+
@pytest.fixture(scope="class", autouse=True)
23+
def set_context(self):
24+
with set_list_to_stack(True):
25+
yield
26+
27+
def test_history_construct(self):
28+
hst0 = History(role="user", content="a message")
29+
assert not hst0.shape
30+
hst1 = History(role="user", content="another message")
31+
with pytest.raises(RuntimeError, match="unsqueeze"):
32+
hst0.append(hst1)
33+
hst0 = hst0.unsqueeze(0)
34+
35+
# In an env.step, we typically have one more piece of history to add to the stack
36+
assert not hst1.shape
37+
assert not hst1.batch_size
38+
assert not hst1.batch_dims
39+
# test out-place
40+
hst0_copy = hst0.copy()
41+
hst0b = hst0.append(hst1, inplace=False)
42+
assert hst0b is not hst0
43+
assert (hst0 == hst0_copy).all()
44+
assert (hst0b[:-1] == hst0).all()
45+
46+
# test in-place
47+
hst0b = hst0.append(hst1)
48+
assert hst0b is hst0
49+
assert hst0b.shape == (2,)
50+
51+
assert hst0b.content == ["a message", "another message"]
52+
hst2 = History(
53+
role=["assistant", "user"],
54+
content=["i'm the assistant", "i'm the user"],
55+
batch_size=2,
56+
)
57+
assert hst2[0].role == "assistant"
58+
assert hst2[0].content == "i'm the assistant"
59+
assert hst2[1].role == "user"
60+
assert hst2[1].content == "i'm the user"
61+
with pytest.raises(RuntimeError, match="The new history to extend"):
62+
hst0.extend(hst1)
63+
64+
# test out-place
65+
hst0_copy = hst0.copy()
66+
hst0b = hst0.extend(hst2, inplace=False)
67+
assert hst0b is not hst0
68+
assert (hst0 == hst0_copy).all()
69+
assert (hst0b[:-2] == hst0).all()
70+
71+
# test in-place
72+
hst0b = hst0.extend(hst2)
73+
74+
assert hst0b is hst0
75+
assert hst0.__dict__["_tensordict"].shape == (4,)
76+
assert hst0.shape == (4,)
77+
assert hst0.role == ["user", "user", "assistant", "user"]
78+
assert hst0.content == [
79+
"a message",
80+
"another message",
81+
"i'm the assistant",
82+
"i'm the user",
83+
]
84+
85+
def test_history_construct_ndim(self):
86+
hst0 = History(role="user", content="a message").unsqueeze(0).unsqueeze(0)
87+
hst1 = History(role="user", content="another message").unsqueeze(0)
88+
89+
# test out-place
90+
hst0_copy = hst0.copy()
91+
assert isinstance(hst0_copy, History)
92+
assert hst0.shape == (1, 1)
93+
hst0b = hst0.append(hst1, inplace=False, dim=1)
94+
assert hst0b is not hst0
95+
assert hst0.shape == (1, 1)
96+
assert (hst0 == hst0_copy).all()
97+
assert (hst0b[:, :-1] == hst0).all()
98+
99+
# test in-place
100+
assert hst0b.shape == (1, 2)
101+
assert hst0.shape == (1, 1)
102+
hst0b = hst0.append(hst1, dim=1)
103+
assert hst0b is hst0
104+
assert hst0b._tensordict.shape == (1, 2)
105+
assert hst0b.batch_size == (1, 2)
106+
assert hst0b.shape == (1, 2)
107+
108+
assert hst0b.content == [["a message", "another message"]]
109+
hst2 = History(
110+
role=["assistant", "user"],
111+
content=["i'm the assistant", "i'm the user"],
112+
batch_size=2,
113+
).unsqueeze(0)
114+
115+
# test out-place
116+
hst0_copy = hst0.copy()
117+
hst0b = hst0.extend(hst2, inplace=False, dim=1)
118+
assert hst0b is not hst0
119+
assert (hst0 == hst0_copy).all()
120+
assert (hst0b[:, :-2] == hst0).all()
121+
122+
# test in-place
123+
hst0b = hst0.extend(hst2, dim=1)
124+
125+
assert hst0b is hst0
126+
assert hst0.__dict__["_tensordict"].shape == (
127+
1,
128+
4,
129+
)
130+
assert hst0.shape == (
131+
1,
132+
4,
133+
)
134+
assert hst0.role == [["user", "user", "assistant", "user"]]
135+
assert hst0.content == [
136+
[
137+
"a message",
138+
"another message",
139+
"i'm the assistant",
140+
"i'm the user",
141+
]
142+
]
143+
144+
@pytest.fixture(scope="class")
145+
def mock_history(self):
146+
history0 = History(
147+
role="system",
148+
content="""CONTENT
149+
This is the setup""",
150+
)
151+
history1 = History(
152+
role="user",
153+
content="""CONTENT
154+
This is the first user prompt""",
155+
)
156+
history2 = History(
157+
role="assistant",
158+
content="""CONTENT
159+
This is the second prompt, the first for the assistant.""",
160+
)
161+
history = torch.stack([history0, history1, history2])
162+
return history
163+
164+
@pytest.fixture(scope="class")
165+
def tokenizer(self):
166+
from transformers import AutoTokenizer
167+
168+
tokenizer = AutoTokenizer.from_pretrained("GPT2")
169+
yield tokenizer
170+
171+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers library")
172+
def test_history_template(self, mock_history, tokenizer):
173+
history = mock_history
174+
data_str = history.apply_chat_template(
175+
tokenizer=tokenizer, add_generation_prompt=False
176+
)
177+
assert isinstance(data_str, str)
178+
data_token = history.apply_chat_template(
179+
tokenizer=tokenizer, tokenize=True, add_generation_prompt=False
180+
)
181+
assert isinstance(data_token, torch.Tensor)
182+
183+
# test add_generation_prompt
184+
data_str = history.apply_chat_template(
185+
tokenizer=tokenizer, add_generation_prompt=True
186+
)
187+
assert isinstance(data_str, str)
188+
assert data_str.endswith("<|im_start|>assistant\n"), data_str
189+
190+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers library")
191+
def test_history_template_recover(self, mock_history, tokenizer):
192+
history = mock_history
193+
data_str = history.apply_chat_template(tokenizer=tokenizer)
194+
# Test inverse
195+
recovered = history._inv_chatml(data_str)
196+
assert recovered.role == history.role
197+
assert recovered.content == history.content
198+
data_token = history.apply_chat_template(
199+
tokenizer=tokenizer, tokenize=True, add_generation_prompt=False
200+
)
201+
recovered = history._inv_chatml(tokenizer.batch_decode(data_token)[0])
202+
203+
def test_history_spec(self):
204+
history = History(
205+
role=["system", "user", "assistant", "user"],
206+
content=[
207+
"i'm the system",
208+
"i'm the user",
209+
"I'm the assistant",
210+
"I'm the user again",
211+
],
212+
)
213+
spec = history.default_spec()
214+
r = spec.zero()
215+
assert isinstance(r, History)
216+
assert spec.is_in(r)
217+
assert spec.is_in(history)
218+
219+
220+
if __name__ == "__main__":
221+
args, unknown = argparse.ArgumentParser().parse_known_args()
222+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_cost.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7733,24 +7733,9 @@ def test_cql_deactivate_vmap(
77337733

77347734
@pytest.mark.parametrize("delay_actor", (True,))
77357735
@pytest.mark.parametrize("delay_qvalue", (True,))
7736-
@pytest.mark.parametrize(
7737-
"max_q_backup",
7738-
[
7739-
True,
7740-
],
7741-
)
7742-
@pytest.mark.parametrize(
7743-
"deterministic_backup",
7744-
[
7745-
True,
7746-
],
7747-
)
7748-
@pytest.mark.parametrize(
7749-
"with_lagrange",
7750-
[
7751-
True,
7752-
],
7753-
)
7736+
@pytest.mark.parametrize("max_q_backup", [True])
7737+
@pytest.mark.parametrize("deterministic_backup", [True])
7738+
@pytest.mark.parametrize("with_lagrange", [True])
77547739
@pytest.mark.parametrize("device", get_available_devices())
77557740
@pytest.mark.parametrize("td_est", [None])
77567741
def test_cql_qvalfromlist(

torchrl/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ConstantKLController,
99
create_infinite_iterator,
1010
get_dataloader,
11+
History,
1112
PairwiseDataset,
1213
PromptData,
1314
PromptTensorDictTokenizer,
@@ -121,6 +122,7 @@
121122
"DiscreteTensorSpec",
122123
"Flat2TED",
123124
"FlatStorageCheckpointer",
125+
"History",
124126
"H5Combine",
125127
"H5Split",
126128
"H5StorageCheckpointer",

torchrl/data/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .chat import History
67
from .dataset import (
78
create_infinite_iterator,
89
get_dataloader,
@@ -15,6 +16,7 @@
1516

1617
__all__ = [
1718
"AdaptiveKLController",
19+
"History",
1820
"ConstantKLController",
1921
"PairwiseDataset",
2022
"PromptData",

0 commit comments

Comments
 (0)