Skip to content

Commit 14dbdb2

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 0e9d5f4 commit 14dbdb2

File tree

106 files changed

+328
-906
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+328
-906
lines changed

.pre-commit-config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,11 @@ repos:
4141
hooks:
4242
- id: pyupgrade
4343
args: [--py38-plus]
44+
45+
- repo: local
46+
hooks:
47+
- id: autoflake
48+
name: autoflake
49+
entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports
50+
language: system
51+
types: [python]

build_tools/setup_helpers/extension.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import platform
99
import subprocess
1010
from pathlib import Path
11-
from subprocess import CalledProcessError, STDOUT, check_output
11+
from subprocess import CalledProcessError, check_output, STDOUT
1212

1313
import torch
1414
from setuptools import Extension

sota-implementations/dreamer/dreamer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.cuda
1313
import tqdm
14+
1415
from dreamer_utils import (
1516
_default_device,
1617
dump_video,
@@ -20,6 +21,7 @@
2021
make_environments,
2122
make_replay_buffer,
2223
)
24+
2325
# mixed precision training
2426
from torch.amp import GradScaler
2527
from torch.nn.utils import clip_grad_norm_

sota-implementations/redq/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
)
1818
from torch import distributions as d, nn, optim
1919
from torch.optim.lr_scheduler import CosineAnnealingLR
20-
from torchrl._utils import VERBOSE, logger as torchrl_logger
20+
21+
from torchrl._utils import logger as torchrl_logger, VERBOSE
2122
from torchrl.collectors.collectors import DataCollectorBase
2223
from torchrl.data import (
2324
LazyMemmapStorage,
@@ -35,10 +36,12 @@
3536
Compose,
3637
DMControlEnv,
3738
DoubleToFloat,
39+
env_creator,
3840
EnvBase,
3941
EnvCreator,
4042
FlattenObservation,
4143
GrayScale,
44+
gSDENoise,
4245
GymEnv,
4346
InitTracker,
4447
NoopResetEnv,
@@ -50,8 +53,6 @@
5053
ToTensorImage,
5154
TransformedEnv,
5255
VecNorm,
53-
env_creator,
54-
gSDENoise,
5556
)
5657
from torchrl.envs.utils import ExplorationType, set_exploration_type
5758
from torchrl.modules import (

test/_utils_internal.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,21 @@
1717
import pytest
1818
import torch
1919
import torch.cuda
20-
from tensordict import NestedKey, TensorDict, TensorDictBase, tensorclass
20+
from tensordict import NestedKey, tensorclass, TensorDict, TensorDictBase
2121
from tensordict.nn import TensorDictModuleBase
2222
from torch import nn, vmap
23-
from torchrl._utils import (RL_WARNINGS, implement_for, logger as torchrl_logger, seed_generator)
23+
24+
from torchrl._utils import (
25+
implement_for,
26+
logger as torchrl_logger,
27+
RL_WARNINGS,
28+
seed_generator,
29+
)
2430
from torchrl.data.utils import CloudpickleWrapper
2531
from torchrl.envs import MultiThreadedEnv, ObservationNorm
2632
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
2733
from torchrl.envs.libs.envpool import _has_envpool
28-
from torchrl.envs.libs.gym import GymEnv, _has_gym, gym_backend
34+
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv
2935
from torchrl.envs.transforms import (
3036
Compose,
3137
RewardClipping,

test/mocking_classes.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
import numpy as np
1111
import torch
1212
import torch.nn as nn
13-
from tensordict import TensorDict, TensorDictBase, tensorclass
13+
from tensordict import tensorclass, TensorDict, TensorDictBase
1414
from tensordict.nn import TensorDictModuleBase
15-
from tensordict.utils import NestedKey, expand_right
15+
from tensordict.utils import expand_right, NestedKey
16+
1617
from torchrl.data import (
1718
Binary,
1819
Bounded,
@@ -28,7 +29,11 @@
2829
from torchrl.envs import Transform
2930
from torchrl.envs.common import EnvBase
3031
from torchrl.envs.model_based.common import ModelBasedEnvBase
31-
from torchrl.envs.utils import (MarlGroupMapType, _terminated_or_truncated, check_marl_grouping)
32+
from torchrl.envs.utils import (
33+
_terminated_or_truncated,
34+
check_marl_grouping,
35+
MarlGroupMapType,
36+
)
3237

3338
spec_dict = {
3439
"bounded": Bounded,

test/smoke_test_deps.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
import pytest
1111

12-
from torchrl.envs.libs.gym import gym_backend
13-
1412

1513
def test_dm_control():
1614
import dm_control # noqa: F401

test/test_collector.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,34 @@
1717
import pytest
1818
import torch
1919
from packaging import version
20-
from tensordict import (LazyStackedTensorDict, NonTensorData, TensorDict, TensorDictBase, assert_allclose_td)
20+
from tensordict import (
21+
assert_allclose_td,
22+
LazyStackedTensorDict,
23+
NonTensorData,
24+
TensorDict,
25+
TensorDictBase,
26+
)
2127
from tensordict.nn import (
2228
CudaGraphModule,
2329
TensorDictModule,
2430
TensorDictModuleBase,
2531
TensorDictSequential,
2632
)
2733
from torch import nn
34+
2835
from torchrl._utils import (
2936
_make_ordinal_device,
3037
_replace_last,
3138
logger as torchrl_logger,
3239
prod,
3340
seed_generator,
3441
)
35-
from torchrl.collectors import SyncDataCollector, aSyncDataCollector
36-
from torchrl.collectors.collectors import (MultiSyncDataCollector, MultiaSyncDataCollector, _Interruptor)
42+
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
43+
from torchrl.collectors.collectors import (
44+
_Interruptor,
45+
MultiaSyncDataCollector,
46+
MultiSyncDataCollector,
47+
)
3748
from torchrl.collectors.utils import split_trajectories
3849
from torchrl.data import (
3950
Composite,
@@ -54,9 +65,14 @@
5465
StepCounter,
5566
Transform,
5667
)
57-
from torchrl.envs.libs.gym import GymEnv, _has_gym, gym_backend, set_gym_backend
68+
from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, set_gym_backend
5869
from torchrl.envs.transforms import TransformedEnv, VecNorm
59-
from torchrl.envs.utils import (PARTIAL_MISSING_ERR, RandomPolicy, _aggregate_end_of_traj, check_env_specs)
70+
from torchrl.envs.utils import (
71+
_aggregate_end_of_traj,
72+
check_env_specs,
73+
PARTIAL_MISSING_ERR,
74+
RandomPolicy,
75+
)
6076
from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule
6177

6278
if os.getenv("PYTORCH_TEST_FBCODE"):

test/test_cost.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5859,7 +5859,6 @@ def test_crossq_tensordict_keys(self, td_est):
58595859

58605860
actor = self._create_mock_actor()
58615861
qvalue = self._create_mock_qvalue()
5862-
value = None
58635862

58645863
loss_fn = CrossQLoss(
58655864
actor_network=actor,

test/test_exploration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def test_nested(
381381
action_spec = env.action_spec
382382
d_act = action_spec.shape[-1]
383383

384-
net = nn.LazyLinear(d_act).to(device)
384+
nn.LazyLinear(d_act).to(device)
385385
policy = TensorDictModule(
386386
CountingEnvCountModule(action_spec=action_spec),
387387
in_keys=[("data", "states") if nested_obs_action else "observation"],

0 commit comments

Comments
 (0)