Skip to content

Commit b1a9c44

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 8512476 commit b1a9c44

File tree

4 files changed

+26
-16
lines changed

4 files changed

+26
-16
lines changed

setup.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,7 @@ ignore-decorators =
4545
test_*
4646
; test/*.py
4747
; .circleci/*
48+
49+
[autoflake]
50+
per-file-ignores =
51+
torchrl/trainers/helpers/envs.py *

test/test_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
56

67
import argparse
78
import dataclasses
@@ -225,6 +226,7 @@ def test_timeit():
225226
@pytest.mark.skipif(not _has_hydra, reason="No hydra library found")
226227
@pytest.mark.parametrize("from_pixels", [(), ("from_pixels=True", "catframes=4")])
227228
def test_transformed_env_constructor_with_state_dict(from_pixels):
229+
228230
config_fields = [
229231
(config_field.name, config_field.type, config_field)
230232
for config_cls in (

torchrl/objectives/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from torchrl._utils import RL_WARNINGS
2323
from torchrl.envs.utils import ExplorationType, set_exploration_type
24-
from torchrl.modules import set_recurrent_mode
24+
from torchrl.modules.tensordict_module.rnn import set_recurrent_mode
2525
from torchrl.objectives.utils import ValueEstimators
2626
from torchrl.objectives.value import ValueEstimatorBase
2727

torchrl/trainers/helpers/envs.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from __future__ import annotations
5+
6+
# This makes omegaconf unhappy with typing.Any
7+
# Therefore we need Optional and Union
8+
# from __future__ import annotations
69

710
from copy import copy
811
from dataclasses import dataclass, field as dataclass_field
9-
from typing import Any, Callable, Sequence
12+
from typing import Any, Callable, Optional, Sequence, Union
1013

1114
import torch
15+
from omegaconf import DictConfig
1216

1317
from torchrl._utils import logger as torchrl_logger, VERBOSE
1418
from torchrl.envs import ParallelEnv
@@ -212,18 +216,18 @@ def get_norm_state_dict(env):
212216
def transformed_env_constructor(
213217
cfg: DictConfig, # noqa: F821
214218
video_tag: str = "",
215-
logger: Logger | None = None,
216-
stats: dict | None = None,
219+
logger: Optional[Logger] = None, # noqa
220+
stats: Optional[dict] = None,
217221
norm_obs_only: bool = False,
218222
use_env_creator: bool = False,
219-
custom_env_maker: Callable | None = None,
220-
custom_env: EnvBase | None = None,
223+
custom_env_maker: Optional[Callable] = None,
224+
custom_env: Optional[EnvBase] = None,
221225
return_transformed_envs: bool = True,
222-
action_dim_gsde: int | None = None,
223-
state_dim_gsde: int | None = None,
224-
batch_dims: int | None = 0,
225-
obs_norm_state_dict: dict | None = None,
226-
) -> Callable | EnvCreator:
226+
action_dim_gsde: Optional[int] = None,
227+
state_dim_gsde: Optional[int] = None,
228+
batch_dims: Optional[int] = 0,
229+
obs_norm_state_dict: Optional[dict] = None,
230+
) -> Union[Callable, EnvCreator]:
227231
"""Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
228232
229233
Args:
@@ -329,7 +333,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
329333

330334
def parallel_env_constructor(
331335
cfg: DictConfig, **kwargs # noqa: F821
332-
) -> ParallelEnv | EnvCreator:
336+
) -> Union[ParallelEnv, EnvCreator]:
333337
"""Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.
334338
335339
Args:
@@ -374,7 +378,7 @@ def parallel_env_constructor(
374378
def get_stats_random_rollout(
375379
cfg: DictConfig, # noqa: F821
376380
proof_environment: EnvBase = None,
377-
key: str | None = None,
381+
key: Optional[str] = None,
378382
):
379383
"""Gathers stas (loc and scale) from an environment using random rollouts.
380384
@@ -452,7 +456,7 @@ def get_stats_random_rollout(
452456
def initialize_observation_norm_transforms(
453457
proof_environment: EnvBase,
454458
num_iter: int = 1000,
455-
key: str | tuple[str, ...] = None,
459+
key: Union[str, tuple[str, ...]] = None,
456460
):
457461
"""Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
458462
@@ -532,7 +536,7 @@ class EnvConfig:
532536
# maximum steps per trajectory, frames per batch or any other factor in the algorithm,
533537
# e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4
534538
# the actual number of frames retrieved will be 200e6. Default=1.
535-
reward_scaling: float | None = None
539+
reward_scaling: Any = None # noqa
536540
# scale of the reward.
537541
reward_loc: float = 0.0
538542
# location of the reward.

0 commit comments

Comments
 (0)