Skip to content

Commit 2195292

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent 413571b commit 2195292

File tree

19 files changed

+1195
-361
lines changed

19 files changed

+1195
-361
lines changed

docs/source/reference/data.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,16 +1144,17 @@ Utils
11441144
:toctree: generated/
11451145
:template: rl_template.rst
11461146

1147-
MultiStep
1148-
consolidate_spec
1149-
check_no_exclusive_keys
1150-
contains_lazy_spec
1151-
Nested2TED
1147+
DensifyReward
11521148
Flat2TED
11531149
H5Combine
11541150
H5Split
1151+
MultiStep
1152+
Nested2TED
11551153
TED2Flat
11561154
TED2Nested
1155+
check_no_exclusive_keys
1156+
consolidate_spec
1157+
contains_lazy_spec
11571158

11581159
.. currentmodule:: torchrl.envs.transforms.rb_transforms
11591160

test/test_env.py

Lines changed: 265 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4616,11 +4616,13 @@ def __next__(self):
46164616
@pytest.mark.parametrize("batch_size", [0, 4])
46174617
@pytest.mark.parametrize("device", [None, "cpu"])
46184618
def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
4619-
env = LLMEnv(str2str=str2str, device=device)
4619+
env = LLMEnv(
4620+
str2str=str2str, device=device, has_attention=False, no_stack=False
4621+
)
46204622
if str2str:
46214623
primer = DataLoadingPrimer(
46224624
dataloader=self.DummyDataLoader(batch_size=batch_size),
4623-
data_keys=["observation"],
4625+
data_keys=[LLMEnv._DEFAULT_STR_KEY],
46244626
example_data="a string!",
46254627
)
46264628
else:
@@ -4630,7 +4632,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46304632
dataloader=self.DummyTensorDataLoader(
46314633
batch_size=batch_size, padding=True
46324634
),
4633-
data_keys=["observation"],
4635+
data_keys=[LLMEnv._DEFAULT_TOKEN_KEY],
46344636
data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)],
46354637
stack_method=stack_method,
46364638
)
@@ -4640,7 +4642,7 @@ def test_llm_env(self, str2str, batched, stack_method, device, batch_size):
46404642
if batched:
46414643
td = env.reset(TensorDict(batch_size=[3]))
46424644
env.check_env_specs(break_when_any_done="both", tensordict=td)
4643-
r = env.rollout(10, tensordict=TensorDict(batch_size=[3]))
4645+
env.rollout(10, tensordict=TensorDict(batch_size=[3]))
46444646
else:
46454647
env.check_env_specs(break_when_any_done="both")
46464648

@@ -4663,7 +4665,7 @@ def test_llm_from_dataloader(
46634665
if str2str:
46644666
kwargs = {
46654667
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4666-
"data_keys": ["observation"],
4668+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
46674669
"example_data": "a string!",
46684670
}
46694671
else:
@@ -4673,11 +4675,18 @@ def test_llm_from_dataloader(
46734675
"dataloader": self.DummyTensorDataLoader(
46744676
padding=True, batch_size=batch_size
46754677
),
4676-
"data_keys": ["observation"],
4678+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
46774679
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
46784680
"stack_method": stack_method,
46794681
}
4680-
kwargs.update({"str2str": str2str, "device": device})
4682+
kwargs.update(
4683+
{
4684+
"str2str": str2str,
4685+
"device": device,
4686+
"has_attention": False,
4687+
"no_stack": False,
4688+
}
4689+
)
46814690
env = LLMEnv.from_dataloader(**kwargs)
46824691
assert not env.batch_locked
46834692
if batched:
@@ -4690,51 +4699,283 @@ def test_llm_from_dataloader(
46904699
def policy(td):
46914700
if str2str:
46924701
if not td.shape:
4693-
td["action"] = "<nothing>"
4702+
td[LLMEnv._DEFAULT_ACTION_KEY] = "<nothing>"
46944703
else:
4695-
td["action"] = NonTensorStack(
4704+
td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack(
46964705
*["<nothing>" for _ in range(td.shape[0])]
46974706
)
46984707
else:
4699-
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4708+
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4709+
td.shape + (1,), dtype=torch.int64
4710+
)
47004711
return td
47014712

47024713
if batched:
47034714
# Tell the env that we want 3 sub-envs
47044715
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[3]))
47054716
assert r.ndim == 2
47064717
if str2str:
4707-
assert isinstance(r[0, 0]["observation"], str)
4708-
assert isinstance(r[0, 1]["observation"], str)
4718+
assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str)
4719+
assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str)
47094720
assert (
4710-
r[0, 0]["observation"]
4711-
== r[0, 1]["observation"][: -len(r[0, 0]["action"])]
4721+
r[0, 0][LLMEnv._DEFAULT_STR_KEY]
4722+
== r[0, 1][LLMEnv._DEFAULT_STR_KEY][
4723+
: -len(r[0, 0][LLMEnv._DEFAULT_ACTION_KEY])
4724+
]
47124725
)
47134726
assert (
4714-
r[0, 1]["observation"]
4715-
== r[0, 2]["observation"][: -len(r[0, 1]["action"])]
4727+
r[0, 1][LLMEnv._DEFAULT_STR_KEY]
4728+
== r[0, 2][LLMEnv._DEFAULT_STR_KEY][
4729+
: -len(r[0, 1][LLMEnv._DEFAULT_ACTION_KEY])
4730+
]
47164731
)
47174732
assert (
4718-
r[-1, 0]["observation"]
4719-
== r[-1, 1]["observation"][: -len(r[-1, 0]["action"])]
4733+
r[-1, 0][LLMEnv._DEFAULT_STR_KEY]
4734+
== r[-1, 1][LLMEnv._DEFAULT_STR_KEY][
4735+
: -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_KEY])
4736+
]
47204737
)
47214738
assert (
4722-
r[-1, 1]["observation"]
4723-
== r[-1, 2]["observation"][: -len(r[-1, 1]["action"])]
4739+
r[-1, 1][LLMEnv._DEFAULT_STR_KEY]
4740+
== r[-1, 2][LLMEnv._DEFAULT_STR_KEY][
4741+
: -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_KEY])
4742+
]
47244743
)
47254744
else:
4726-
assert (r[0, 0]["observation"] == r[0, 1]["observation"][:-1]).all()
4727-
assert (r[0, 1]["observation"] == r[0, 2]["observation"][:-1]).all()
47284745
assert (
4729-
r[-1, 0]["observation"] == r[-1, 1]["observation"][:-1]
4746+
r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4747+
== r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4748+
).all()
4749+
assert (
4750+
r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4751+
== r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
47304752
).all()
47314753
assert (
4732-
r[-1, 1]["observation"] == r[-1, 2]["observation"][:-1]
4754+
r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4755+
== r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
4756+
).all()
4757+
assert (
4758+
r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4759+
== r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1]
47334760
).all()
47344761
else:
47354762
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
47364763
assert r.ndim == 1
47374764

4765+
@pytest.mark.parametrize(
4766+
"str2str,stack_method",
4767+
[
4768+
[True, None],
4769+
[False, "as_padded_tensor"],
4770+
# TODO: a bit experimental, fails with check_env_specs
4771+
# [False, "as_nested_tensor"],
4772+
[False, None],
4773+
],
4774+
)
4775+
@pytest.mark.parametrize("batched", [True, False])
4776+
@pytest.mark.parametrize("device", [None, "cpu"])
4777+
@pytest.mark.parametrize("batch_size", [0, 4])
4778+
@pytest.mark.parametrize("repeats", [3])
4779+
def test_llm_from_dataloader_repeats(
4780+
self, str2str, batched, stack_method, device, batch_size, repeats
4781+
):
4782+
if str2str:
4783+
kwargs = {
4784+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4785+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
4786+
"example_data": "a string!",
4787+
"repeats": repeats,
4788+
}
4789+
else:
4790+
if stack_method is None:
4791+
stack_method = as_padded_tensor
4792+
kwargs = {
4793+
"dataloader": self.DummyTensorDataLoader(
4794+
padding=True, batch_size=batch_size
4795+
),
4796+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
4797+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4798+
"stack_method": stack_method,
4799+
"repeats": repeats,
4800+
}
4801+
kwargs.update(
4802+
{
4803+
"str2str": str2str,
4804+
"device": device,
4805+
"has_attention": False,
4806+
"no_stack": False,
4807+
}
4808+
)
4809+
env = LLMEnv.from_dataloader(**kwargs)
4810+
assert env.transform.repeats == repeats
4811+
4812+
max_steps = 3
4813+
env.append_transform(StepCounter(max_steps=max_steps))
4814+
4815+
def policy(td):
4816+
if str2str:
4817+
if not td.shape:
4818+
td[LLMEnv._DEFAULT_ACTION_KEY] = "<nothing>"
4819+
else:
4820+
td[LLMEnv._DEFAULT_ACTION_KEY] = NonTensorStack(
4821+
*["<nothing>" for _ in range(td.shape[0])]
4822+
)
4823+
else:
4824+
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4825+
td.shape + (1,), dtype=torch.int64
4826+
)
4827+
return td
4828+
4829+
if batched:
4830+
r = env.rollout(
4831+
100,
4832+
policy,
4833+
tensordict=TensorDict(batch_size=[3]),
4834+
break_when_any_done=False,
4835+
)
4836+
else:
4837+
r = env.rollout(100, policy, break_when_any_done=False)
4838+
# check that r at reset is always the same
4839+
r_reset = r[..., ::max_steps]
4840+
if not batched:
4841+
if str2str:
4842+
assert (
4843+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4844+
== r_reset[..., 1][LLMEnv._DEFAULT_STR_KEY]
4845+
)
4846+
assert (
4847+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4848+
== r_reset[..., 2][LLMEnv._DEFAULT_STR_KEY]
4849+
)
4850+
assert (
4851+
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4852+
!= r_reset[..., 3][LLMEnv._DEFAULT_STR_KEY]
4853+
)
4854+
else:
4855+
assert (
4856+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4857+
== r_reset[..., 1][LLMEnv._DEFAULT_TOKEN_KEY]
4858+
).all()
4859+
assert (
4860+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4861+
== r_reset[..., 2][LLMEnv._DEFAULT_TOKEN_KEY]
4862+
).all()
4863+
assert (
4864+
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4865+
!= r_reset[..., 3][LLMEnv._DEFAULT_TOKEN_KEY]
4866+
).any()
4867+
else:
4868+
# When batched, each block contains the 3 reset packs
4869+
if str2str:
4870+
assert (
4871+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4872+
== r_reset[1, 0][LLMEnv._DEFAULT_STR_KEY]
4873+
)
4874+
assert (
4875+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4876+
== r_reset[2, 0][LLMEnv._DEFAULT_STR_KEY]
4877+
)
4878+
assert (
4879+
r_reset[0, 0][LLMEnv._DEFAULT_STR_KEY]
4880+
!= r_reset[0, 1][LLMEnv._DEFAULT_STR_KEY]
4881+
)
4882+
else:
4883+
assert (
4884+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4885+
== r_reset[1, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4886+
).all()
4887+
assert (
4888+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4889+
== r_reset[2, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4890+
).all()
4891+
assert (
4892+
r_reset[0, 0][LLMEnv._DEFAULT_TOKEN_KEY]
4893+
!= r_reset[0, 1][LLMEnv._DEFAULT_TOKEN_KEY]
4894+
).any()
4895+
4896+
@pytest.mark.parametrize(
4897+
"str2str,stack_method",
4898+
[
4899+
[True, None],
4900+
[False, "as_padded_tensor"],
4901+
],
4902+
)
4903+
@pytest.mark.parametrize("batched", [True])
4904+
@pytest.mark.parametrize("device", [None])
4905+
@pytest.mark.parametrize("batch_size", [4])
4906+
@pytest.mark.parametrize("repeats", [3])
4907+
@pytest.mark.parametrize(
4908+
"assign_reward,assign_done", [[True, False], [True, True], [False, True]]
4909+
)
4910+
def test_done_and_reward(
4911+
self,
4912+
str2str,
4913+
batched,
4914+
stack_method,
4915+
device,
4916+
batch_size,
4917+
repeats,
4918+
assign_reward,
4919+
assign_done,
4920+
):
4921+
with pytest.raises(
4922+
ValueError, match="str2str"
4923+
) if str2str else contextlib.nullcontext():
4924+
if str2str:
4925+
kwargs = {
4926+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4927+
"data_keys": [LLMEnv._DEFAULT_STR_KEY],
4928+
"example_data": "a string!",
4929+
"repeats": repeats,
4930+
"assign_reward": assign_reward,
4931+
"assign_done": assign_done,
4932+
}
4933+
else:
4934+
if stack_method is None:
4935+
stack_method = as_padded_tensor
4936+
kwargs = {
4937+
"dataloader": self.DummyTensorDataLoader(
4938+
padding=True, batch_size=batch_size
4939+
),
4940+
"data_keys": [LLMEnv._DEFAULT_TOKEN_KEY],
4941+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4942+
"stack_method": stack_method,
4943+
"repeats": repeats,
4944+
"assign_reward": assign_reward,
4945+
"assign_done": assign_done,
4946+
}
4947+
kwargs.update(
4948+
{
4949+
"str2str": str2str,
4950+
"device": device,
4951+
"has_attention": False,
4952+
"no_stack": False,
4953+
}
4954+
)
4955+
env = LLMEnv.from_dataloader(**kwargs)
4956+
# We want to make sure that transforms that rely on the done state work appropriately
4957+
env.append_transform(StepCounter(max_steps=10))
4958+
4959+
def policy(td):
4960+
td[LLMEnv._DEFAULT_ACTION_KEY] = torch.ones(
4961+
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
4962+
)
4963+
return td
4964+
4965+
if batched:
4966+
r = env.rollout(
4967+
100,
4968+
policy,
4969+
tensordict=TensorDict(batch_size=[3]),
4970+
break_when_any_done=False,
4971+
)
4972+
else:
4973+
r = env.rollout(100, policy, break_when_any_done=False)
4974+
if assign_done:
4975+
assert "terminated" in r
4976+
assert "done" in r
4977+
print(r)
4978+
47384979

47394980
if __name__ == "__main__":
47404981
args, unknown = argparse.ArgumentParser().parse_known_args()

0 commit comments

Comments
 (0)