Skip to content

Commit 93ba865

Browse files
author
Vincent Moens
committed
[Refactor] Fix repeats order
ghstack-source-id: 0bedd5c Pull Request resolved: #2887
1 parent 6f634c6 commit 93ba865

File tree

2 files changed

+12
-20
lines changed

2 files changed

+12
-20
lines changed

test/test_env.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4861,31 +4861,22 @@ def policy(td):
48614861
r_reset = r[..., ::max_steps]
48624862
if not batched:
48634863
if str2str:
4864+
all_strings = r_reset.view(-1)[LLMEnv._DEFAULT_STR_KEY]
4865+
assert sum(s == all_strings[0] for s in all_strings) == repeats
4866+
assert sum(s == all_strings[repeats] for s in all_strings) == repeats
48644867
assert (
4865-
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4866-
== r_reset[..., 1][LLMEnv._DEFAULT_STR_KEY]
4868+
sum(s == all_strings[repeats * 2] for s in all_strings) == repeats
48674869
)
4870+
else:
4871+
all_tokens = r_reset.view(-1)[LLMEnv._DEFAULT_TOKEN_KEY]
4872+
assert sum((s == all_tokens[0]).all() for s in all_tokens) == repeats
48684873
assert (
4869-
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4870-
== r_reset[..., 2][LLMEnv._DEFAULT_STR_KEY]
4874+
sum((s == all_tokens[repeats]).all() for s in all_tokens) == repeats
48714875
)
48724876
assert (
4873-
r_reset[..., 0][LLMEnv._DEFAULT_STR_KEY]
4874-
!= r_reset[..., 3][LLMEnv._DEFAULT_STR_KEY]
4877+
sum((s == all_tokens[repeats * 2]).all() for s in all_tokens)
4878+
== repeats
48754879
)
4876-
else:
4877-
assert (
4878-
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4879-
== r_reset[..., 1][LLMEnv._DEFAULT_TOKEN_KEY]
4880-
).all()
4881-
assert (
4882-
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4883-
== r_reset[..., 2][LLMEnv._DEFAULT_TOKEN_KEY]
4884-
).all()
4885-
assert (
4886-
r_reset[..., 0][LLMEnv._DEFAULT_TOKEN_KEY]
4887-
!= r_reset[..., 3][LLMEnv._DEFAULT_TOKEN_KEY]
4888-
).any()
48894880
else:
48904881
# When batched, each block contains the 3 reset packs
48914882
if str2str:

torchrl/envs/transforms/llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ def _load_from_dataloader(self, reset: torch.Tensor | None = None):
562562
if not out.ndim:
563563
out = out.unsqueeze(0)
564564
self._queue.extend(
565-
[d for _ in range(max(1, self.repeats)) for d in out.unbind(0)]
565+
[d for d in out.unbind(0) for _ in range(max(1, self.repeats))]
566+
# [d for _ in range(max(1, self.repeats)) for d in out.unbind(0)]
566567
)
567568
return self._queue.popleft()
568569
return out

0 commit comments

Comments
 (0)