Skip to content

Commit 9dad72b

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent f4713f9 commit 9dad72b

File tree

15 files changed

+749
-279
lines changed

15 files changed

+749
-279
lines changed

docs/source/reference/data.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,16 +1144,17 @@ Utils
11441144
:toctree: generated/
11451145
:template: rl_template.rst
11461146

1147-
MultiStep
1148-
consolidate_spec
1149-
check_no_exclusive_keys
1150-
contains_lazy_spec
1151-
Nested2TED
1147+
DensifyReward
11521148
Flat2TED
11531149
H5Combine
11541150
H5Split
1151+
MultiStep
1152+
Nested2TED
11551153
TED2Flat
11561154
TED2Nested
1155+
check_no_exclusive_keys
1156+
consolidate_spec
1157+
contains_lazy_spec
11571158

11581159
.. currentmodule:: torchrl.envs.transforms.rb_transforms
11591160

test/test_env.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4700,6 +4700,180 @@ def policy(td):
47004700
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
47014701
assert r.ndim == 1
47024702

4703+
@pytest.mark.parametrize(
4704+
"str2str,stack_method",
4705+
[
4706+
[True, None],
4707+
[False, "as_padded_tensor"],
4708+
# TODO: a bit experimental, fails with check_env_specs
4709+
# [False, "as_nested_tensor"],
4710+
[False, None],
4711+
],
4712+
)
4713+
@pytest.mark.parametrize("batched", [True, False])
4714+
@pytest.mark.parametrize("device", [None, "cpu"])
4715+
@pytest.mark.parametrize("batch_size", [0, 4])
4716+
@pytest.mark.parametrize("repeats", [3])
4717+
def test_llm_from_dataloader_repeats(
4718+
self, str2str, batched, stack_method, device, batch_size, repeats
4719+
):
4720+
if str2str:
4721+
kwargs = {
4722+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4723+
"data_keys": ["observation"],
4724+
"example_data": "a string!",
4725+
"repeats": repeats,
4726+
}
4727+
else:
4728+
if stack_method is None:
4729+
stack_method = as_padded_tensor
4730+
kwargs = {
4731+
"dataloader": self.DummyTensorDataLoader(
4732+
padding=True, batch_size=batch_size
4733+
),
4734+
"data_keys": ["observation"],
4735+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4736+
"stack_method": stack_method,
4737+
"repeats": repeats,
4738+
}
4739+
kwargs.update({"str2str": str2str, "device": device})
4740+
env = LLMEnv.from_dataloader(**kwargs)
4741+
assert env.transform.repeats == repeats
4742+
4743+
max_steps = 3
4744+
env.append_transform(StepCounter(max_steps=max_steps))
4745+
4746+
def policy(td):
4747+
if str2str:
4748+
if not td.shape:
4749+
td["action"] = "<nothing>"
4750+
else:
4751+
td["action"] = NonTensorStack(
4752+
*["<nothing>" for _ in range(td.shape[0])]
4753+
)
4754+
else:
4755+
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4756+
return td
4757+
4758+
if batched:
4759+
r = env.rollout(
4760+
100,
4761+
policy,
4762+
tensordict=TensorDict(batch_size=[3]),
4763+
break_when_any_done=False,
4764+
)
4765+
else:
4766+
r = env.rollout(100, policy, break_when_any_done=False)
4767+
# check that r at reset is always the same
4768+
r_reset = r[..., ::max_steps]
4769+
if not batched:
4770+
if str2str:
4771+
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4772+
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4773+
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4774+
else:
4775+
assert (
4776+
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4777+
).all()
4778+
assert (
4779+
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4780+
).all()
4781+
assert (
4782+
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4783+
).any()
4784+
else:
4785+
# When batched, each block contains the 3 reset packs
4786+
if str2str:
4787+
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4788+
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4789+
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4790+
else:
4791+
assert (
4792+
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4793+
).all()
4794+
assert (
4795+
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4796+
).all()
4797+
assert (
4798+
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4799+
).any()
4800+
4801+
@pytest.mark.parametrize(
4802+
"str2str,stack_method",
4803+
[
4804+
[True, None],
4805+
[False, "as_padded_tensor"],
4806+
],
4807+
)
4808+
@pytest.mark.parametrize("batched", [True])
4809+
@pytest.mark.parametrize("device", [None])
4810+
@pytest.mark.parametrize("batch_size", [4])
4811+
@pytest.mark.parametrize("repeats", [3])
4812+
@pytest.mark.parametrize(
4813+
"assign_reward,assign_done", [[True, False], [True, True], [False, True]]
4814+
)
4815+
def test_done_and_reward(
4816+
self,
4817+
str2str,
4818+
batched,
4819+
stack_method,
4820+
device,
4821+
batch_size,
4822+
repeats,
4823+
assign_reward,
4824+
assign_done,
4825+
):
4826+
with pytest.raises(
4827+
ValueError, match="str2str"
4828+
) if str2str else contextlib.nullcontext():
4829+
if str2str:
4830+
kwargs = {
4831+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4832+
"data_keys": ["observation"],
4833+
"example_data": "a string!",
4834+
"repeats": repeats,
4835+
"assign_reward": assign_reward,
4836+
"assign_done": assign_done,
4837+
}
4838+
else:
4839+
if stack_method is None:
4840+
stack_method = as_padded_tensor
4841+
kwargs = {
4842+
"dataloader": self.DummyTensorDataLoader(
4843+
padding=True, batch_size=batch_size
4844+
),
4845+
"data_keys": ["observation"],
4846+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4847+
"stack_method": stack_method,
4848+
"repeats": repeats,
4849+
"assign_reward": assign_reward,
4850+
"assign_done": assign_done,
4851+
}
4852+
kwargs.update({"str2str": str2str, "device": device})
4853+
env = LLMEnv.from_dataloader(**kwargs)
4854+
# We want to make sure that transforms that rely on the done state work appropriately
4855+
env.append_transform(StepCounter(max_steps=10))
4856+
4857+
def policy(td):
4858+
td["action"] = torch.ones(
4859+
td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64
4860+
)
4861+
return td
4862+
4863+
if batched:
4864+
r = env.rollout(
4865+
100,
4866+
policy,
4867+
tensordict=TensorDict(batch_size=[3]),
4868+
break_when_any_done=False,
4869+
)
4870+
else:
4871+
r = env.rollout(100, policy, break_when_any_done=False)
4872+
if assign_done:
4873+
assert "terminated" in r
4874+
assert "done" in r
4875+
print(r)
4876+
47034877

47044878
if __name__ == "__main__":
47054879
args, unknown = argparse.ArgumentParser().parse_known_args()

0 commit comments

Comments
 (0)