Skip to content

Commit c83c04c

Browse files
author
Vincent Moens
committed
[BugFix] Fix PEnv device copies
ghstack-source-id: df39fd2 Pull Request resolved: #2840 (cherry picked from commit 6e40548)
1 parent edc284f commit c83c04c

File tree

2 files changed

+47
-35
lines changed

2 files changed

+47
-35
lines changed

test/test_env.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,34 @@ def test_parallel_env_device(
15991599
env_serial.close()
16001600
env0.close()
16011601

1602+
@pytest.mark.skipif(not _has_gym, reason="no gym")
1603+
@pytest.mark.parametrize("env_device", [None, "cpu"])
1604+
def test_parallel_env_device_vs_no_device(self, maybe_fork_ParallelEnv, env_device):
1605+
def make_env() -> GymEnv:
1606+
env = GymEnv(PENDULUM_VERSIONED(), device=env_device)
1607+
return env.append_transform(DoubleToFloat())
1608+
1609+
# Rollouts work with a regular env
1610+
parallel_env = maybe_fork_ParallelEnv(
1611+
num_workers=1, create_env_fn=make_env, device=None
1612+
)
1613+
parallel_env.reset()
1614+
parallel_env.set_seed(0)
1615+
torch.manual_seed(0)
1616+
1617+
parallel_rollout = parallel_env.rollout(max_steps=10)
1618+
1619+
# Rollout doesn't work with Parallelnv
1620+
parallel_env = maybe_fork_ParallelEnv(
1621+
num_workers=1, create_env_fn=make_env, device="cpu"
1622+
)
1623+
parallel_env.reset()
1624+
parallel_env.set_seed(0)
1625+
torch.manual_seed(0)
1626+
1627+
parallel_rollout_cpu = parallel_env.rollout(max_steps=10)
1628+
assert_allclose_td(parallel_rollout, parallel_rollout_cpu)
1629+
16021630
@pytest.mark.skipif(not _has_gym, reason="no gym")
16031631
@pytest.mark.flaky(reruns=3, reruns_delay=1)
16041632
@pytest.mark.parametrize(

torchrl/envs/batched_envs.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,14 @@ def __init__(
374374

375375
is_spec_locked = EnvBase.is_spec_locked
376376

377+
def select_and_clone(self, name, tensor, selected_keys=None):
378+
if selected_keys is None:
379+
selected_keys = self._selected_step_keys
380+
if name in selected_keys:
381+
if self.device is not None and tensor.device != self.device:
382+
return tensor.to(self.device, non_blocking=self.non_blocking)
383+
return tensor.clone()
384+
377385
@property
378386
def non_blocking(self):
379387
nb = self._non_blocking
@@ -1062,12 +1070,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10621070
selected_output_keys = self._selected_reset_keys_filt
10631071

10641072
# select + clone creates 2 tds, but we can create one only
1065-
def select_and_clone(name, tensor):
1066-
if name in selected_output_keys:
1067-
return tensor.clone()
1068-
10691073
out = self.shared_tensordict_parent.named_apply(
1070-
select_and_clone,
1074+
lambda *args: self.select_and_clone(
1075+
*args, selected_keys=selected_output_keys
1076+
),
10711077
nested_keys=True,
10721078
filter_empty=True,
10731079
)
@@ -1135,14 +1141,14 @@ def _step(
11351141
# will be modified in-place at further steps
11361142
device = self.device
11371143

1138-
def select_and_clone(name, tensor):
1139-
if name in self._selected_step_keys:
1140-
return tensor.clone()
1144+
selected_keys = self._selected_step_keys
11411145

11421146
if partial_steps is not None:
11431147
next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range])
11441148
out = next_td.named_apply(
1145-
select_and_clone, nested_keys=True, filter_empty=True
1149+
lambda *args: self.select_and_clone(*args, selected_keys),
1150+
nested_keys=True,
1151+
filter_empty=True,
11461152
)
11471153
if out_tds is not None:
11481154
out.update(
@@ -1841,20 +1847,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
18411847
next_td = shared_tensordict_parent.get("next")
18421848
device = self.device
18431849

1844-
if next_td.device != device and device is not None:
1845-
1846-
def select_and_clone(name, tensor):
1847-
if name in self._selected_step_keys:
1848-
return tensor.to(device, non_blocking=self.non_blocking)
1849-
1850-
else:
1851-
1852-
def select_and_clone(name, tensor):
1853-
if name in self._selected_step_keys:
1854-
return tensor.clone()
1855-
18561850
out = next_td.named_apply(
1857-
select_and_clone,
1851+
self.select_and_clone,
18581852
nested_keys=True,
18591853
filter_empty=True,
18601854
device=device,
@@ -2005,20 +1999,10 @@ def tentative_update(val, other):
20051999
selected_output_keys = self._selected_reset_keys_filt
20062000
device = self.device
20072001

2008-
if self.shared_tensordict_parent.device != device and device is not None:
2009-
2010-
def select_and_clone(name, tensor):
2011-
if name in selected_output_keys:
2012-
return tensor.to(device, non_blocking=self.non_blocking)
2013-
2014-
else:
2015-
2016-
def select_and_clone(name, tensor):
2017-
if name in selected_output_keys:
2018-
return tensor.clone()
2019-
20202002
out = self.shared_tensordict_parent.named_apply(
2021-
select_and_clone,
2003+
lambda *args: self.select_and_clone(
2004+
*args, selected_keys=selected_output_keys
2005+
),
20222006
nested_keys=True,
20232007
filter_empty=True,
20242008
device=device,

0 commit comments

Comments
 (0)