|
7 | 7 | import functools
|
8 | 8 | import itertools
|
9 | 9 | import operator
|
| 10 | + |
| 11 | +import sys |
10 | 12 | import warnings
|
11 | 13 | from copy import deepcopy
|
12 | 14 | from dataclasses import asdict, dataclass
|
13 | 15 |
|
14 |
| -from packaging import version as pack_version |
| 16 | +import numpy as np |
| 17 | +import pytest |
| 18 | +import torch |
| 19 | +from _utils_internal import ( # noqa |
| 20 | + dtype_fixture, |
| 21 | + get_available_devices, |
| 22 | + get_default_devices, |
| 23 | +) |
| 24 | +from mocking_classes import ContinuousActionConvMockEnv |
| 25 | + |
| 26 | +from packaging import version, version as pack_version |
| 27 | + |
| 28 | +from tensordict import assert_allclose_td, TensorDict, TensorDictBase |
15 | 29 | from tensordict._C import unravel_keys
|
16 | 30 | from tensordict.nn import (
|
17 | 31 | CompositeDistribution,
|
18 | 32 | InteractionType,
|
| 33 | + NormalParamExtractor, |
19 | 34 | ProbabilisticTensorDictModule,
|
20 | 35 | ProbabilisticTensorDictModule as ProbMod,
|
21 | 36 | ProbabilisticTensorDictSequential,
|
22 | 37 | ProbabilisticTensorDictSequential as ProbSeq,
|
| 38 | + TensorDictModule, |
23 | 39 | TensorDictModule as Mod,
|
24 | 40 | TensorDictSequential,
|
25 | 41 | TensorDictSequential as Seq,
|
26 | 42 | )
|
27 |
| -from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type |
28 |
| -from torchrl.modules.models import QMixer |
29 |
| - |
30 |
| -_has_functorch = True |
31 |
| -try: |
32 |
| - import functorch as ft # noqa |
33 |
| - |
34 |
| - make_functional_with_buffers = ft.make_functional_with_buffers |
35 |
| - FUNCTORCH_ERR = "" |
36 |
| -except ImportError as err: |
37 |
| - _has_functorch = False |
38 |
| - FUNCTORCH_ERR = str(err) |
39 |
| - |
40 |
| -import numpy as np |
41 |
| -import pytest |
42 |
| -import torch |
43 |
| -from _utils_internal import ( # noqa |
44 |
| - dtype_fixture, |
45 |
| - get_available_devices, |
46 |
| - get_default_devices, |
47 |
| -) |
48 |
| -from mocking_classes import ContinuousActionConvMockEnv |
49 |
| -from packaging import version |
50 |
| - |
51 |
| -# from torchrl.data.postprocs.utils import expand_as_right |
52 |
| -from tensordict import assert_allclose_td, TensorDict, TensorDictBase |
53 |
| -from tensordict.nn import NormalParamExtractor, TensorDictModule |
54 | 43 | from tensordict.nn.utils import Buffer
|
55 | 44 | from tensordict.utils import unravel_key
|
56 | 45 | from torch import autograd, nn
|
57 | 46 | from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
|
58 | 47 | from torchrl.data.postprocs.postprocs import MultiStep
|
59 | 48 | from torchrl.envs.model_based.dreamer import DreamerEnv
|
60 | 49 | from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
|
| 50 | +from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type |
61 | 51 | from torchrl.modules import (
|
62 | 52 | DistributionalQValueActor,
|
63 | 53 | OneHotCategorical,
|
|
66 | 56 | WorldModelWrapper,
|
67 | 57 | )
|
68 | 58 | from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal
|
| 59 | +from torchrl.modules.models import QMixer |
69 | 60 | from torchrl.modules.models.model_based import (
|
70 | 61 | DreamerActor,
|
71 | 62 | ObsDecoder,
|
|
147 | 138 | _split_and_pad_sequence,
|
148 | 139 | )
|
149 | 140 |
|
| 141 | +_has_functorch = True |
| 142 | +try: |
| 143 | + import functorch as ft # noqa |
| 144 | + |
| 145 | + make_functional_with_buffers = ft.make_functional_with_buffers |
| 146 | + FUNCTORCH_ERR = "" |
| 147 | +except ImportError as err: |
| 148 | + _has_functorch = False |
| 149 | + FUNCTORCH_ERR = str(err) |
| 150 | + |
150 | 151 | TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
|
| 152 | +IS_WINDOWS = sys.platform == "win32" |
151 | 153 |
|
152 | 154 | # Capture all warnings
|
153 | 155 | pytestmark = [
|
@@ -15735,7 +15737,13 @@ def __init__(self):
|
15735 | 15737 | @pytest.mark.skipif(
|
15736 | 15738 | TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
|
15737 | 15739 | )
|
| 15740 | +@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile") |
15738 | 15741 | def test_exploration_compile():
|
| 15742 | + try: |
| 15743 | + torch._dynamo.reset_code_caches() |
| 15744 | + except Exception: |
| 15745 | + # older versions of PT don't have that function |
| 15746 | + pass |
15739 | 15747 | m = ProbabilisticTensorDictModule(
|
15740 | 15748 | in_keys=["loc", "scale"],
|
15741 | 15749 | out_keys=["sample"],
|
|
0 commit comments