Skip to content

Commit 36432e6

Browse files
author
Vincent Moens
committed
Update (base update)
[ghstack-poisoned]
1 parent 8c9dc05 commit 36432e6

File tree

220 files changed

+3116
-1685
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

220 files changed

+3116
-1685
lines changed

.github/unittest/helpers/coverage_run_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def write_config(config_path: Path, argv: List[str]) -> None:
2828
argv: Arguments passed to this script, which need to be converted to config file entries
2929
"""
3030
assert not config_path.exists(), "Temporary coverage config exists already"
31-
cmdline = " ".join(shlex.quote(arg) for arg in argv[1:])
32-
with open(str(config_path), "wt", encoding="utf-8") as fh:
31+
cmdline = shlex.join(argv[1:])
32+
with open(str(config_path), "w", encoding="utf-8") as fh:
3333
fh.write(
3434
f"""# .coveragerc to control coverage.py
3535
[run]

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
echo '::endgroup::'
3636
3737
echo '::group::Install lint tools'
38-
pip install --progress-bar=off pre-commit
38+
pip install --progress-bar=off pre-commit autoflake
3939
echo '::endgroup::'
4040
4141
echo '::group::Lint Python source and configs'

.pre-commit-config.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,17 @@ repos:
3535
hooks:
3636
- id: pydocstyle
3737
files: ^torchrl/
38+
39+
- repo: https://github.com/asottile/pyupgrade
40+
rev: v3.9.0
41+
hooks:
42+
- id: pyupgrade
43+
args: [--py38-plus]
44+
45+
- repo: local
46+
hooks:
47+
- id: autoflake
48+
name: autoflake
49+
entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports
50+
language: system
51+
types: [python]

build_tools/setup_helpers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .extension import CMakeBuild, get_ext_modules # noqa
6+
from .extension import CMakeBuild, get_ext_modules
7+
8+
__all__ = ["CMakeBuild", "get_ext_modules"]

build_tools/setup_helpers/extension.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from setuptools import Extension
1515
from setuptools.command.build_ext import build_ext
1616

17-
1817
_THIS_DIR = Path(__file__).parent.resolve()
1918
_ROOT_DIR = _THIS_DIR.parent.parent.resolve()
2019
_TORCHRL_DIR = _ROOT_DIR / "torchrl"
@@ -130,7 +129,7 @@ def build_extension(self, ext):
130129
# using -j in the build_ext call, not supported by pip or PyPA-build.
131130
if hasattr(self, "parallel") and self.parallel:
132131
# CMake 3.12+ only.
133-
build_args += ["-j{}".format(self.parallel)]
132+
build_args += [f"-j{self.parallel}"]
134133

135134
if not os.path.exists(self.build_temp):
136135
os.makedirs(self.build_temp)

docs/source/reference/envs.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ TorchRL offers a series of custom built-in environments.
440440
ChessEnv
441441
PendulumEnv
442442
TicTacToeEnv
443+
LLMEnv
443444
LLMHashingEnv
444445

445446

@@ -1033,6 +1034,7 @@ to be able to create this other composition:
10331034
Compose
10341035
ConditionalSkip
10351036
Crop
1037+
DataLoadingPrimer
10361038
DTypeCastTransform
10371039
DeviceCastTransform
10381040
DiscreteActionProjection
@@ -1218,7 +1220,7 @@ Recorders are transforms that register data as they come in, for logging purpose
12181220

12191221
Helpers
12201222
-------
1221-
.. currentmodule:: torchrl.envs.utils
1223+
.. currentmodule:: torchrl.envs
12221224

12231225
.. autosummary::
12241226
:toctree: generated/

docs/source/reference/objectives.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ auto-completion to make their choice.
111111
:template: rl_template_noinherit.rst
112112

113113
LossModule
114+
add_random_module
114115

115116
DQN
116117
---

setup.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ ignore-decorators =
4545
test_*
4646
; test/*.py
4747
; .circleci/*
48+
49+
[autoflake]
50+
per-file-ignores =
51+
torchrl/trainers/helpers/envs.py *

setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
def get_version():
3232
version_txt = os.path.join(cwd, "version.txt")
33-
with open(version_txt, "r") as f:
33+
with open(version_txt) as f:
3434
version = f.readline().strip()
3535
if os.getenv("TORCHRL_BUILD_VERSION"):
3636
version = os.getenv("TORCHRL_BUILD_VERSION")
@@ -64,8 +64,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
6464
def write_version_file(version):
6565
version_path = os.path.join(cwd, "torchrl", "version.py")
6666
with open(version_path, "w") as f:
67-
f.write("__version__ = '{}'\n".format(version))
68-
f.write("git_version = {}\n".format(repr(sha)))
67+
f.write(f"__version__ = '{version}'\n")
68+
f.write(f"git_version = {repr(sha)}\n")
6969

7070

7171
def _get_pytorch_version(is_nightly, is_local):
@@ -185,7 +185,7 @@ def _main(argv):
185185
version = get_version()
186186
write_version_file(version)
187187
TORCHRL_BUILD_VERSION = os.getenv("TORCHRL_BUILD_VERSION")
188-
logging.info("Building wheel {}-{}".format(package_name, version))
188+
logging.info(f"Building wheel {package_name}-{version}")
189189
logging.info(f"TORCHRL_BUILD_VERSION is {TORCHRL_BUILD_VERSION}")
190190

191191
is_local = TORCHRL_BUILD_VERSION is None

sota-implementations/a2c/a2c_atari.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
@hydra.main(config_path="", config_name="config_atari", version_base="1.1")
16-
def main(cfg: "DictConfig"): # noqa: F821
16+
def main(cfg: DictConfig): # noqa: F821
1717

1818
from copy import deepcopy
1919

0 commit comments

Comments
 (0)