Skip to content

Commit 32f088c

Browse files
author
Vincent Moens
committed
[Lint] pyupgrade
ghstack-source-id: 027a8a8 Pull Request resolved: #2819
1 parent a2a28a9 commit 32f088c

File tree

169 files changed

+943
-1913
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

169 files changed

+943
-1913
lines changed

.github/unittest/helpers/coverage_run_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def write_config(config_path: Path, argv: List[str]) -> None:
2828
argv: Arguments passed to this script, which need to be converted to config file entries
2929
"""
3030
assert not config_path.exists(), "Temporary coverage config exists already"
31-
cmdline = " ".join(shlex.quote(arg) for arg in argv[1:])
32-
with open(str(config_path), "wt", encoding="utf-8") as fh:
31+
cmdline = shlex.join(argv[1:])
32+
with open(str(config_path), "w", encoding="utf-8") as fh:
3333
fh.write(
3434
f"""# .coveragerc to control coverage.py
3535
[run]

.pre-commit-config.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,17 @@ repos:
3535
hooks:
3636
- id: pydocstyle
3737
files: ^torchrl/
38+
39+
- repo: https://github.com/asottile/pyupgrade
40+
rev: v3.9.0
41+
hooks:
42+
- id: pyupgrade
43+
args: [--py38-plus]
44+
45+
- repo: local
46+
hooks:
47+
- id: autoflake
48+
name: autoflake
49+
entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports
50+
language: system
51+
types: [python]

build_tools/setup_helpers/extension.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from setuptools import Extension
1515
from setuptools.command.build_ext import build_ext
1616

17-
1817
_THIS_DIR = Path(__file__).parent.resolve()
1918
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
2019
_TORCHRL_DIR = _ROOT_DIR / "torchrl"
@@ -130,7 +129,7 @@ def build_extension(self, ext):
130129
# using -j in the build_ext call, not supported by pip or PyPA-build.
131130
if hasattr(self, "parallel") and self.parallel:
132131
# CMake 3.12+ only.
133-
build_args += ["-j{}".format(self.parallel)]
132+
build_args += [f"-j{self.parallel}"]
134133

135134
if not os.path.exists(self.build_temp):
136135
os.makedirs(self.build_temp)

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
def get_version():
3232
version_txt = os.path.join(cwd, "version.txt")
33-
with open(version_txt, "r") as f:
33+
with open(version_txt) as f:
3434
version = f.readline().strip()
3535
if os.getenv("TORCHRL_BUILD_VERSION"):
3636
version = os.getenv("TORCHRL_BUILD_VERSION")
@@ -64,8 +64,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
6464
def write_version_file(version):
6565
version_path = os.path.join(cwd, "torchrl", "version.py")
6666
with open(version_path, "w") as f:
67-
f.write("__version__ = '{}'\n".format(version))
68-
f.write("git_version = {}\n".format(repr(sha)))
67+
f.write(f"__version__ = '{version}'\n")
68+
f.write(f"git_version = {repr(sha)}\n")
6969

7070

7171
def _get_pytorch_version(is_nightly, is_local):
@@ -185,7 +185,7 @@ def _main(argv):
185185
version = get_version()
186186
write_version_file(version)
187187
TORCHRL_BUILD_VERSION = os.getenv("TORCHRL_BUILD_VERSION")
188-
logging.info("Building wheel {}-{}".format(package_name, version))
188+
logging.info(f"Building wheel {package_name}-{version}")
189189
logging.info(f"TORCHRL_BUILD_VERSION is {TORCHRL_BUILD_VERSION}")
190190

191191
is_local = TORCHRL_BUILD_VERSION is None

sota-implementations/a2c/a2c_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
16-
def main(cfg: "DictConfig"): # noqa: F821
16+
def main(cfg: DictConfig): # noqa: F821
1717

1818
from copy import deepcopy
1919

sota-implementations/a2c/a2c_mujoco.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
@hydra.main(config_path="", config_name="config_mujoco", version_base="1.1")
16-
def main(cfg: "DictConfig"): # noqa: F821
16+
def main(cfg: DictConfig): # noqa: F821
1717

1818
from copy import deepcopy
1919

sota-implementations/cql/cql_offline.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,13 @@
1515

1616
import hydra
1717
import numpy as np
18-
1918
import torch
2019
import tqdm
2120
from tensordict.nn import CudaGraphModule
22-
2321
from torchrl._utils import timeit
2422
from torchrl.envs.utils import ExplorationType, set_exploration_type
2523
from torchrl.objectives import group_optimizers
2624
from torchrl.record.loggers import generate_exp_name, get_logger
27-
2825
from utils import (
2926
dump_video,
3027
log_metrics,
@@ -39,7 +36,7 @@
3936

4037

4138
@hydra.main(config_path="", config_name="offline_config", version_base="1.1")
42-
def main(cfg: "DictConfig"): # noqa: F821
39+
def main(cfg: DictConfig): # noqa: F821
4340
# Create logger
4441
exp_name = generate_exp_name("CQL-offline", cfg.logger.exp_name)
4542
logger = None

sota-implementations/cql/cql_online.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@
2121
import tqdm
2222
from tensordict import TensorDict
2323
from tensordict.nn import CudaGraphModule
24-
2524
from torchrl._utils import timeit
2625
from torchrl.envs.utils import ExplorationType, set_exploration_type
2726
from torchrl.objectives import group_optimizers
2827
from torchrl.record.loggers import generate_exp_name, get_logger
29-
3028
from utils import (
3129
dump_video,
3230
log_metrics,
@@ -42,7 +40,7 @@
4240

4341

4442
@hydra.main(version_base="1.1", config_path="", config_name="online_config")
45-
def main(cfg: "DictConfig"): # noqa: F821
43+
def main(cfg: DictConfig): # noqa: F821
4644
# Create logger
4745
exp_name = generate_exp_name("CQL-online", cfg.logger.exp_name)
4846
logger = None

sota-implementations/cql/discrete_cql_online.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616

1717
import hydra
1818
import numpy as np
19-
2019
import torch
2120
import torch.cuda
2221
import tqdm
2322
from tensordict.nn import CudaGraphModule
24-
2523
from torchrl._utils import timeit
26-
2724
from torchrl.envs.utils import ExplorationType, set_exploration_type
28-
2925
from torchrl.record.loggers import generate_exp_name, get_logger
3026
from utils import (
3127
log_metrics,
@@ -41,7 +37,7 @@
4137

4238

4339
@hydra.main(version_base="1.1", config_path="", config_name="discrete_cql_config")
44-
def main(cfg: "DictConfig"): # noqa: F821
40+
def main(cfg: DictConfig): # noqa: F821
4541
device = cfg.optim.device
4642
if device in ("", None):
4743
if torch.cuda.is_available():

sota-implementations/crossq/crossq.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,15 @@
1515
import warnings
1616

1717
import hydra
18-
1918
import numpy as np
20-
2119
import torch
2220
import torch.cuda
2321
import tqdm
2422
from tensordict import TensorDict
2523
from tensordict.nn import CudaGraphModule
26-
2724
from torchrl._utils import timeit
2825
from torchrl.envs.utils import ExplorationType, set_exploration_type
2926
from torchrl.objectives import group_optimizers
30-
3127
from torchrl.record.loggers import generate_exp_name, get_logger
3228
from utils import (
3329
log_metrics,
@@ -43,7 +39,7 @@
4339

4440

4541
@hydra.main(version_base="1.1", config_path=".", config_name="config")
46-
def main(cfg: "DictConfig"): # noqa: F821
42+
def main(cfg: DictConfig): # noqa: F821
4743
device = cfg.network.device
4844
if device in ("", None):
4945
if torch.cuda.is_available():

0 commit comments

Comments
 (0)