Skip to content

Commit 280297a

Browse files
author
Vincent Moens
committed
[Feature] no_cuda_sync arg in collectors
ghstack-source-id: 9baba31 Pull Request resolved: #2727
1 parent dda0df1 commit 280297a

File tree

2 files changed

+129
-8
lines changed

2 files changed

+129
-8
lines changed

test/test_collector.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
from __future__ import annotations
66

77
import argparse
8+
import contextlib
89
import functools
910
import gc
1011
import os
1112

1213
import sys
14+
from typing import Optional
15+
from unittest.mock import patch
1316

1417
import numpy as np
1518
import pytest
@@ -469,6 +472,84 @@ def test_output_device(self, main_device, storing_device):
469472
break
470473
assert data.device == storing_device
471474

475+
class CudaPolicy(TensorDictSequential):
476+
def __init__(self, n_obs):
477+
module = torch.nn.Linear(n_obs, n_obs, device="cuda")
478+
module.weight.data.copy_(torch.eye(n_obs))
479+
module.bias.data.fill_(0)
480+
m0 = TensorDictModule(module, in_keys=["observation"], out_keys=["hidden"])
481+
m1 = TensorDictModule(
482+
lambda a: a + 1, in_keys=["hidden"], out_keys=["action"]
483+
)
484+
super().__init__(m0, m1)
485+
486+
class GoesThroughEnv(EnvBase):
487+
def __init__(self, n_obs, device):
488+
self.observation_spec = Composite(observation=Unbounded(n_obs))
489+
self.action_spec = Unbounded(n_obs)
490+
self.reward_spec = Unbounded(1)
491+
self.full_done_specs = Composite(done=Unbounded(1, dtype=torch.bool))
492+
super().__init__(device=device)
493+
494+
def _step(
495+
self,
496+
tensordict: TensorDictBase,
497+
) -> TensorDictBase:
498+
a = tensordict["action"]
499+
if self.device is not None:
500+
assert a.device == self.device
501+
out = tensordict.empty()
502+
out["observation"] = tensordict["observation"] + (
503+
a - tensordict["observation"]
504+
)
505+
out["reward"] = torch.zeros((1,), device=self.device)
506+
out["done"] = torch.zeros((1,), device=self.device, dtype=torch.bool)
507+
return out
508+
509+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
510+
return self.full_done_specs.zeros().update(self.observation_spec.zeros())
511+
512+
def _set_seed(self, seed: Optional[int]):
513+
return seed
514+
515+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device")
516+
@pytest.mark.parametrize("env_device", ["cuda:0", "cpu"])
517+
@pytest.mark.parametrize("storing_device", [None, "cuda:0", "cpu"])
518+
@pytest.mark.parametrize("no_cuda_sync", [True, False])
519+
def test_no_synchronize(self, env_device, storing_device, no_cuda_sync):
520+
"""Tests that no_cuda_sync avoids any call to torch.cuda.synchronize() and that the data is not corrupted."""
521+
should_raise = not no_cuda_sync
522+
should_raise = should_raise & (
523+
(env_device == "cpu") or (storing_device == "cpu")
524+
)
525+
with patch("torch.cuda.synchronize") as mock_synchronize, pytest.raises(
526+
AssertionError, match="Expected 'synchronize' to not have been called."
527+
) if should_raise else contextlib.nullcontext():
528+
collector = SyncDataCollector(
529+
create_env_fn=functools.partial(
530+
self.GoesThroughEnv, n_obs=1000, device=None
531+
),
532+
policy=self.CudaPolicy(n_obs=1000),
533+
frames_per_batch=100,
534+
total_frames=1000,
535+
env_device=env_device,
536+
storing_device=storing_device,
537+
policy_device="cuda:0",
538+
no_cuda_sync=no_cuda_sync,
539+
)
540+
assert collector.env.device == torch.device(env_device)
541+
i = 0
542+
for d in collector:
543+
for _d in d.unbind(0):
544+
u = _d["observation"].unique()
545+
assert u.numel() == 1, i
546+
assert u == i, i
547+
i += 1
548+
u = _d["next", "observation"].unique()
549+
assert u.numel() == 1, i
550+
assert u == i, i
551+
mock_synchronize.assert_not_called()
552+
472553

473554
# @pytest.mark.skipif(
474555
# IS_WINDOWS and PYTHON_3_10,

torchrl/collectors/collectors.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,11 @@ class SyncDataCollector(DataCollectorBase):
440440
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
441441
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
442442
If a dictionary of kwargs is passed, it will be used to wrap the policy.
443+
no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
444+
For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
445+
or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
446+
crashes.
447+
Defaults to ``False``.
443448
444449
Examples:
445450
>>> from torchrl.envs.libs.gym import GymEnv
@@ -532,6 +537,7 @@ def __init__(
532537
trust_policy: bool = None,
533538
compile_policy: bool | Dict[str, Any] | None = None,
534539
cudagraph_policy: bool | Dict[str, Any] | None = None,
540+
no_cuda_sync: bool = False,
535541
**kwargs,
536542
):
537543
from torchrl.envs.batched_envs import BatchedEnvBase
@@ -625,6 +631,7 @@ def __init__(
625631
else:
626632
self._sync_policy = _do_nothing
627633
self.device = device
634+
self.no_cuda_sync = no_cuda_sync
628635
# Check if we need to cast things from device to device
629636
# If the policy has a None device and the env too, no need to cast (we don't know
630637
# and assume the user knows what she's doing).
@@ -1010,12 +1017,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
10101017
Yields: TensorDictBase objects containing (chunks of) trajectories
10111018
10121019
"""
1013-
if self.storing_device and self.storing_device.type == "cuda":
1020+
if (
1021+
not self.no_cuda_sync
1022+
and self.storing_device
1023+
and self.storing_device.type == "cuda"
1024+
):
10141025
stream = torch.cuda.Stream(self.storing_device, priority=-1)
10151026
event = stream.record_event()
10161027
streams = [stream]
10171028
events = [event]
1018-
elif self.storing_device is None:
1029+
elif not self.no_cuda_sync and self.storing_device is None:
10191030
streams = []
10201031
events = []
10211032
# this way of checking cuda is robust to lazy stacks with mismatching shapes
@@ -1166,10 +1177,17 @@ def rollout(self) -> TensorDictBase:
11661177
else:
11671178
if self._cast_to_policy_device:
11681179
if self.policy_device is not None:
1180+
# This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking
1181+
non_blocking = (
1182+
not self.no_cuda_sync
1183+
or self.policy_device.type == "cuda"
1184+
)
11691185
policy_input = self._shuttle.to(
1170-
self.policy_device, non_blocking=True
1186+
self.policy_device,
1187+
non_blocking=non_blocking,
11711188
)
1172-
self._sync_policy()
1189+
if not self.no_cuda_sync:
1190+
self._sync_policy()
11731191
elif self.policy_device is None:
11741192
# we know the tensordict has a device otherwise we would not be here
11751193
# we can pass this, clear_device_ must have been called earlier
@@ -1191,8 +1209,14 @@ def rollout(self) -> TensorDictBase:
11911209

11921210
if self._cast_to_env_device:
11931211
if self.env_device is not None:
1194-
env_input = self._shuttle.to(self.env_device, non_blocking=True)
1195-
self._sync_env()
1212+
non_blocking = (
1213+
not self.no_cuda_sync or self.env_device.type == "cuda"
1214+
)
1215+
env_input = self._shuttle.to(
1216+
self.env_device, non_blocking=non_blocking
1217+
)
1218+
if not self.no_cuda_sync:
1219+
self._sync_env()
11961220
elif self.env_device is None:
11971221
# we know the tensordict has a device otherwise we would not be here
11981222
# we can pass this, clear_device_ must have been called earlier
@@ -1216,10 +1240,16 @@ def rollout(self) -> TensorDictBase:
12161240
return
12171241
else:
12181242
if self.storing_device is not None:
1243+
non_blocking = (
1244+
not self.no_cuda_sync or self.storing_device.type == "cuda"
1245+
)
12191246
tensordicts.append(
1220-
self._shuttle.to(self.storing_device, non_blocking=True)
1247+
self._shuttle.to(
1248+
self.storing_device, non_blocking=non_blocking
1249+
)
12211250
)
1222-
self._sync_storage()
1251+
if not self.no_cuda_sync:
1252+
self._sync_storage()
12231253
else:
12241254
tensordicts.append(self._shuttle)
12251255

@@ -1558,6 +1588,11 @@ class _MultiDataCollector(DataCollectorBase):
15581588
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
15591589
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
15601590
If a dictionary of kwargs is passed, it will be used to wrap the policy.
1591+
no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
1592+
For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
1593+
or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
1594+
crashes.
1595+
Defaults to ``False``.
15611596
15621597
"""
15631598

@@ -1597,6 +1632,7 @@ def __init__(
15971632
trust_policy: bool = None,
15981633
compile_policy: bool | Dict[str, Any] | None = None,
15991634
cudagraph_policy: bool | Dict[str, Any] | None = None,
1635+
no_cuda_sync: bool = False,
16001636
):
16011637
self.closed = True
16021638
self.num_workers = len(create_env_fn)
@@ -1636,6 +1672,7 @@ def __init__(
16361672
self.env_device = env_devices
16371673

16381674
del storing_device, env_device, policy_device, device
1675+
self.no_cuda_sync = no_cuda_sync
16391676

16401677
self._use_buffers = use_buffers
16411678
self.replay_buffer = replay_buffer
@@ -1909,6 +1946,7 @@ def _run_processes(self) -> None:
19091946
"cudagraph_policy": self.cudagraphed_policy_kwargs
19101947
if self.cudagraphed_policy
19111948
else False,
1949+
"no_cuda_sync": self.no_cuda_sync,
19121950
}
19131951
proc = _ProcessNoWarn(
19141952
target=_main_async_collector,
@@ -2914,6 +2952,7 @@ def _main_async_collector(
29142952
trust_policy: bool = False,
29152953
compile_policy: bool = False,
29162954
cudagraph_policy: bool = False,
2955+
no_cuda_sync: bool = False,
29172956
) -> None:
29182957
pipe_parent.close()
29192958
# init variables that will be cleared when closing
@@ -2943,6 +2982,7 @@ def _main_async_collector(
29432982
trust_policy=trust_policy,
29442983
compile_policy=compile_policy,
29452984
cudagraph_policy=cudagraph_policy,
2985+
no_cuda_sync=no_cuda_sync,
29462986
)
29472987
use_buffers = inner_collector._use_buffers
29482988
if verbose:

0 commit comments

Comments
 (0)