diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index 579faecc3c6..b530a01418e 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -269,11 +269,20 @@ def _reset(self, tensordict): batch_size = ( tensordict.batch_size if tensordict is not None else self.batch_size ) - if tensordict is None or tensordict.is_empty(): + if tensordict is None or "params" not in tensordict: # if no ``tensordict`` is passed, we generate a single set of hyperparameters # Otherwise, we assume that the input ``tensordict`` contains all the relevant # parameters to get started. tensordict = self.gen_params(batch_size=batch_size, device=self.device) + elif "th" in tensordict and "thdot" in tensordict: + # we can hard-reset the env too + return tensordict + out = self._reset_random_data( + tensordict.shape, batch_size, tensordict["params"] + ) + return out + + def _reset_random_data(self, shape, batch_size, params): high_th = torch.tensor(self.DEFAULT_X, device=self.device) high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device) @@ -284,12 +293,12 @@ def _reset(self, tensordict): # of simulators run simultaneously. In other contexts, the initial # random state's shape will depend upon the environment batch-size instead. th = ( - torch.rand(tensordict.shape, generator=self.rng, device=self.device) + torch.rand(shape, generator=self.rng, device=self.device) * (high_th - low_th) + low_th ) thdot = ( - torch.rand(tensordict.shape, generator=self.rng, device=self.device) + torch.rand(shape, generator=self.rng, device=self.device) * (high_thdot - low_thdot) + low_thdot ) @@ -297,7 +306,7 @@ def _reset(self, tensordict): { "th": th, "thdot": thdot, - "params": tensordict["params"], + "params": params, }, batch_size=batch_size, ) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 0b5959bb900..feae60a1c59 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from copy import copy, deepcopy import torch diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index fa7b50e2f3e..2ee42d19667 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -16,19 +16,7 @@ from enum import IntEnum from functools import wraps from textwrap import indent -from typing import ( - Any, - Callable, - Dict, - List, - Mapping, - Optional, - OrderedDict, - Sequence, - Tuple, - TypeVar, - Union, -) +from typing import Any, Callable, Mapping, OrderedDict, Sequence, TypeVar, Union import numpy as np @@ -663,7 +651,7 @@ def dump(self, **kwargs) -> None: def __repr__(self) -> str: return f"{self.__class__.__name__}(keys={self.in_keys})" - def set_container(self, container: Union[Transform, EnvBase]) -> None: + def set_container(self, container: Transform | EnvBase) -> None: if self.parent is not None: raise AttributeError( f"parent of transform {type(self)} already set. " @@ -738,7 +726,7 @@ def __setstate__(self, state): self.__dict__.update(state) @property - def parent(self) -> Optional[EnvBase]: + def parent(self) -> EnvBase | None: """Returns the parent env of the transform. The parent env is the env that contains all the transforms up until the current one. @@ -859,7 +847,7 @@ class TransformedEnv(EnvBase, metaclass=_TEnvPostInit): def __init__( self, env: EnvBase, - transform: Optional[Transform] = None, + transform: Transform | None = None, cache_specs: bool = True, *, auto_unwrap: bool | None = None, @@ -1070,7 +1058,7 @@ def _make_input_spec(self): self.__dict__["_input_spec"] = input_spec return input_spec - def rand_action(self, tensordict: Optional[TensorDictBase] = None) -> TensorDict: + def rand_action(self, tensordict: TensorDictBase | None = None) -> TensorDict: if type(self.base_env).rand_action is not EnvBase.rand_action: # TODO: this will fail if the transform modifies the input. # For instance, if an env overrides rand_action and we build a @@ -1166,16 +1154,16 @@ def select_and_clone(x, y): return next_tensordict def set_seed( - self, seed: Optional[int] = None, static_seed: bool = False - ) -> Optional[int]: + self, seed: int | None = None, static_seed: bool = False + ) -> int | None: """Set the seeds of the environment.""" return self.base_env.set_seed(seed, static_seed=static_seed) - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): """This method is not used in transformed envs.""" pass - def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): + def _reset(self, tensordict: TensorDictBase | None = None, **kwargs): if tensordict is not None: # We must avoid modifying the original tensordict so a shallow copy is necessary. # We just select the input data and reset signal, which is all we need. @@ -1390,7 +1378,7 @@ def __init__( "observation", "pixels", ] - super(ObservationTransform, self).__init__( + super().__init__( in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, @@ -1537,7 +1525,7 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: ) return reward_spec - def __getitem__(self, item: Union[int, slice, List]) -> Union: + def __getitem__(self, item: int | slice | list) -> Union: transform = self.transforms transform = transform[item] if not isinstance(transform, Transform): @@ -1590,7 +1578,7 @@ def append( self.transforms.append(transform) transform.set_container(self) - def set_container(self, container: Union[Transform, EnvBase]) -> None: + def set_container(self, container: Transform | EnvBase) -> None: self.reset_parent() super().set_container(container) for t in self.transforms: @@ -1737,9 +1725,9 @@ class ToTensorImage(ObservationTransform): def __init__( self, - from_int: Optional[bool] = None, + from_int: bool | None = None, unsqueeze: bool = False, - dtype: Optional[torch.device] = None, + dtype: torch.device | None = None, *, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, @@ -2105,7 +2093,7 @@ def _apply_transform( target_return = target_return return target_return else: - raise ValueError("Unknown mode: {}".format(self.mode)) + raise ValueError(f"Unknown mode: {self.mode}") def forward(self, tensordict: TensorDictBase) -> TensorDictBase: raise NotImplementedError( @@ -2707,10 +2695,10 @@ def __init__( self, dim: int | None = None, *args, - in_keys: Optional[Sequence[str]] = None, - out_keys: Optional[Sequence[str]] = None, - in_keys_inv: Optional[Sequence[str]] = None, - out_keys_inv: Optional[Sequence[str]] = None, + in_keys: Sequence[str] | None = None, + out_keys: Sequence[str] | None = None, + in_keys_inv: Sequence[str] | None = None, + out_keys_inv: Sequence[str] | None = None, **kwargs, ): if dim is None: @@ -3009,8 +2997,8 @@ class ObservationNorm(ObservationTransform): def __init__( self, - loc: Optional[float, torch.Tensor] = None, - scale: Optional[float, torch.Tensor] = None, + loc: float | torch.Tensor | None = None, + scale: float | torch.Tensor | None = None, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, in_keys_inv: Sequence[NestedKey] | None = None, @@ -3074,10 +3062,10 @@ def initialized(self): def init_stats( self, num_iter: int, - reduce_dim: Union[int, Tuple[int]] = 0, - cat_dim: Optional[int] = None, - key: Optional[NestedKey] = None, - keep_dims: Optional[Tuple[int]] = None, + reduce_dim: int | tuple[int] = 0, + cat_dim: int | None = None, + key: NestedKey | None = None, + keep_dims: tuple[int] | None = None, ) -> None: """Initializes the loc and scale stats of the parent environment. @@ -3410,7 +3398,7 @@ def __init__( def make_rb_transform_and_sampler( self, batch_size: int, **sampler_kwargs - ) -> Tuple[Transform, "torchrl.data.replay_buffers.SliceSampler"]: # noqa: F821 + ) -> tuple[Transform, torchrl.data.replay_buffers.SliceSampler]: # noqa: F821 """Creates a transform and sampler to be used with a replay buffer when storing frame-stacked data. This method helps reduce redundancy in stored data by avoiding the need to @@ -3837,8 +3825,8 @@ class RewardScaling(Transform): def __init__( self, - loc: Union[float, torch.Tensor], - scale: Union[float, torch.Tensor], + loc: float | torch.Tensor, + scale: float | torch.Tensor, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, standard_normal: bool = False, @@ -4483,7 +4471,7 @@ def __init__( else: self._sync_device = _do_nothing - def set_container(self, container: Union[Transform, EnvBase]) -> None: + def set_container(self, container: Transform | EnvBase) -> None: if self.orig_device is None: if isinstance(container, EnvBase): device = container.device @@ -4719,7 +4707,7 @@ def __init__( in_keys = sorted(in_keys, key=_sort_keys) if not isinstance(out_key, (str, tuple)): raise Exception("CatTensors requires out_key to be of type NestedKey") - super(CatTensors, self).__init__(in_keys=in_keys, out_keys=[out_key]) + super().__init__(in_keys=in_keys, out_keys=[out_key]) self.dim = dim self._del_keys = del_keys self._keys_to_exclude = None @@ -5197,7 +5185,7 @@ def __init__( hash_fn: Callable = None, seed: Any | None = None, use_raw_nontensor: bool = False, - repertoire: Tuple[Tuple[int], Any] = None, + repertoire: tuple[tuple[int], Any] = None, ): if hash_fn is None: hash_fn = Hash.reproducible_hash @@ -5329,9 +5317,9 @@ def __init__( in_keys_inv: Sequence[NestedKey] | None = None, out_keys_inv: Sequence[NestedKey] | None = None, *, - tokenizer: "transformers.PretrainedTokenizerBase" = None, # noqa: F821 + tokenizer: transformers.PretrainedTokenizerBase = None, # noqa: F821 use_raw_nontensor: bool = False, - additional_tokens: List[str] | None = None, + additional_tokens: list[str] | None = None, skip_special_tokens: bool = True, add_special_tokens: bool = False, padding: bool = True, @@ -5374,7 +5362,7 @@ def device(self): self._device = device return device - def call_tokenizer_fn(self, value: str | List[str]): + def call_tokenizer_fn(self, value: str | list[str]): device = self.device kwargs = {"add_special_tokens": self.add_special_tokens} if self.max_length is not None: @@ -5585,7 +5573,7 @@ def __init__( elif in_key_inv is not None and out_keys_inv is None: raise ValueError("in_key_inv was specified, but out_keys_inv was not") - super(Stack, self).__init__( + super().__init__( in_keys=in_keys, out_keys=[out_key], in_keys_inv=None if in_key_inv is None else [in_key_inv], @@ -5961,15 +5949,30 @@ class TensorDictPrimer(Transform): the TensorSpec domain (or a unit Gaussian if unbounded). Otherwise a fixed value will be assumed. Defaults to `False`. default_value (:obj:`float`, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random - filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float, - all elements of the tensors will be set to that value. If it is a callable, this callable is expected to - return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value` - is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will - be used to generate the corresponding tensors. Defaults to `0.0`. + filling is chosen, `default_value` will be used to populate the tensors. + + - If `default_value` is a float or any other scala, all elements of the tensors will be set to that value. + - If it is a callable and `single_default_value=False` (default), this callable is expected to return a tensor + fitting the specs (ie, ``default_value()`` will be called independently for each leaf spec). + - If it is a callable and ``single_default_value=True``, then the callable will be called just once and it is expected + that the structure of its returned TensorDict instance or equivalent will match the provided specs. + The ``default_value`` must accept an optional `reset` keyword argument indicating which envs are to be reset. + The returned `TensorDict` must have as many elements as the number of envs to reset. + + .. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` + + - Finally, if `default_value` is a dictionary of tensors or a dictionary of callables with keys matching + those of the specs, these will be used to generate the corresponding tensors. Defaults to `0.0`. + reset_key (NestedKey, optional): the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) and raises an exception otherwise. + single_default_value (bool, optional): if ``True`` and `default_value` is a callable, it will be expected that + ``default_value`` returns a single tensordict matching the specs. If `False`, `default_value()` will be + called independently for each leaf. Defaults to ``False``. + call_before_env_reset (bool, optional): if ``True``, the tensordict is populated before `env.reset` is called. + Defaults to ``False``. **kwargs: each keyword argument corresponds to a key in the tensordict. The corresponding value has to be a TensorSpec instance indicating what the value must be. @@ -6065,18 +6068,20 @@ def __init__( random: bool | None = None, default_value: float | Callable - | Dict[NestedKey, float] - | Dict[NestedKey, Callable] = None, + | dict[NestedKey, float] + | dict[NestedKey, Callable] = None, reset_key: NestedKey | None = None, expand_specs: bool = None, + single_default_value: bool = False, + call_before_env_reset: bool = False, **kwargs, ): self.device = kwargs.pop("device", None) if primers is not None: if kwargs: raise RuntimeError( - "providing the primers as a dictionary is incompatible with extra keys provided " - "as kwargs." + f"providing the primers as a dictionary is incompatible with extra keys " + f"'{kwargs.keys()}' provided as kwargs." ) kwargs = primers if not isinstance(kwargs, Composite): @@ -6089,6 +6094,7 @@ def __init__( primers = Composite(kwargs, device=device, shape=shape, **extra_kwargs) self.primers = primers self.expand_specs = expand_specs + self.call_before_env_reset = call_before_env_reset if random and default_value: raise ValueError( @@ -6109,10 +6115,13 @@ def __init__( raise ValueError( "If a default_value dictionary is provided, it must match the primers keys." ) + elif single_default_value: + pass else: default_value = { key: default_value for key in self.primers.keys(True, True) } + self.single_default_value = single_default_value self.default_value = default_value self._validated = False self.reset_key = reset_key @@ -6225,6 +6234,14 @@ def _validate_value_tensor(self, value, spec): return True def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if self.single_default_value and callable(self.default_value): + tensordict.update(self.default_value()) + for key, spec in self.primers.items(True, True): + if not self._validated: + self._validate_value_tensor(tensordict.get(key), spec) + if not self._validated: + self._validated = True + return tensordict for key, spec in self.primers.items(True, True): if spec.shape[: len(tensordict.shape)] != tensordict.shape: raise RuntimeError( @@ -6271,6 +6288,27 @@ def _reset( shape. We allow for execution when the parent is missing, in which case the spec shape is assumed to match the tensordict's. """ + if self.call_before_env_reset: + return tensordict_reset + return self._reset_func(tensordict, tensordict_reset) + + def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: + if not self.call_before_env_reset: + return tensordict + if tensordict is None: + parent = self.parent + if parent is not None: + device = parent.device + batch_size = parent.batch_size + else: + device = None + batch_size = () + tensordict = TensorDict(device=device, batch_size=batch_size) + return self._reset_func(tensordict, tensordict) + + def _reset_func( + self, tensordict, tensordict_reset: TensorDictBase + ) -> TensorDictBase: _reset = _get_reset(self.reset_key, tensordict) if ( self.parent @@ -6279,6 +6317,23 @@ def _reset( ): self.primers = self._expand_shape(self.primers) if _reset.any(): + if self.single_default_value and callable(self.default_value): + if not _reset.all(): + tensordict_reset = torch.where( + _reset, + self.default_value(reset=_reset), + tensordict_reset[_reset], + ) + else: + resets = self.default_value(reset=_reset) + tensordict_reset.update(resets) + + for key, spec in self.primers.items(True, True): + if not self._validated: + self._validate_value_tensor(tensordict_reset.get(key), spec) + self._validated = True + return tensordict_reset + for key, spec in self.primers.items(True, True): if self.random: shape = ( @@ -6429,11 +6484,11 @@ def __init__( self, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, - shared_td: Optional[TensorDictBase] = None, + shared_td: TensorDictBase | None = None, lock: mp.Lock = None, decay: float = 0.9999, eps: float = 1e-4, - shapes: List[torch.Size] = None, + shapes: list[torch.Size] = None, ) -> None: if lock is None: lock = mp.Lock() @@ -6617,7 +6672,7 @@ def _update(self, key, value, N) -> torch.Tensor: std = (_ssq / _count - mean.pow(2)).clamp_min(self.eps).sqrt() return (value - mean) / std.clamp_min(self.eps) - def to_observation_norm(self) -> Union[Compose, ObservationNorm]: + def to_observation_norm(self) -> Compose | ObservationNorm: """Converts VecNorm into an ObservationNorm class that can be used at inference time. The :class:`~torchrl.envs.ObservationNorm` layer can be updated using the :meth:`~torch.nn.Module.state_dict` @@ -6704,7 +6759,7 @@ def scale(self): @staticmethod def build_td_for_shared_vecnorm( env: EnvBase, - keys: Optional[Sequence[str]] = None, + keys: Sequence[str] | None = None, memmap: bool = False, ) -> TensorDictBase: """Creates a shared tensordict for normalization across processes. @@ -6806,14 +6861,14 @@ def __repr__(self) -> str: f"eps={self.eps:4.4f}, in_keys={self.in_keys}, out_keys={self.out_keys})" ) - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: state = super().__getstate__() _lock = state.pop("lock", None) if _lock is not None: state["lock_placeholder"] = None return state - def __setstate__(self, state: Dict[str, Any]): + def __setstate__(self, state: dict[str, Any]): if "lock_placeholder" in state: state.pop("lock_placeholder") _lock = mp.Lock() @@ -7185,7 +7240,7 @@ class StepCounter(Transform): def __init__( self, - max_steps: Optional[int] = None, + max_steps: int | None = None, truncated_key: str | None = "truncated", step_count_key: str | None = "step_count", update_done: bool = True, @@ -7903,7 +7958,7 @@ def __init__( self, sub_seq_len: int, sample_dim: int = -1, - mask_key: Optional[NestedKey] = None, + mask_key: NestedKey | None = None, ): self.sub_seq_len = sub_seq_len if sample_dim > 0: @@ -8007,7 +8062,7 @@ def __init__(self, init_key: str = "is_init"): self.init_key = init_key super().__init__() - def set_container(self, container: Union[Transform, EnvBase]) -> None: + def set_container(self, container: Transform | EnvBase) -> None: self._init_keys = None return super().set_container(container) @@ -8436,10 +8491,10 @@ class Reward2GoTransform(Transform): def __init__( self, - gamma: Optional[Union[float, torch.Tensor]] = 1.0, + gamma: float | torch.Tensor | None = 1.0, in_keys: Sequence[NestedKey] | None = None, out_keys: Sequence[NestedKey] | None = None, - done_key: Optional[NestedKey] = "done", + done_key: NestedKey | None = "done", ): if in_keys is None: in_keys = [("next", "reward")] @@ -8652,7 +8707,7 @@ def __init__(self, final_name="final"): super().__init__() self._memo = {} - def set_container(self, container: Union[Transform, EnvBase]) -> None: + def set_container(self, container: Transform | EnvBase) -> None: out = super().set_container(container) self._done_keys = None self._obs_keys = None @@ -8748,7 +8803,7 @@ def _reset( return tensordict_reset @property - def done_keys(self) -> List[NestedKey]: + def done_keys(self) -> list[NestedKey]: keys = self.__dict__.get("_done_keys", None) if keys is None: keys = self.parent.done_keys @@ -8765,7 +8820,7 @@ def done_keys(self) -> List[NestedKey]: return keys @property - def obs_keys(self) -> List[NestedKey]: + def obs_keys(self) -> list[NestedKey]: keys = self.__dict__.get("_obs_keys", None) if keys is None: keys = list(self.parent.observation_spec.keys(True, True)) @@ -9537,7 +9592,7 @@ def transform_input_spec(self, input_spec: Composite) -> Composite: class AutoResetEnv(TransformedEnv): """A subclass for auto-resetting envs.""" - def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): + def _reset(self, tensordict: TensorDictBase | None = None, **kwargs): if tensordict is not None: # We must avoid modifying the original tensordict so a shallow copy is necessary. # We just select the input data and reset signal, which is all we need.