|
11 | 11 |
|
12 | 12 | import numpy as np
|
13 | 13 | import torch
|
| 14 | +from packaging import version |
14 | 15 | from torch import distributions as D, nn
|
15 | 16 |
|
16 |
| -try: |
17 |
| - from torch.compiler import assume_constant_result |
18 |
| -except ImportError: |
19 |
| - from torch._dynamo import assume_constant_result |
20 |
| - |
21 | 17 | from torch.distributions import constraints
|
22 | 18 | from torch.distributions.transforms import _InverseTransform
|
23 | 19 |
|
|
36 | 32 | # speeds up distribution construction
|
37 | 33 | D.Distribution.set_default_validate_args(False)
|
38 | 34 |
|
| 35 | +try: |
| 36 | + from torch.compiler import assume_constant_result |
| 37 | +except ImportError: |
| 38 | + from torch._dynamo import assume_constant_result |
| 39 | + |
39 | 40 | try:
|
40 | 41 | from torch.compiler import is_dynamo_compiling
|
41 | 42 | except ImportError:
|
42 | 43 | from torch._dynamo import is_compiling as is_dynamo_compiling
|
43 | 44 |
|
| 45 | +TORCH_VERSION_PRE_2_6 = version.parse(torch.__version__).base_version < version.parse( |
| 46 | + "2.6.0" |
| 47 | +) |
| 48 | + |
44 | 49 |
|
45 | 50 | class IndependentNormal(D.Independent):
|
46 | 51 | """Implements a Normal distribution with location scaling.
|
@@ -437,7 +442,7 @@ def __init__(
|
437 | 442 | self.high = high
|
438 | 443 |
|
439 | 444 | if safe_tanh:
|
440 |
| - if is_dynamo_compiling(): |
| 445 | + if is_dynamo_compiling() and TORCH_VERSION_PRE_2_6: |
441 | 446 | _err_compile_safetanh()
|
442 | 447 | t = SafeTanhTransform()
|
443 | 448 | else:
|
@@ -772,8 +777,8 @@ def _uniform_sample_delta(dist: Delta, size=None) -> torch.Tensor:
|
772 | 777 |
|
773 | 778 | def _err_compile_safetanh():
|
774 | 779 | raise RuntimeError(
|
775 |
| - "safe_tanh=True in TanhNormal is not compatible with torch.compile. To deactivate it, pass" |
776 |
| - "safe_tanh=False. " |
| 780 | + "safe_tanh=True in TanhNormal is not compatible with torch.compile with torch pre 2.6.0. " |
| 781 | + "To deactivate it, pass safe_tanh=False. " |
777 | 782 | "If you are using a ProbabilisticTensorDictModule, this can be done via "
|
778 | 783 | "`distribution_kwargs={'safe_tanh': False}`. "
|
779 | 784 | "See https://github.com/pytorch/pytorch/issues/133529 for more details."
|
|
0 commit comments