Skip to content

Commit fd10fe2

Browse files
author
Vincent Moens
committed
[Feature] History API
ghstack-source-id: 5b9723f Pull Request resolved: #2890
1 parent 9bc85f4 commit fd10fe2

File tree

6 files changed

+419
-4
lines changed

6 files changed

+419
-4
lines changed

docs/source/reference/data.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,10 +1107,10 @@ and the tree can be expanded for each of these. The following figure shows how t
11071107
Tree
11081108

11091109

1110-
Reinforcement Learning From Human Feedback (RLHF)
1111-
-------------------------------------------------
1110+
Large language models and Reinforcement Learning From Human Feedback (RLHF)
1111+
---------------------------------------------------------------------------
11121112

1113-
Data is of utmost importance in Reinforcement Learning from Human Feedback (RLHF).
1113+
Data is of utmost importance in LLM post-training (e.g., GRPO or Reinforcement Learning from Human Feedback (RLHF)).
11141114
Given that these techniques are commonly employed in the realm of language,
11151115
which is scarcely addressed in other subdomains of RL within the library,
11161116
we offer specific utilities to facilitate interaction with external libraries
@@ -1124,6 +1124,7 @@ efficient sampling.
11241124
:toctree: generated/
11251125
:template: rl_template.rst
11261126

1127+
History
11271128
PairwiseDataset
11281129
PromptData
11291130
PromptTensorDictTokenizer

test/test_rb.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
is_tensor_collection,
4343
is_tensorclass,
4444
LazyStackedTensorDict,
45+
set_list_to_stack,
4546
tensorclass,
4647
TensorDict,
4748
TensorDictBase,
@@ -54,6 +55,7 @@
5455
from torchrl.collectors.utils import split_trajectories
5556
from torchrl.data import (
5657
FlatStorageCheckpointer,
58+
History,
5759
MultiStep,
5860
NestedStorageCheckpointer,
5961
PrioritizedReplayBuffer,
@@ -127,6 +129,7 @@
127129
_has_gym = importlib.util.find_spec("gym") is not None
128130
_has_snapshot = importlib.util.find_spec("torchsnapshot") is not None
129131
_os_is_windows = sys.platform == "win32"
132+
_has_transformers = importlib.util.find_spec("transformers") is not None
130133
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
131134

132135
torch_2_3 = version.parse(
@@ -3916,6 +3919,185 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch):
39163919
assert rb._writer._cursor == rb_test._writer._cursor
39173920

39183921

3922+
class TestHistory:
3923+
@pytest.fixture(scope="class", autouse=True)
3924+
def set_context(self):
3925+
with set_list_to_stack(True):
3926+
yield
3927+
3928+
def test_history_construct(self):
3929+
hst0 = History(role="user", content="a message")
3930+
assert not hst0.shape
3931+
hst1 = History(role="user", content="another message")
3932+
with pytest.raises(RuntimeError, match="unsqueeze"):
3933+
hst0.append(hst1)
3934+
hst0 = hst0.unsqueeze(0)
3935+
3936+
# In an env.step, we typically have one more piece of history to add to the stack
3937+
assert not hst1.shape
3938+
assert not hst1.batch_size
3939+
assert not hst1.batch_dims
3940+
# test out-place
3941+
hst0_copy = hst0.copy()
3942+
hst0b = hst0.append(hst1, inplace=False)
3943+
assert hst0b is not hst0
3944+
assert (hst0 == hst0_copy).all()
3945+
assert (hst0b[:-1] == hst0).all()
3946+
3947+
# test in-place
3948+
hst0b = hst0.append(hst1)
3949+
assert hst0b is hst0
3950+
assert hst0b.shape == (2,)
3951+
3952+
assert hst0b.content == ["a message", "another message"]
3953+
hst2 = History(
3954+
role=["assistant", "user"],
3955+
content=["i'm the assistant", "i'm the user"],
3956+
batch_size=2,
3957+
)
3958+
assert hst2[0].role == "assistant"
3959+
assert hst2[0].content == "i'm the assistant"
3960+
assert hst2[1].role == "user"
3961+
assert hst2[1].content == "i'm the user"
3962+
with pytest.raises(RuntimeError, match="The new history to extend"):
3963+
hst0.extend(hst1)
3964+
3965+
# test out-place
3966+
hst0_copy = hst0.copy()
3967+
hst0b = hst0.extend(hst2, inplace=False)
3968+
assert hst0b is not hst0
3969+
assert (hst0 == hst0_copy).all()
3970+
assert (hst0b[:-2] == hst0).all()
3971+
3972+
# test in-place
3973+
hst0b = hst0.extend(hst2)
3974+
3975+
assert hst0b is hst0
3976+
assert hst0.__dict__["_tensordict"].shape == (4,)
3977+
assert hst0.shape == (4,)
3978+
assert hst0.role == ["user", "user", "assistant", "user"]
3979+
assert hst0.content == [
3980+
"a message",
3981+
"another message",
3982+
"i'm the assistant",
3983+
"i'm the user",
3984+
]
3985+
3986+
def test_history_construct_ndim(self):
3987+
hst0 = History(role="user", content="a message").unsqueeze(0).unsqueeze(0)
3988+
hst1 = History(role="user", content="another message").unsqueeze(0)
3989+
3990+
# test out-place
3991+
hst0_copy = hst0.copy()
3992+
hst0b = hst0.append(hst1, inplace=False, dim=1)
3993+
assert hst0b is not hst0
3994+
assert (hst0 == hst0_copy).all()
3995+
assert (hst0b[:, :-1] == hst0).all()
3996+
3997+
# test in-place
3998+
hst0b = hst0.append(hst1, dim=1)
3999+
assert hst0b is hst0
4000+
assert hst0b.shape == (
4001+
1,
4002+
2,
4003+
)
4004+
4005+
assert hst0b.content == [["a message", "another message"]]
4006+
hst2 = History(
4007+
role=["assistant", "user"],
4008+
content=["i'm the assistant", "i'm the user"],
4009+
batch_size=2,
4010+
).unsqueeze(0)
4011+
4012+
# test out-place
4013+
hst0_copy = hst0.copy()
4014+
hst0b = hst0.extend(hst2, inplace=False, dim=1)
4015+
assert hst0b is not hst0
4016+
assert (hst0 == hst0_copy).all()
4017+
assert (hst0b[:, :-2] == hst0).all()
4018+
4019+
# test in-place
4020+
hst0b = hst0.extend(hst2, dim=1)
4021+
4022+
assert hst0b is hst0
4023+
assert hst0.__dict__["_tensordict"].shape == (
4024+
1,
4025+
4,
4026+
)
4027+
assert hst0.shape == (
4028+
1,
4029+
4,
4030+
)
4031+
assert hst0.role == [["user", "user", "assistant", "user"]]
4032+
assert hst0.content == [
4033+
[
4034+
"a message",
4035+
"another message",
4036+
"i'm the assistant",
4037+
"i'm the user",
4038+
]
4039+
]
4040+
4041+
@pytest.fixture(scope="class")
4042+
def mock_history(self):
4043+
history0 = History(
4044+
role="system",
4045+
content="""CONTENT
4046+
This is the setup""",
4047+
)
4048+
history1 = History(
4049+
role="user",
4050+
content="""CONTENT
4051+
This is the first user prompt""",
4052+
)
4053+
history2 = History(
4054+
role="assistant",
4055+
content="""CONTENT
4056+
This is the second prompt, the first for the assistant.""",
4057+
)
4058+
history = torch.stack([history0, history1, history2])
4059+
return history
4060+
4061+
@pytest.fixture(scope="class")
4062+
def tokenizer(self):
4063+
from transformers import AutoTokenizer
4064+
4065+
tokenizer = AutoTokenizer.from_pretrained("GPT2")
4066+
yield tokenizer
4067+
4068+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers library")
4069+
def test_history_template(self, mock_history, tokenizer):
4070+
history = mock_history
4071+
data_str = history.apply_chat_template(
4072+
tokenizer=tokenizer, add_generation_prompt=False
4073+
)
4074+
assert isinstance(data_str, str)
4075+
data_token = history.apply_chat_template(
4076+
tokenizer=tokenizer, tokenize=True, add_generation_prompt=False
4077+
)
4078+
assert isinstance(data_token, torch.Tensor)
4079+
4080+
# test add_generation_prompt
4081+
data_str = history.apply_chat_template(
4082+
tokenizer=tokenizer, add_generation_prompt=True
4083+
)
4084+
assert isinstance(data_str, str)
4085+
assert data_str.endswith("<|im_start|>assistant\n"), data_str
4086+
4087+
@pytest.mark.skipif(not _has_transformers, reason="requires transformers library")
4088+
def test_history_template_recover(self, mock_history, tokenizer):
4089+
history = mock_history
4090+
data_str = history.apply_chat_template(tokenizer=tokenizer)
4091+
# Test inverse
4092+
recovered = history._inv_chatml(data_str)
4093+
assert recovered.role == history.role
4094+
assert recovered.content == history.content
4095+
data_token = history.apply_chat_template(
4096+
tokenizer=tokenizer, tokenize=True, add_generation_prompt=False
4097+
)
4098+
recovered = history._inv_chatml(tokenizer.batch_decode(data_token)[0])
4099+
4100+
39194101
if __name__ == "__main__":
39204102
args, unknown = argparse.ArgumentParser().parse_known_args()
39214103
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ConstantKLController,
99
create_infinite_iterator,
1010
get_dataloader,
11+
History,
1112
LLMData,
1213
LLMInput,
1314
LLMOutput,
@@ -108,6 +109,7 @@
108109

109110
__all__ = [
110111
"AdaptiveKLController",
112+
"History",
111113
"Binary",
112114
"BinaryDiscreteTensorSpec",
113115
"BinaryToDecimal",

torchrl/data/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from .chat import History
67
from .dataset import (
78
create_infinite_iterator,
89
get_dataloader,
@@ -35,4 +36,5 @@
3536
"TokenizedDatasetLoader",
3637
"create_infinite_iterator",
3738
"get_dataloader",
39+
"History",
3840
]

0 commit comments

Comments
 (0)