Skip to content

Commit 904d639

Browse files
author
Vincent Moens
committed
[Feature] DataLoadingPrimer.repeat
ghstack-source-id: fb2e2b6 Pull Request resolved: #2822
1 parent 5ec9bc5 commit 904d639

File tree

4 files changed

+129
-9
lines changed

4 files changed

+129
-9
lines changed

test/test_env.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4763,6 +4763,104 @@ def policy(td):
47634763
r = env.rollout(10, policy, tensordict=TensorDict(batch_size=[]))
47644764
assert r.ndim == 1
47654765

4766+
@pytest.mark.parametrize(
4767+
"str2str,stack_method",
4768+
[
4769+
[True, None],
4770+
[False, "as_padded_tensor"],
4771+
# TODO: a bit experimental, fails with check_env_specs
4772+
# [False, "as_nested_tensor"],
4773+
[False, None],
4774+
],
4775+
)
4776+
@pytest.mark.parametrize("batched", [True, False])
4777+
@pytest.mark.parametrize("device", [None, "cpu"])
4778+
@pytest.mark.parametrize("batch_size", [0, 4])
4779+
@pytest.mark.parametrize("repeats", [3])
4780+
def test_llm_from_dataloader_repeats(
4781+
self, str2str, batched, stack_method, device, batch_size, repeats
4782+
):
4783+
if str2str:
4784+
kwargs = {
4785+
"dataloader": self.DummyDataLoader(batch_size=batch_size),
4786+
"data_keys": ["observation"],
4787+
"example_data": "a string!",
4788+
"repeats": repeats,
4789+
}
4790+
else:
4791+
if stack_method is None:
4792+
stack_method = as_padded_tensor
4793+
kwargs = {
4794+
"dataloader": self.DummyTensorDataLoader(
4795+
padding=True, batch_size=batch_size
4796+
),
4797+
"data_keys": ["observation"],
4798+
"data_specs": [Unbounded(shape=(-1,), dtype=torch.int64)],
4799+
"stack_method": stack_method,
4800+
"repeats": repeats,
4801+
}
4802+
kwargs.update({"str2str": str2str, "device": device})
4803+
env = LLMEnv.from_dataloader(**kwargs)
4804+
assert env.transform.repeats == repeats
4805+
4806+
max_steps = 3
4807+
env.append_transform(StepCounter(max_steps=max_steps))
4808+
4809+
def policy(td):
4810+
if str2str:
4811+
if not td.shape:
4812+
td["action"] = "<nothing>"
4813+
else:
4814+
td["action"] = NonTensorStack(
4815+
*["<nothing>" for _ in range(td.shape[0])]
4816+
)
4817+
else:
4818+
td["action"] = torch.ones(td.shape + (1,), dtype=torch.int64)
4819+
return td
4820+
4821+
if batched:
4822+
r = env.rollout(
4823+
100,
4824+
policy,
4825+
tensordict=TensorDict(batch_size=[3]),
4826+
break_when_any_done=False,
4827+
)
4828+
else:
4829+
r = env.rollout(100, policy, break_when_any_done=False)
4830+
# check that r at reset is always the same
4831+
r_reset = r[..., ::max_steps]
4832+
if not batched:
4833+
if str2str:
4834+
assert r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4835+
assert r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4836+
assert r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4837+
else:
4838+
assert (
4839+
r_reset[..., 0]["observation"] == r_reset[..., 1]["observation"]
4840+
).all()
4841+
assert (
4842+
r_reset[..., 0]["observation"] == r_reset[..., 2]["observation"]
4843+
).all()
4844+
assert (
4845+
r_reset[..., 0]["observation"] != r_reset[..., 3]["observation"]
4846+
).any()
4847+
else:
4848+
# When batched, each block contains the 3 reset packs
4849+
if str2str:
4850+
assert r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4851+
assert r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4852+
assert r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4853+
else:
4854+
assert (
4855+
r_reset[0, 0]["observation"] == r_reset[1, 0]["observation"]
4856+
).all()
4857+
assert (
4858+
r_reset[0, 0]["observation"] == r_reset[2, 0]["observation"]
4859+
).all()
4860+
assert (
4861+
r_reset[0, 0]["observation"] != r_reset[0, 1]["observation"]
4862+
).any()
4863+
47664864

47674865
if __name__ == "__main__":
47684866
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/envs/custom/llm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def from_dataloader(
142142
example_data: Any = None,
143143
stack_method: Callable[[Any], Any]
144144
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
145+
repeats: int | None = None,
145146
) -> LLMEnv:
146147
"""Creates an LLMEnv instance from a dataloader.
147148
@@ -165,6 +166,9 @@ def from_dataloader(
165166
example_data (Any, optional): Example data to use for initializing the primer. Defaults to ``None``.
166167
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The
167168
method to use for stacking the data. Defaults to ``None``.
169+
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
170+
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
171+
samples (rather than an advantage module).
168172
169173
Returns:
170174
LLMEnv: The created LLMEnv instance.
@@ -178,6 +182,7 @@ def from_dataloader(
178182
data_specs=data_specs,
179183
example_data=example_data,
180184
stack_method=stack_method,
185+
repeats=repeats,
181186
)
182187
env = LLMEnv(
183188
str2str=str2str,

torchrl/envs/transforms/llm.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ class DataLoadingPrimer(TensorDictPrimer):
103103
auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the
104104
tensordict returned by the transform will be automatically determined assuming that there is a single batch
105105
dimension.
106+
repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in
107+
situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo
108+
samples (rather than an advantage module).
106109
107110
Attributes:
108111
dataloader (Iterable[Any]): The dataloader to load data from.
@@ -359,15 +362,21 @@ def __init__(
359362
| Literal["as_nested_tensor", "as_padded_tensor"] = None,
360363
use_buffer: bool | None = None,
361364
auto_batch_size: bool = True,
365+
repeats: int | None = None,
362366
):
363367
self.dataloader = dataloader
364-
if getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None:
368+
if repeats is None:
369+
repeats = 0
370+
self.repeats = repeats
371+
if (
372+
getattr(dataloader, "batch_size", 1) > 1 and use_buffer is None
373+
) or repeats > 0:
365374
use_buffer = True
366375

367376
self.use_buffer = use_buffer
368377
# No auto_batch_size if we know we have a single element
369378
self.auto_batch_size = auto_batch_size and (
370-
getattr(dataloader, "dataloader", 1) > 0
379+
getattr(dataloader, "batch_size", 1) > 0
371380
)
372381
self.endless_dataloader = self._endless_iter(self.dataloader)
373382
if primers is None:
@@ -420,11 +429,13 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
420429
if not reset.any():
421430
raise RuntimeError("reset must have at least one True value.")
422431
if reset.ndim > 0:
423-
return self.stack_method(
424-
[self._load_from_dataloader() for i in range(reset.sum())]
425-
)
432+
loaded = [self._load_from_dataloader() for i in range(reset.sum())]
433+
return self.stack_method(loaded)
434+
426435
if self.use_buffer and len(self._queue) > 0:
427-
return self._queue.popleft()
436+
result = self._queue.popleft()
437+
return result
438+
428439
data = next(self.endless_dataloader)
429440
# Some heuristic here:
430441
# if data is a map, assume its keys match the keys in spec
@@ -450,7 +461,11 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
450461
f"Unrecognized data type: {type(data)} with keys {self.data_keys}."
451462
)
452463
if self.use_buffer:
453-
self._queue.extend(out.unbind(0))
464+
if not out.ndim:
465+
out = out.unsqueeze(0)
466+
self._queue.extend(
467+
[d for d in out.unbind(0) for _ in range(max(1, self.repeats))]
468+
)
454469
return self._queue.popleft()
455470
return out
456471

torchrl/envs/transforms/transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7352,7 +7352,9 @@ def _reset(
73527352
else:
73537353
# It may be the case that reset did not provide a done state, in which case
73547354
# we fall back on the spec
7355-
done = self.parent.output_spec["full_done_spec", entry_name].zero()
7355+
done = self.parent.output_spec_unbatched[
7356+
"full_done_spec", entry_name
7357+
].zero(tensordict_reset.shape)
73567358
reset = torch.ones_like(done)
73577359

73587360
step_count = tensordict.get(step_count_key, default=None)
@@ -7362,7 +7364,7 @@ def _reset(
73627364
step_count = step_count.to(reset.device, non_blocking=True)
73637365

73647366
# zero the step count if reset is needed
7365-
step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0)
7367+
step_count = torch.where(~reset, step_count.expand_as(reset), 0)
73667368
tensordict_reset.set(step_count_key, step_count)
73677369
if self.max_steps is not None:
73687370
truncated = step_count >= self.max_steps

0 commit comments

Comments
 (0)