From 85d9d2f8001aa78f7adb7da4ac74959a31939bfc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 29 Jan 2025 17:37:35 -0800 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- torchrl/envs/custom/pendulum.py | 17 ++++++++++---- torchrl/envs/transforms/transforms.py | 34 +++++++++++++++++++++++---- 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/torchrl/envs/custom/pendulum.py b/torchrl/envs/custom/pendulum.py index 579faecc3c6..b530a01418e 100644 --- a/torchrl/envs/custom/pendulum.py +++ b/torchrl/envs/custom/pendulum.py @@ -269,11 +269,20 @@ def _reset(self, tensordict): batch_size = ( tensordict.batch_size if tensordict is not None else self.batch_size ) - if tensordict is None or tensordict.is_empty(): + if tensordict is None or "params" not in tensordict: # if no ``tensordict`` is passed, we generate a single set of hyperparameters # Otherwise, we assume that the input ``tensordict`` contains all the relevant # parameters to get started. tensordict = self.gen_params(batch_size=batch_size, device=self.device) + elif "th" in tensordict and "thdot" in tensordict: + # we can hard-reset the env too + return tensordict + out = self._reset_random_data( + tensordict.shape, batch_size, tensordict["params"] + ) + return out + + def _reset_random_data(self, shape, batch_size, params): high_th = torch.tensor(self.DEFAULT_X, device=self.device) high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device) @@ -284,12 +293,12 @@ def _reset(self, tensordict): # of simulators run simultaneously. In other contexts, the initial # random state's shape will depend upon the environment batch-size instead. th = ( - torch.rand(tensordict.shape, generator=self.rng, device=self.device) + torch.rand(shape, generator=self.rng, device=self.device) * (high_th - low_th) + low_th ) thdot = ( - torch.rand(tensordict.shape, generator=self.rng, device=self.device) + torch.rand(shape, generator=self.rng, device=self.device) * (high_thdot - low_thdot) + low_thdot ) @@ -297,7 +306,7 @@ def _reset(self, tensordict): { "th": th, "thdot": thdot, - "params": tensordict["params"], + "params": params, }, batch_size=batch_size, ) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 65eda4bc6ec..3cfcfd0b589 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5618,14 +5618,20 @@ class TensorDictPrimer(Transform): Defaults to `False`. default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float, - all elements of the tensors will be set to that value. If it is a callable, this callable is expected to - return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value` - is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will - be used to generate the corresponding tensors. Defaults to `0.0`. + all elements of the tensors will be set to that value. + If it is a callable and `single_default_value=False` (default), this callable is expected to return a tensor + fitting the specs (ie, ``default_value()`` will be called independently for each leaf spec). If it is a + callable and ``single_default_value=True``, then the callable will be called just once and it is expected + that the structure of its returned TensorDict instance or equivalent will match the provided specs. + Finally, if `default_value` is a dictionary of tensors or a dictionary of callables with keys matching + those of the specs, these will be used to generate the corresponding tensors. Defaults to `0.0`. reset_key (NestedKey, optional): the reset key to be used as partial reset indicator. Must be unique. If not provided, defaults to the only reset key of the parent environment (if it has only one) and raises an exception otherwise. + single_default_value (bool, optional): if ``True`` and `default_value` is a callable, it will be expected that + ``default_value`` returns a single tensordict matching the specs. If `False`, `default_value()` will be + called independently for each leaf. Defaults to ``False``. **kwargs: each keyword argument corresponds to a key in the tensordict. The corresponding value has to be a TensorSpec instance indicating what the value must be. @@ -5725,6 +5731,7 @@ def __init__( | Dict[NestedKey, Callable] = None, reset_key: NestedKey | None = None, expand_specs: bool = None, + single_default_value: bool = False, **kwargs, ): self.device = kwargs.pop("device", None) @@ -5765,10 +5772,13 @@ def __init__( raise ValueError( "If a default_value dictionary is provided, it must match the primers keys." ) + elif single_default_value: + pass else: default_value = { key: default_value for key in self.primers.keys(True, True) } + self.single_default_value = single_default_value self.default_value = default_value self._validated = False self.reset_key = reset_key @@ -5881,6 +5891,14 @@ def _validate_value_tensor(self, value, spec): return True def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + if self.single_default_value and callable(self.default_value): + tensordict.update(self.default_value()) + for key, spec in self.primers.items(True, True): + if not self._validated: + self._validate_value_tensor(tensordict.get(key), spec) + if not self._validated: + self._validated = True + return tensordict for key, spec in self.primers.items(True, True): if spec.shape[: len(tensordict.shape)] != tensordict.shape: raise RuntimeError( @@ -5935,6 +5953,14 @@ def _reset( ): self.primers = self._expand_shape(self.primers) if _reset.any(): + if self.single_default_value and callable(self.default_value): + tensordict_reset.update(self.default_value()) + for key, spec in self.primers.items(True, True): + if not self._validated: + self._validate_value_tensor(tensordict_reset.get(key), spec) + self._validated = True + return tensordict_reset + for key, spec in self.primers.items(True, True): if self.random: shape = ( From ee3d70b4d55c3c78a9648cf71e196382f9bdbfa3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 28 Feb 2025 15:41:05 +0000 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- torchrl/envs/transforms/rlhf.py | 2 ++ torchrl/envs/transforms/transforms.py | 4 ---- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 0b5959bb900..feae60a1c59 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from copy import copy, deepcopy import torch diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 8a823eab931..6eed2fd9318 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -19,13 +19,9 @@ from typing import ( Any, Callable, - Dict, - List, Mapping, - Optional, OrderedDict, Sequence, - Tuple, TypeVar, Union, )