File tree Expand file tree Collapse file tree 3 files changed +22
-2
lines changed Expand file tree Collapse file tree 3 files changed +22
-2
lines changed Original file line number Diff line number Diff line change 50
50
auto_unwrap_transformed_env ,
51
51
compile_with_warmup ,
52
52
implement_for ,
53
+ logger ,
53
54
set_auto_unwrap_transformed_env ,
54
55
timeit ,
55
56
)
56
57
58
+ torchrl_logger = logger
59
+
57
60
# Filter warnings in subprocesses: True by default given the multiple optional
58
61
# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.
59
62
filter_warnings_subprocess = True
@@ -108,4 +111,6 @@ def _inv(self):
108
111
"implement_for" ,
109
112
"set_auto_unwrap_transformed_env" ,
110
113
"timeit" ,
114
+ "logger" ,
115
+ "torchrl_logger" ,
111
116
]
Original file line number Diff line number Diff line change @@ -1122,3 +1122,17 @@ def auto_unwrap_transformed_env(allow_none=False):
1122
1122
elif _AUTO_UNWRAP is None :
1123
1123
return _DEFAULT_AUTO_UNWRAP
1124
1124
return strtobool (_AUTO_UNWRAP ) if isinstance (_AUTO_UNWRAP , str ) else _AUTO_UNWRAP
1125
+
1126
+
1127
+ def safe_is_current_stream_capturing ():
1128
+ """A safe proxy to torch.cuda.is_current_stream_capturing."""
1129
+ if not torch .cuda .is_available ():
1130
+ return False
1131
+ try :
1132
+ return torch .cuda .is_current_stream_capturing ()
1133
+ except Exception as error :
1134
+ warnings .warn (
1135
+ f"torch.cuda.is_current_stream_capturing() exited unexpectedly with the error message { error = } . "
1136
+ f"Returning False by default."
1137
+ )
1138
+ return False
Original file line number Diff line number Diff line change 15
15
from torch .distributions import constraints
16
16
from torch .distributions .transforms import _InverseTransform
17
17
18
+ from torchrl ._utils import safe_is_current_stream_capturing
18
19
from torchrl .modules .distributions .truncated_normal import (
19
20
TruncatedNormal as _TruncatedNormal ,
20
21
)
@@ -358,7 +359,7 @@ def __init__(
358
359
event_dims = min (1 , loc .ndim )
359
360
360
361
err_msg = "TanhNormal high values must be strictly greater than low values"
361
- if not is_compiling () and not torch . cuda . is_current_stream_capturing ():
362
+ if not is_compiling () and not safe_is_current_stream_capturing ():
362
363
if isinstance (high , torch .Tensor ) or isinstance (low , torch .Tensor ):
363
364
if not (high > low ).all ():
364
365
raise RuntimeError (err_msg )
@@ -377,7 +378,7 @@ def __init__(
377
378
low = torch .as_tensor (low , device = loc .device )
378
379
elif low .device != loc .device :
379
380
low = low .to (loc .device )
380
- if not is_compiling () and not torch . cuda . is_current_stream_capturing ():
381
+ if not is_compiling () and not safe_is_current_stream_capturing ():
381
382
self .non_trivial_max = (high != 1.0 ).any ()
382
383
self .non_trivial_min = (low != - 1.0 ).any ()
383
384
else :
You can’t perform that action at this time.
0 commit comments