@@ -440,6 +440,11 @@ class SyncDataCollector(DataCollectorBase):
440
440
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
441
441
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
442
442
If a dictionary of kwargs is passed, it will be used to wrap the policy.
443
+ no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
444
+ For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
445
+ or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
446
+ crashes.
447
+ Defaults to ``False``.
443
448
444
449
Examples:
445
450
>>> from torchrl.envs.libs.gym import GymEnv
@@ -532,6 +537,7 @@ def __init__(
532
537
trust_policy : bool = None ,
533
538
compile_policy : bool | Dict [str , Any ] | None = None ,
534
539
cudagraph_policy : bool | Dict [str , Any ] | None = None ,
540
+ no_cuda_sync : bool = False ,
535
541
** kwargs ,
536
542
):
537
543
from torchrl .envs .batched_envs import BatchedEnvBase
@@ -625,6 +631,7 @@ def __init__(
625
631
else :
626
632
self ._sync_policy = _do_nothing
627
633
self .device = device
634
+ self .no_cuda_sync = no_cuda_sync
628
635
# Check if we need to cast things from device to device
629
636
# If the policy has a None device and the env too, no need to cast (we don't know
630
637
# and assume the user knows what she's doing).
@@ -1010,12 +1017,16 @@ def iterator(self) -> Iterator[TensorDictBase]:
1010
1017
Yields: TensorDictBase objects containing (chunks of) trajectories
1011
1018
1012
1019
"""
1013
- if self .storing_device and self .storing_device .type == "cuda" :
1020
+ if (
1021
+ not self .no_cuda_sync
1022
+ and self .storing_device
1023
+ and self .storing_device .type == "cuda"
1024
+ ):
1014
1025
stream = torch .cuda .Stream (self .storing_device , priority = - 1 )
1015
1026
event = stream .record_event ()
1016
1027
streams = [stream ]
1017
1028
events = [event ]
1018
- elif self .storing_device is None :
1029
+ elif not self . no_cuda_sync and self .storing_device is None :
1019
1030
streams = []
1020
1031
events = []
1021
1032
# this way of checking cuda is robust to lazy stacks with mismatching shapes
@@ -1166,10 +1177,17 @@ def rollout(self) -> TensorDictBase:
1166
1177
else :
1167
1178
if self ._cast_to_policy_device :
1168
1179
if self .policy_device is not None :
1180
+ # This is unsafe if the shuttle is in pin_memory -- otherwise cuda will be happy with non_blocking
1181
+ non_blocking = (
1182
+ not self .no_cuda_sync
1183
+ or self .policy_device .type == "cuda"
1184
+ )
1169
1185
policy_input = self ._shuttle .to (
1170
- self .policy_device , non_blocking = True
1186
+ self .policy_device ,
1187
+ non_blocking = non_blocking ,
1171
1188
)
1172
- self ._sync_policy ()
1189
+ if not self .no_cuda_sync :
1190
+ self ._sync_policy ()
1173
1191
elif self .policy_device is None :
1174
1192
# we know the tensordict has a device otherwise we would not be here
1175
1193
# we can pass this, clear_device_ must have been called earlier
@@ -1191,8 +1209,14 @@ def rollout(self) -> TensorDictBase:
1191
1209
1192
1210
if self ._cast_to_env_device :
1193
1211
if self .env_device is not None :
1194
- env_input = self ._shuttle .to (self .env_device , non_blocking = True )
1195
- self ._sync_env ()
1212
+ non_blocking = (
1213
+ not self .no_cuda_sync or self .env_device .type == "cuda"
1214
+ )
1215
+ env_input = self ._shuttle .to (
1216
+ self .env_device , non_blocking = non_blocking
1217
+ )
1218
+ if not self .no_cuda_sync :
1219
+ self ._sync_env ()
1196
1220
elif self .env_device is None :
1197
1221
# we know the tensordict has a device otherwise we would not be here
1198
1222
# we can pass this, clear_device_ must have been called earlier
@@ -1216,10 +1240,16 @@ def rollout(self) -> TensorDictBase:
1216
1240
return
1217
1241
else :
1218
1242
if self .storing_device is not None :
1243
+ non_blocking = (
1244
+ not self .no_cuda_sync or self .storing_device .type == "cuda"
1245
+ )
1219
1246
tensordicts .append (
1220
- self ._shuttle .to (self .storing_device , non_blocking = True )
1247
+ self ._shuttle .to (
1248
+ self .storing_device , non_blocking = non_blocking
1249
+ )
1221
1250
)
1222
- self ._sync_storage ()
1251
+ if not self .no_cuda_sync :
1252
+ self ._sync_storage ()
1223
1253
else :
1224
1254
tensordicts .append (self ._shuttle )
1225
1255
@@ -1558,6 +1588,11 @@ class _MultiDataCollector(DataCollectorBase):
1558
1588
cudagraph_policy (bool or Dict[str, Any], optional): if ``True``, the policy will be wrapped
1559
1589
in :class:`~tensordict.nn.CudaGraphModule` with default kwargs.
1560
1590
If a dictionary of kwargs is passed, it will be used to wrap the policy.
1591
+ no_cuda_sync (bool): if ``True``, explicit CUDA synchronizations calls will be bypassed.
1592
+ For environments running directly on CUDA (`IsaacLab <https://github.com/isaac-sim/IsaacLab/>`_
1593
+ or `ManiSkills <https://github.com/haosulab/ManiSkill/>`_) cuda synchronization may cause unexpected
1594
+ crashes.
1595
+ Defaults to ``False``.
1561
1596
1562
1597
"""
1563
1598
@@ -1597,6 +1632,7 @@ def __init__(
1597
1632
trust_policy : bool = None ,
1598
1633
compile_policy : bool | Dict [str , Any ] | None = None ,
1599
1634
cudagraph_policy : bool | Dict [str , Any ] | None = None ,
1635
+ no_cuda_sync : bool = False ,
1600
1636
):
1601
1637
self .closed = True
1602
1638
self .num_workers = len (create_env_fn )
@@ -1636,6 +1672,7 @@ def __init__(
1636
1672
self .env_device = env_devices
1637
1673
1638
1674
del storing_device , env_device , policy_device , device
1675
+ self .no_cuda_sync = no_cuda_sync
1639
1676
1640
1677
self ._use_buffers = use_buffers
1641
1678
self .replay_buffer = replay_buffer
@@ -1909,6 +1946,7 @@ def _run_processes(self) -> None:
1909
1946
"cudagraph_policy" : self .cudagraphed_policy_kwargs
1910
1947
if self .cudagraphed_policy
1911
1948
else False ,
1949
+ "no_cuda_sync" : self .no_cuda_sync ,
1912
1950
}
1913
1951
proc = _ProcessNoWarn (
1914
1952
target = _main_async_collector ,
@@ -2914,6 +2952,7 @@ def _main_async_collector(
2914
2952
trust_policy : bool = False ,
2915
2953
compile_policy : bool = False ,
2916
2954
cudagraph_policy : bool = False ,
2955
+ no_cuda_sync : bool = False ,
2917
2956
) -> None :
2918
2957
pipe_parent .close ()
2919
2958
# init variables that will be cleared when closing
@@ -2943,6 +2982,7 @@ def _main_async_collector(
2943
2982
trust_policy = trust_policy ,
2944
2983
compile_policy = compile_policy ,
2945
2984
cudagraph_policy = cudagraph_policy ,
2985
+ no_cuda_sync = no_cuda_sync ,
2946
2986
)
2947
2987
use_buffers = inner_collector ._use_buffers
2948
2988
if verbose :
0 commit comments