Skip to content

Commit 382430d

Browse files
author
Vincent Moens
committed
[Quality] Better device checks
ghstack-source-id: 7174415 Pull Request resolved: #2909
1 parent 96c3003 commit 382430d

File tree

3 files changed

+80
-31
lines changed

3 files changed

+80
-31
lines changed

test/test_env.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,47 +1217,52 @@ def test_env_with_batch_size(
12171217
@pytest.mark.skipif(not _has_dmc, reason="no dm_control")
12181218
@pytest.mark.parametrize("env_task", ["stand,stand,stand", "stand,walk,stand"])
12191219
@pytest.mark.parametrize("share_individual_td", [True, False])
1220+
@pytest.mark.parametrize("device", get_default_devices())
12201221
def test_multi_task_serial_parallel(
1221-
self, env_task, share_individual_td, maybe_fork_ParallelEnv
1222+
self, env_task, share_individual_td, maybe_fork_ParallelEnv, device
12221223
):
1223-
try:
1224-
tasks = env_task.split(",")
1225-
if len(tasks) == 1:
1226-
single_task = True
1224+
tasks = env_task.split(",")
1225+
if len(tasks) == 1:
1226+
single_task = True
12271227

1228-
def env_make():
1229-
return DMControlEnv("humanoid", tasks[0])
1228+
def env_make():
1229+
return DMControlEnv("humanoid", tasks[0], device=device)
12301230

1231-
elif len(set(tasks)) == 1 and len(tasks) == 3:
1232-
single_task = True
1233-
env_make = [lambda: DMControlEnv("humanoid", tasks[0])] * 3
1234-
else:
1235-
single_task = False
1236-
env_make = [
1237-
lambda task=task: DMControlEnv("humanoid", task) for task in tasks
1238-
]
1231+
elif len(set(tasks)) == 1 and len(tasks) == 3:
1232+
single_task = True
1233+
env_make = [lambda: DMControlEnv("humanoid", tasks[0], device=device)] * 3
1234+
else:
1235+
single_task = False
1236+
env_make = [
1237+
lambda task=task: DMControlEnv("humanoid", task, device=device)
1238+
for task in tasks
1239+
]
12391240

1240-
env_serial = SerialEnv(3, env_make, share_individual_td=share_individual_td)
1241+
env_serial = SerialEnv(3, env_make, share_individual_td=share_individual_td)
1242+
try:
12411243
env_serial.start()
12421244
assert env_serial._single_task is single_task
1245+
1246+
env_serial.set_seed(0)
1247+
torch.manual_seed(0)
1248+
td_serial = env_serial.rollout(max_steps=50)
1249+
finally:
1250+
env_serial.close(raise_if_closed=False)
1251+
1252+
try:
12431253
env_parallel = maybe_fork_ParallelEnv(
12441254
3, env_make, share_individual_td=share_individual_td
12451255
)
12461256
env_parallel.start()
12471257
assert env_parallel._single_task is single_task
12481258

1249-
env_serial.set_seed(0)
1250-
torch.manual_seed(0)
1251-
td_serial = env_serial.rollout(max_steps=50)
1252-
12531259
env_parallel.set_seed(0)
12541260
torch.manual_seed(0)
12551261
td_parallel = env_parallel.rollout(max_steps=50)
12561262

12571263
assert_allclose_td(td_serial, td_parallel)
12581264
finally:
12591265
env_parallel.close(raise_if_closed=False)
1260-
env_serial.close(raise_if_closed=False)
12611266

12621267
@pytest.mark.skipif(not _has_dmc, reason="no dm_control")
12631268
def test_multitask(self, maybe_fork_ParallelEnv):

test/test_libs.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import collections
78
import functools
89
import gc
910
import importlib.util
@@ -1762,6 +1763,43 @@ def test_dmcontrol(self, env_name, task, frame_skip, from_pixels, pixels_only):
17621763
assert final_seed0 == final_seed2
17631764
assert_allclose_td(rollout0, rollout2)
17641765

1766+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
1767+
@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])
1768+
@pytest.mark.parametrize("frame_skip", [1, 3])
1769+
@pytest.mark.parametrize(
1770+
"from_pixels,pixels_only", [[True, True], [True, False], [False, False]]
1771+
)
1772+
def test_dmcontrol_device_consistency(
1773+
self, env_name, task, frame_skip, from_pixels, pixels_only
1774+
):
1775+
env0 = DMControlEnv(
1776+
env_name,
1777+
task,
1778+
frame_skip=frame_skip,
1779+
from_pixels=from_pixels,
1780+
pixels_only=pixels_only,
1781+
device="cpu",
1782+
)
1783+
1784+
env1 = DMControlEnv(
1785+
env_name,
1786+
task,
1787+
frame_skip=frame_skip,
1788+
from_pixels=from_pixels,
1789+
pixels_only=pixels_only,
1790+
device="cuda",
1791+
)
1792+
1793+
env0.set_seed(0)
1794+
r0 = env0.rollout(100, break_when_any_done=False)
1795+
assert r0.device == torch.device("cpu")
1796+
actions = collections.deque(r0["action"].unbind(0))
1797+
policy = lambda td: td.set("action", actions.popleft())
1798+
env1.set_seed(0)
1799+
r1 = env1.rollout(100, policy, break_when_any_done=False)
1800+
assert r1.device == torch.device("cuda:0")
1801+
assert_allclose_td(r0, r1.cpu())
1802+
17651803
@pytest.mark.parametrize("env_name,task", [["cheetah", "run"]])
17661804
@pytest.mark.parametrize("frame_skip", [1, 3])
17671805
@pytest.mark.parametrize(

torchrl/modules/tensordict_module/exploration.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,23 +159,29 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
159159
action_tensordict = tensordict
160160
action_key = self.action_key
161161

162-
out = action_tensordict.get(action_key)
162+
action = action_tensordict.get(action_key)
163163
eps = self.eps
164-
cond = torch.rand(action_tensordict.shape, device=out.device) < eps
164+
device = eps.device
165+
action_device = action.device
166+
if action_device is not None and action_device != device:
167+
raise RuntimeError(
168+
f"Expected action and e-greedy module to be on the same device, but got {action.device=} and e-greedy device={device}."
169+
)
170+
cond = torch.rand(action_tensordict.shape, device=device) < eps
165171
# cond = torch.zeros(action_tensordict.shape, device=out.device, dtype=torch.bool).bernoulli_(eps)
166-
cond = expand_as_right(cond, out)
172+
cond = expand_as_right(cond, action)
167173
spec = self.spec
168174
if spec is not None:
169175
if isinstance(spec, Composite):
170176
spec = spec[self.action_key]
171-
if spec.shape != out.shape:
177+
if spec.shape != action.shape:
172178
# In batched envs if the spec is passed unbatched, the rand() will not
173179
# cover all batched dims
174180
if (
175181
not len(spec.shape)
176-
or out.shape[-len(spec.shape) :] == spec.shape
182+
or action.shape[-len(spec.shape) :] == spec.shape
177183
):
178-
spec = spec.expand(out.shape)
184+
spec = spec.expand(action.shape)
179185
else:
180186
raise ValueError(
181187
"Action spec shape does not match the action shape"
@@ -188,12 +194,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
188194
)
189195
spec.update_mask(action_mask)
190196
r = spec.rand()
191-
if r.device != out.device:
192-
r = r.to(out.device)
193-
out = torch.where(cond, r, out)
197+
if r.device != device:
198+
r = r.to(device)
199+
action = torch.where(cond, r, action)
194200
else:
195201
raise RuntimeError("spec must be provided to the exploration wrapper.")
196-
action_tensordict.set(action_key, out)
202+
action_tensordict.set(action_key, action)
197203
return tensordict
198204

199205

0 commit comments

Comments
 (0)