Skip to content

Commit 152f282

Browse files
author
Vincent Moens
committed
[Feature] TensorDictPrimer with single default_value callable
ghstack-source-id: 172825a Pull Request resolved: #2732
1 parent 280297a commit 152f282

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

torchrl/envs/custom/pendulum.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,20 @@ def _reset(self, tensordict):
269269
batch_size = (
270270
tensordict.batch_size if tensordict is not None else self.batch_size
271271
)
272-
if tensordict is None or tensordict.is_empty():
272+
if tensordict is None or "params" not in tensordict:
273273
# if no ``tensordict`` is passed, we generate a single set of hyperparameters
274274
# Otherwise, we assume that the input ``tensordict`` contains all the relevant
275275
# parameters to get started.
276276
tensordict = self.gen_params(batch_size=batch_size, device=self.device)
277+
elif "th" in tensordict and "thdot" in tensordict:
278+
# we can hard-reset the env too
279+
return tensordict
280+
out = self._reset_random_data(
281+
tensordict.shape, batch_size, tensordict["params"]
282+
)
283+
return out
284+
285+
def _reset_random_data(self, shape, batch_size, params):
277286

278287
high_th = torch.tensor(self.DEFAULT_X, device=self.device)
279288
high_thdot = torch.tensor(self.DEFAULT_Y, device=self.device)
@@ -284,20 +293,20 @@ def _reset(self, tensordict):
284293
# of simulators run simultaneously. In other contexts, the initial
285294
# random state's shape will depend upon the environment batch-size instead.
286295
th = (
287-
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
296+
torch.rand(shape, generator=self.rng, device=self.device)
288297
* (high_th - low_th)
289298
+ low_th
290299
)
291300
thdot = (
292-
torch.rand(tensordict.shape, generator=self.rng, device=self.device)
301+
torch.rand(shape, generator=self.rng, device=self.device)
293302
* (high_thdot - low_thdot)
294303
+ low_thdot
295304
)
296305
out = TensorDict(
297306
{
298307
"th": th,
299308
"thdot": thdot,
300-
"params": tensordict["params"],
309+
"params": params,
301310
},
302311
batch_size=batch_size,
303312
)

torchrl/envs/transforms/transforms.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5618,14 +5618,20 @@ class TensorDictPrimer(Transform):
56185618
Defaults to `False`.
56195619
default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random
56205620
filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float,
5621-
all elements of the tensors will be set to that value. If it is a callable, this callable is expected to
5622-
return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value`
5623-
is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will
5624-
be used to generate the corresponding tensors. Defaults to `0.0`.
5621+
all elements of the tensors will be set to that value.
5622+
If it is a callable and `single_default_value=False` (default), this callable is expected to return a tensor
5623+
fitting the specs (ie, ``default_value()`` will be called independently for each leaf spec). If it is a
5624+
callable and ``single_default_value=True``, then the callable will be called just once and it is expected
5625+
that the structure of its returned TensorDict instance or equivalent will match the provided specs.
5626+
Finally, if `default_value` is a dictionary of tensors or a dictionary of callables with keys matching
5627+
those of the specs, these will be used to generate the corresponding tensors. Defaults to `0.0`.
56255628
reset_key (NestedKey, optional): the reset key to be used as partial
56265629
reset indicator. Must be unique. If not provided, defaults to the
56275630
only reset key of the parent environment (if it has only one)
56285631
and raises an exception otherwise.
5632+
single_default_value (bool, optional): if ``True`` and `default_value` is a callable, it will be expected that
5633+
``default_value`` returns a single tensordict matching the specs. If `False`, `default_value()` will be
5634+
called independently for each leaf. Defaults to ``False``.
56295635
**kwargs: each keyword argument corresponds to a key in the tensordict.
56305636
The corresponding value has to be a TensorSpec instance indicating
56315637
what the value must be.
@@ -5725,6 +5731,7 @@ def __init__(
57255731
| Dict[NestedKey, Callable] = None,
57265732
reset_key: NestedKey | None = None,
57275733
expand_specs: bool = None,
5734+
single_default_value: bool = False,
57285735
**kwargs,
57295736
):
57305737
self.device = kwargs.pop("device", None)
@@ -5765,10 +5772,13 @@ def __init__(
57655772
raise ValueError(
57665773
"If a default_value dictionary is provided, it must match the primers keys."
57675774
)
5775+
elif single_default_value:
5776+
pass
57685777
else:
57695778
default_value = {
57705779
key: default_value for key in self.primers.keys(True, True)
57715780
}
5781+
self.single_default_value = single_default_value
57725782
self.default_value = default_value
57735783
self._validated = False
57745784
self.reset_key = reset_key
@@ -5881,6 +5891,14 @@ def _validate_value_tensor(self, value, spec):
58815891
return True
58825892

58835893
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
5894+
if self.single_default_value and callable(self.default_value):
5895+
tensordict.update(self.default_value())
5896+
for key, spec in self.primers.items(True, True):
5897+
if not self._validated:
5898+
self._validate_value_tensor(tensordict.get(key), spec)
5899+
if not self._validated:
5900+
self._validated = True
5901+
return tensordict
58845902
for key, spec in self.primers.items(True, True):
58855903
if spec.shape[: len(tensordict.shape)] != tensordict.shape:
58865904
raise RuntimeError(
@@ -5935,6 +5953,14 @@ def _reset(
59355953
):
59365954
self.primers = self._expand_shape(self.primers)
59375955
if _reset.any():
5956+
if self.single_default_value and callable(self.default_value):
5957+
tensordict_reset.update(self.default_value())
5958+
for key, spec in self.primers.items(True, True):
5959+
if not self._validated:
5960+
self._validate_value_tensor(tensordict_reset.get(key), spec)
5961+
self._validated = True
5962+
return tensordict_reset
5963+
59385964
for key, spec in self.primers.items(True, True):
59395965
if self.random:
59405966
shape = (

0 commit comments

Comments
 (0)