|
2 | 2 | #
|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 |
| -from __future__ import annotations |
| 5 | + |
| 6 | +# This makes omegaconf unhappy with typing.Any |
| 7 | +# Therefore we need Optional and Union |
| 8 | +# from __future__ import annotations |
6 | 9 |
|
7 | 10 | from copy import copy
|
8 | 11 | from dataclasses import dataclass, field as dataclass_field
|
9 |
| -from typing import Any, Callable, Sequence |
| 12 | +from typing import Any, Callable, Optional, Sequence, Union |
10 | 13 |
|
11 | 14 | import torch
|
| 15 | +from omegaconf import DictConfig |
12 | 16 |
|
13 | 17 | from torchrl._utils import logger as torchrl_logger, VERBOSE
|
14 | 18 | from torchrl.envs import ParallelEnv
|
@@ -212,18 +216,18 @@ def get_norm_state_dict(env):
|
212 | 216 | def transformed_env_constructor(
|
213 | 217 | cfg: DictConfig, # noqa: F821
|
214 | 218 | video_tag: str = "",
|
215 |
| - logger: Logger | None = None, |
216 |
| - stats: dict | None = None, |
| 219 | + logger: Optional[Logger] = None, # noqa |
| 220 | + stats: Optional[dict] = None, |
217 | 221 | norm_obs_only: bool = False,
|
218 | 222 | use_env_creator: bool = False,
|
219 |
| - custom_env_maker: Callable | None = None, |
220 |
| - custom_env: EnvBase | None = None, |
| 223 | + custom_env_maker: Optional[Callable] = None, |
| 224 | + custom_env: Optional[EnvBase] = None, |
221 | 225 | return_transformed_envs: bool = True,
|
222 |
| - action_dim_gsde: int | None = None, |
223 |
| - state_dim_gsde: int | None = None, |
224 |
| - batch_dims: int | None = 0, |
225 |
| - obs_norm_state_dict: dict | None = None, |
226 |
| -) -> Callable | EnvCreator: |
| 226 | + action_dim_gsde: Optional[int] = None, |
| 227 | + state_dim_gsde: Optional[int] = None, |
| 228 | + batch_dims: Optional[int] = 0, |
| 229 | + obs_norm_state_dict: Optional[dict] = None, |
| 230 | +) -> Union[Callable, EnvCreator]: |
227 | 231 | """Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor.
|
228 | 232 |
|
229 | 233 | Args:
|
@@ -329,7 +333,7 @@ def make_transformed_env(**kwargs) -> TransformedEnv:
|
329 | 333 |
|
330 | 334 | def parallel_env_constructor(
|
331 | 335 | cfg: DictConfig, **kwargs # noqa: F821
|
332 |
| -) -> ParallelEnv | EnvCreator: |
| 336 | +) -> Union[ParallelEnv, EnvCreator]: |
333 | 337 | """Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor.
|
334 | 338 |
|
335 | 339 | Args:
|
@@ -374,7 +378,7 @@ def parallel_env_constructor(
|
374 | 378 | def get_stats_random_rollout(
|
375 | 379 | cfg: DictConfig, # noqa: F821
|
376 | 380 | proof_environment: EnvBase = None,
|
377 |
| - key: str | None = None, |
| 381 | + key: Optional[str] = None, |
378 | 382 | ):
|
379 | 383 | """Gathers stas (loc and scale) from an environment using random rollouts.
|
380 | 384 |
|
@@ -452,7 +456,7 @@ def get_stats_random_rollout(
|
452 | 456 | def initialize_observation_norm_transforms(
|
453 | 457 | proof_environment: EnvBase,
|
454 | 458 | num_iter: int = 1000,
|
455 |
| - key: str | tuple[str, ...] = None, |
| 459 | + key: Union[str, tuple[str, ...]] = None, |
456 | 460 | ):
|
457 | 461 | """Calls :obj:`ObservationNorm.init_stats` on all uninitialized :obj:`ObservationNorm` instances of a :obj:`TransformedEnv`.
|
458 | 462 |
|
@@ -532,7 +536,7 @@ class EnvConfig:
|
532 | 536 | # maximum steps per trajectory, frames per batch or any other factor in the algorithm,
|
533 | 537 | # e.g. if the total number of frames that has to be computed is 50e6 and the frame skip is 4
|
534 | 538 | # the actual number of frames retrieved will be 200e6. Default=1.
|
535 |
| - reward_scaling: float | None = None |
| 539 | + reward_scaling: Any = None # noqa |
536 | 540 | # scale of the reward.
|
537 | 541 | reward_loc: float = 0.0
|
538 | 542 | # location of the reward.
|
|
0 commit comments