Skip to content

Commit 1474f85

Browse files
author
Vincent Moens
committed
[Refactor] Allow safe-tanh for torch >= 2.6.0
ghstack-source-id: 92df195 Pull Request resolved: #2580
1 parent 600760f commit 1474f85

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

torchrl/modules/distributions/continuous.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,9 @@
1111

1212
import numpy as np
1313
import torch
14+
from packaging import version
1415
from torch import distributions as D, nn
1516

16-
try:
17-
from torch.compiler import assume_constant_result
18-
except ImportError:
19-
from torch._dynamo import assume_constant_result
20-
2117
from torch.distributions import constraints
2218
from torch.distributions.transforms import _InverseTransform
2319

@@ -36,11 +32,20 @@
3632
# speeds up distribution construction
3733
D.Distribution.set_default_validate_args(False)
3834

35+
try:
36+
from torch.compiler import assume_constant_result
37+
except ImportError:
38+
from torch._dynamo import assume_constant_result
39+
3940
try:
4041
from torch.compiler import is_dynamo_compiling
4142
except ImportError:
4243
from torch._dynamo import is_compiling as is_dynamo_compiling
4344

45+
TORCH_VERSION_PRE_2_6 = version.parse(torch.__version__).base_version < version.parse(
46+
"2.6.0"
47+
)
48+
4449

4550
class IndependentNormal(D.Independent):
4651
"""Implements a Normal distribution with location scaling.
@@ -437,7 +442,7 @@ def __init__(
437442
self.high = high
438443

439444
if safe_tanh:
440-
if is_dynamo_compiling():
445+
if is_dynamo_compiling() and TORCH_VERSION_PRE_2_6:
441446
_err_compile_safetanh()
442447
t = SafeTanhTransform()
443448
else:
@@ -772,8 +777,8 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
772777

773778
def _err_compile_safetanh():
774779
raise RuntimeError(
775-
"safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass"
776-
"safe_tanh=False. "
780+
"safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. "
781+
"To deactivate it, pass safe_tanh=False. "
777782
"If you are using a ProbabilisticTensorDictModule, this can be done via "
778783
"`distribution_kwargs={'safe_tanh': False}`. "
779784
"See https://github.com/pytorch/pytorch/issues/133529 for more details."

0 commit comments

Comments
 (0)