Skip to content

Commit 6b48e08

Browse files
author
Vincent Moens
committed
[BugFix] Robust torch.cuda.is_current_stream_capturing calls
ghstack-source-id: 8335587 Pull-Request-resolved: #2950
1 parent a31dca3 commit 6b48e08

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

torchrl/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,13 @@
5050
auto_unwrap_transformed_env,
5151
compile_with_warmup,
5252
implement_for,
53+
logger,
5354
set_auto_unwrap_transformed_env,
5455
timeit,
5556
)
5657

58+
torchrl_logger = logger
59+
5760
# Filter warnings in subprocesses: True by default given the multiple optional
5861
# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.
5962
filter_warnings_subprocess = True
@@ -108,4 +111,6 @@ def _inv(self):
108111
"implement_for",
109112
"set_auto_unwrap_transformed_env",
110113
"timeit",
114+
"logger",
115+
"torchrl_logger",
111116
]

torchrl/_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,3 +1122,17 @@ def auto_unwrap_transformed_env(allow_none=False):
11221122
elif _AUTO_UNWRAP is None:
11231123
return _DEFAULT_AUTO_UNWRAP
11241124
return strtobool(_AUTO_UNWRAP) if isinstance(_AUTO_UNWRAP, str) else _AUTO_UNWRAP
1125+
1126+
1127+
def safe_is_current_stream_capturing():
1128+
"""A safe proxy to torch.cuda.is_current_stream_capturing."""
1129+
if not torch.cuda.is_available():
1130+
return False
1131+
try:
1132+
return torch.cuda.is_current_stream_capturing()
1133+
except Exception as error:
1134+
warnings.warn(
1135+
f"torch.cuda.is_current_stream_capturing() exited unexpectedly with the error message {error=}. "
1136+
f"Returning False by default."
1137+
)
1138+
return False

torchrl/modules/distributions/continuous.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.distributions import constraints
1616
from torch.distributions.transforms import _InverseTransform
1717

18+
from torchrl._utils import safe_is_current_stream_capturing
1819
from torchrl.modules.distributions.truncated_normal import (
1920
TruncatedNormal as _TruncatedNormal,
2021
)
@@ -358,7 +359,7 @@ def __init__(
358359
event_dims = min(1, loc.ndim)
359360

360361
err_msg = "TanhNormal high values must be strictly greater than low values"
361-
if not is_compiling() and not torch.cuda.is_current_stream_capturing():
362+
if not is_compiling() and not safe_is_current_stream_capturing():
362363
if isinstance(high, torch.Tensor) or isinstance(low, torch.Tensor):
363364
if not (high > low).all():
364365
raise RuntimeError(err_msg)
@@ -377,7 +378,7 @@ def __init__(
377378
low = torch.as_tensor(low, device=loc.device)
378379
elif low.device != loc.device:
379380
low = low.to(loc.device)
380-
if not is_compiling() and not torch.cuda.is_current_stream_capturing():
381+
if not is_compiling() and not safe_is_current_stream_capturing():
381382
self.non_trivial_max = (high != 1.0).any()
382383
self.non_trivial_min = (low != -1.0).any()
383384
else:

0 commit comments

Comments
 (0)