Skip to content

Commit b36919e

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents a737d09 + 305f265 commit b36919e

File tree

8 files changed

+703
-306
lines changed

8 files changed

+703
-306
lines changed

test/test_collector.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,39 @@ def make_env():
690690
del env
691691

692692

693+
@pytest.mark.parametrize(
694+
"break_when_any_done,break_when_all_done",
695+
[[True, False], [False, True], [False, False]],
696+
)
697+
@pytest.mark.parametrize("n_envs", [1, 4])
698+
def test_collector_outplace_policy(n_envs, break_when_any_done, break_when_all_done):
699+
def policy_inplace(td):
700+
td.set("action", torch.ones(td.shape + (1,)))
701+
return td
702+
703+
def policy_outplace(td):
704+
return td.empty().set("action", torch.ones(td.shape + (1,)))
705+
706+
if n_envs == 1:
707+
env = CountingEnv(10)
708+
else:
709+
env = SerialEnv(
710+
n_envs,
711+
[functools.partial(CountingEnv, 10 + i) for i in range(n_envs)],
712+
)
713+
env.reset()
714+
c_inplace = SyncDataCollector(
715+
env, policy_inplace, frames_per_batch=10, total_frames=100
716+
)
717+
d_inplace = torch.cat(list(c_inplace), dim=0)
718+
env.reset()
719+
c_outplace = SyncDataCollector(
720+
env, policy_outplace, frames_per_batch=10, total_frames=100
721+
)
722+
d_outplace = torch.cat(list(c_outplace), dim=0)
723+
assert_allclose_td(d_inplace, d_outplace)
724+
725+
693726
# Deprecated reset_when_done
694727
# @pytest.mark.parametrize("num_env", [1, 2])
695728
# @pytest.mark.parametrize("env_name", ["vec"])

0 commit comments

Comments
 (0)