Skip to content

Commit 9332809

Browse files
author
Vincent Moens
committed
[CI] Fix winndows compile tests
ghstack-source-id: 2ab8ae3 Pull Request resolved: #2511
1 parent baba52b commit 9332809

File tree

1 file changed

+36
-28
lines changed

1 file changed

+36
-28
lines changed

test/test_cost.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,57 +7,47 @@
77
import functools
88
import itertools
99
import operator
10+
11+
import sys
1012
import warnings
1113
from copy import deepcopy
1214
from dataclasses import asdict, dataclass
1315

14-
from packaging import version as pack_version
16+
import numpy as np
17+
import pytest
18+
import torch
19+
from _utils_internal import ( # noqa
20+
dtype_fixture,
21+
get_available_devices,
22+
get_default_devices,
23+
)
24+
from mocking_classes import ContinuousActionConvMockEnv
25+
26+
from packaging import version, version as pack_version
27+
28+
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
1529
from tensordict._C import unravel_keys
1630
from tensordict.nn import (
1731
CompositeDistribution,
1832
InteractionType,
33+
NormalParamExtractor,
1934
ProbabilisticTensorDictModule,
2035
ProbabilisticTensorDictModule as ProbMod,
2136
ProbabilisticTensorDictSequential,
2237
ProbabilisticTensorDictSequential as ProbSeq,
38+
TensorDictModule,
2339
TensorDictModule as Mod,
2440
TensorDictSequential,
2541
TensorDictSequential as Seq,
2642
)
27-
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
28-
from torchrl.modules.models import QMixer
29-
30-
_has_functorch = True
31-
try:
32-
import functorch as ft # noqa
33-
34-
make_functional_with_buffers = ft.make_functional_with_buffers
35-
FUNCTORCH_ERR = ""
36-
except ImportError as err:
37-
_has_functorch = False
38-
FUNCTORCH_ERR = str(err)
39-
40-
import numpy as np
41-
import pytest
42-
import torch
43-
from _utils_internal import ( # noqa
44-
dtype_fixture,
45-
get_available_devices,
46-
get_default_devices,
47-
)
48-
from mocking_classes import ContinuousActionConvMockEnv
49-
from packaging import version
50-
51-
# from torchrl.data.postprocs.utils import expand_as_right
52-
from tensordict import assert_allclose_td, TensorDict, TensorDictBase
53-
from tensordict.nn import NormalParamExtractor, TensorDictModule
5443
from tensordict.nn.utils import Buffer
5544
from tensordict.utils import unravel_key
5645
from torch import autograd, nn
5746
from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded
5847
from torchrl.data.postprocs.postprocs import MultiStep
5948
from torchrl.envs.model_based.dreamer import DreamerEnv
6049
from torchrl.envs.transforms import TensorDictPrimer, TransformedEnv
50+
from torchrl.envs.utils import exploration_type, ExplorationType, set_exploration_type
6151
from torchrl.modules import (
6252
DistributionalQValueActor,
6353
OneHotCategorical,
@@ -66,6 +56,7 @@
6656
WorldModelWrapper,
6757
)
6858
from torchrl.modules.distributions.continuous import TanhDelta, TanhNormal
59+
from torchrl.modules.models import QMixer
6960
from torchrl.modules.models.model_based import (
7061
DreamerActor,
7162
ObsDecoder,
@@ -147,7 +138,18 @@
147138
_split_and_pad_sequence,
148139
)
149140

141+
_has_functorch = True
142+
try:
143+
import functorch as ft # noqa
144+
145+
make_functional_with_buffers = ft.make_functional_with_buffers
146+
FUNCTORCH_ERR = ""
147+
except ImportError as err:
148+
_has_functorch = False
149+
FUNCTORCH_ERR = str(err)
150+
150151
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
152+
IS_WINDOWS = sys.platform == "win32"
151153

152154
# Capture all warnings
153155
pytestmark = [
@@ -15735,7 +15737,13 @@ def __init__(self):
1573515737
@pytest.mark.skipif(
1573615738
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
1573715739
)
15740+
@pytest.mark.skipif(IS_WINDOWS, reason="windows tests do not support compile")
1573815741
def test_exploration_compile():
15742+
try:
15743+
torch._dynamo.reset_code_caches()
15744+
except Exception:
15745+
# older versions of PT don't have that function
15746+
pass
1573915747
m = ProbabilisticTensorDictModule(
1574015748
in_keys=["loc", "scale"],
1574115749
out_keys=["sample"],

0 commit comments

Comments
 (0)