|
42 | 42 | is_tensor_collection,
|
43 | 43 | is_tensorclass,
|
44 | 44 | LazyStackedTensorDict,
|
| 45 | + set_list_to_stack, |
45 | 46 | tensorclass,
|
46 | 47 | TensorDict,
|
47 | 48 | TensorDictBase,
|
|
54 | 55 | from torchrl.collectors.utils import split_trajectories
|
55 | 56 | from torchrl.data import (
|
56 | 57 | FlatStorageCheckpointer,
|
| 58 | + History, |
57 | 59 | MultiStep,
|
58 | 60 | NestedStorageCheckpointer,
|
59 | 61 | PrioritizedReplayBuffer,
|
|
127 | 129 | _has_gym = importlib.util.find_spec("gym") is not None
|
128 | 130 | _has_snapshot = importlib.util.find_spec("torchsnapshot") is not None
|
129 | 131 | _os_is_windows = sys.platform == "win32"
|
| 132 | +_has_transformers = importlib.util.find_spec("transformers") is not None |
130 | 133 | TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
|
131 | 134 |
|
132 | 135 | torch_2_3 = version.parse(
|
@@ -3916,6 +3919,185 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch):
|
3916 | 3919 | assert rb._writer._cursor == rb_test._writer._cursor
|
3917 | 3920 |
|
3918 | 3921 |
|
| 3922 | +class TestHistory: |
| 3923 | + @pytest.fixture(scope="class", autouse=True) |
| 3924 | + def set_context(self): |
| 3925 | + with set_list_to_stack(True): |
| 3926 | + yield |
| 3927 | + |
| 3928 | + def test_history_construct(self): |
| 3929 | + hst0 = History(role="user", content="a message") |
| 3930 | + assert not hst0.shape |
| 3931 | + hst1 = History(role="user", content="another message") |
| 3932 | + with pytest.raises(RuntimeError, match="unsqueeze"): |
| 3933 | + hst0.append(hst1) |
| 3934 | + hst0 = hst0.unsqueeze(0) |
| 3935 | + |
| 3936 | + # In an env.step, we typically have one more piece of history to add to the stack |
| 3937 | + assert not hst1.shape |
| 3938 | + assert not hst1.batch_size |
| 3939 | + assert not hst1.batch_dims |
| 3940 | + # test out-place |
| 3941 | + hst0_copy = hst0.copy() |
| 3942 | + hst0b = hst0.append(hst1, inplace=False) |
| 3943 | + assert hst0b is not hst0 |
| 3944 | + assert (hst0 == hst0_copy).all() |
| 3945 | + assert (hst0b[:-1] == hst0).all() |
| 3946 | + |
| 3947 | + # test in-place |
| 3948 | + hst0b = hst0.append(hst1) |
| 3949 | + assert hst0b is hst0 |
| 3950 | + assert hst0b.shape == (2,) |
| 3951 | + |
| 3952 | + assert hst0b.content == ["a message", "another message"] |
| 3953 | + hst2 = History( |
| 3954 | + role=["assistant", "user"], |
| 3955 | + content=["i'm the assistant", "i'm the user"], |
| 3956 | + batch_size=2, |
| 3957 | + ) |
| 3958 | + assert hst2[0].role == "assistant" |
| 3959 | + assert hst2[0].content == "i'm the assistant" |
| 3960 | + assert hst2[1].role == "user" |
| 3961 | + assert hst2[1].content == "i'm the user" |
| 3962 | + with pytest.raises(RuntimeError, match="The new history to extend"): |
| 3963 | + hst0.extend(hst1) |
| 3964 | + |
| 3965 | + # test out-place |
| 3966 | + hst0_copy = hst0.copy() |
| 3967 | + hst0b = hst0.extend(hst2, inplace=False) |
| 3968 | + assert hst0b is not hst0 |
| 3969 | + assert (hst0 == hst0_copy).all() |
| 3970 | + assert (hst0b[:-2] == hst0).all() |
| 3971 | + |
| 3972 | + # test in-place |
| 3973 | + hst0b = hst0.extend(hst2) |
| 3974 | + |
| 3975 | + assert hst0b is hst0 |
| 3976 | + assert hst0.__dict__["_tensordict"].shape == (4,) |
| 3977 | + assert hst0.shape == (4,) |
| 3978 | + assert hst0.role == ["user", "user", "assistant", "user"] |
| 3979 | + assert hst0.content == [ |
| 3980 | + "a message", |
| 3981 | + "another message", |
| 3982 | + "i'm the assistant", |
| 3983 | + "i'm the user", |
| 3984 | + ] |
| 3985 | + |
| 3986 | + def test_history_construct_ndim(self): |
| 3987 | + hst0 = History(role="user", content="a message").unsqueeze(0).unsqueeze(0) |
| 3988 | + hst1 = History(role="user", content="another message").unsqueeze(0) |
| 3989 | + |
| 3990 | + # test out-place |
| 3991 | + hst0_copy = hst0.copy() |
| 3992 | + hst0b = hst0.append(hst1, inplace=False, dim=1) |
| 3993 | + assert hst0b is not hst0 |
| 3994 | + assert (hst0 == hst0_copy).all() |
| 3995 | + assert (hst0b[:, :-1] == hst0).all() |
| 3996 | + |
| 3997 | + # test in-place |
| 3998 | + hst0b = hst0.append(hst1, dim=1) |
| 3999 | + assert hst0b is hst0 |
| 4000 | + assert hst0b.shape == ( |
| 4001 | + 1, |
| 4002 | + 2, |
| 4003 | + ) |
| 4004 | + |
| 4005 | + assert hst0b.content == [["a message", "another message"]] |
| 4006 | + hst2 = History( |
| 4007 | + role=["assistant", "user"], |
| 4008 | + content=["i'm the assistant", "i'm the user"], |
| 4009 | + batch_size=2, |
| 4010 | + ).unsqueeze(0) |
| 4011 | + |
| 4012 | + # test out-place |
| 4013 | + hst0_copy = hst0.copy() |
| 4014 | + hst0b = hst0.extend(hst2, inplace=False, dim=1) |
| 4015 | + assert hst0b is not hst0 |
| 4016 | + assert (hst0 == hst0_copy).all() |
| 4017 | + assert (hst0b[:, :-2] == hst0).all() |
| 4018 | + |
| 4019 | + # test in-place |
| 4020 | + hst0b = hst0.extend(hst2, dim=1) |
| 4021 | + |
| 4022 | + assert hst0b is hst0 |
| 4023 | + assert hst0.__dict__["_tensordict"].shape == ( |
| 4024 | + 1, |
| 4025 | + 4, |
| 4026 | + ) |
| 4027 | + assert hst0.shape == ( |
| 4028 | + 1, |
| 4029 | + 4, |
| 4030 | + ) |
| 4031 | + assert hst0.role == [["user", "user", "assistant", "user"]] |
| 4032 | + assert hst0.content == [ |
| 4033 | + [ |
| 4034 | + "a message", |
| 4035 | + "another message", |
| 4036 | + "i'm the assistant", |
| 4037 | + "i'm the user", |
| 4038 | + ] |
| 4039 | + ] |
| 4040 | + |
| 4041 | + @pytest.fixture(scope="class") |
| 4042 | + def mock_history(self): |
| 4043 | + history0 = History( |
| 4044 | + role="system", |
| 4045 | + content="""CONTENT |
| 4046 | + This is the setup""", |
| 4047 | + ) |
| 4048 | + history1 = History( |
| 4049 | + role="user", |
| 4050 | + content="""CONTENT |
| 4051 | + This is the first user prompt""", |
| 4052 | + ) |
| 4053 | + history2 = History( |
| 4054 | + role="assistant", |
| 4055 | + content="""CONTENT |
| 4056 | + This is the second prompt, the first for the assistant.""", |
| 4057 | + ) |
| 4058 | + history = torch.stack([history0, history1, history2]) |
| 4059 | + return history |
| 4060 | + |
| 4061 | + @pytest.fixture(scope="class") |
| 4062 | + def tokenizer(self): |
| 4063 | + from transformers import AutoTokenizer |
| 4064 | + |
| 4065 | + tokenizer = AutoTokenizer.from_pretrained("GPT2") |
| 4066 | + yield tokenizer |
| 4067 | + |
| 4068 | + @pytest.mark.skipif(not _has_transformers, reason="requires transformers library") |
| 4069 | + def test_history_template(self, mock_history, tokenizer): |
| 4070 | + history = mock_history |
| 4071 | + data_str = history.apply_chat_template( |
| 4072 | + tokenizer=tokenizer, add_generation_prompt=False |
| 4073 | + ) |
| 4074 | + assert isinstance(data_str, str) |
| 4075 | + data_token = history.apply_chat_template( |
| 4076 | + tokenizer=tokenizer, tokenize=True, add_generation_prompt=False |
| 4077 | + ) |
| 4078 | + assert isinstance(data_token, torch.Tensor) |
| 4079 | + |
| 4080 | + # test add_generation_prompt |
| 4081 | + data_str = history.apply_chat_template( |
| 4082 | + tokenizer=tokenizer, add_generation_prompt=True |
| 4083 | + ) |
| 4084 | + assert isinstance(data_str, str) |
| 4085 | + assert data_str.endswith("<|im_start|>assistant\n"), data_str |
| 4086 | + |
| 4087 | + @pytest.mark.skipif(not _has_transformers, reason="requires transformers library") |
| 4088 | + def test_history_template_recover(self, mock_history, tokenizer): |
| 4089 | + history = mock_history |
| 4090 | + data_str = history.apply_chat_template(tokenizer=tokenizer) |
| 4091 | + # Test inverse |
| 4092 | + recovered = history._inv_chatml(data_str) |
| 4093 | + assert recovered.role == history.role |
| 4094 | + assert recovered.content == history.content |
| 4095 | + data_token = history.apply_chat_template( |
| 4096 | + tokenizer=tokenizer, tokenize=True, add_generation_prompt=False |
| 4097 | + ) |
| 4098 | + recovered = history._inv_chatml(tokenizer.batch_decode(data_token)[0]) |
| 4099 | + |
| 4100 | + |
3919 | 4101 | if __name__ == "__main__":
|
3920 | 4102 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
3921 | 4103 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments