@@ -5618,14 +5618,20 @@ class TensorDictPrimer(Transform):
5618
5618
Defaults to `False`.
5619
5619
default_value (float, Callable, Dict[NestedKey, float], Dict[NestedKey, Callable], optional): If non-random
5620
5620
filling is chosen, `default_value` will be used to populate the tensors. If `default_value` is a float,
5621
- all elements of the tensors will be set to that value. If it is a callable, this callable is expected to
5622
- return a tensor fitting the specs, and it will be used to generate the tensors. Finally, if `default_value`
5623
- is a dictionary of tensors or a dictionary of callables with keys matching those of the specs, these will
5624
- be used to generate the corresponding tensors. Defaults to `0.0`.
5621
+ all elements of the tensors will be set to that value.
5622
+ If it is a callable and `single_default_value=False` (default), this callable is expected to return a tensor
5623
+ fitting the specs (ie, ``default_value()`` will be called independently for each leaf spec). If it is a
5624
+ callable and ``single_default_value=True``, then the callable will be called just once and it is expected
5625
+ that the structure of its returned TensorDict instance or equivalent will match the provided specs.
5626
+ Finally, if `default_value` is a dictionary of tensors or a dictionary of callables with keys matching
5627
+ those of the specs, these will be used to generate the corresponding tensors. Defaults to `0.0`.
5625
5628
reset_key (NestedKey, optional): the reset key to be used as partial
5626
5629
reset indicator. Must be unique. If not provided, defaults to the
5627
5630
only reset key of the parent environment (if it has only one)
5628
5631
and raises an exception otherwise.
5632
+ single_default_value (bool, optional): if ``True`` and `default_value` is a callable, it will be expected that
5633
+ ``default_value`` returns a single tensordict matching the specs. If `False`, `default_value()` will be
5634
+ called independently for each leaf. Defaults to ``False``.
5629
5635
**kwargs: each keyword argument corresponds to a key in the tensordict.
5630
5636
The corresponding value has to be a TensorSpec instance indicating
5631
5637
what the value must be.
@@ -5725,6 +5731,7 @@ def __init__(
5725
5731
| Dict [NestedKey , Callable ] = None ,
5726
5732
reset_key : NestedKey | None = None ,
5727
5733
expand_specs : bool = None ,
5734
+ single_default_value : bool = False ,
5728
5735
** kwargs ,
5729
5736
):
5730
5737
self .device = kwargs .pop ("device" , None )
@@ -5765,10 +5772,13 @@ def __init__(
5765
5772
raise ValueError (
5766
5773
"If a default_value dictionary is provided, it must match the primers keys."
5767
5774
)
5775
+ elif single_default_value :
5776
+ pass
5768
5777
else :
5769
5778
default_value = {
5770
5779
key : default_value for key in self .primers .keys (True , True )
5771
5780
}
5781
+ self .single_default_value = single_default_value
5772
5782
self .default_value = default_value
5773
5783
self ._validated = False
5774
5784
self .reset_key = reset_key
@@ -5881,6 +5891,14 @@ def _validate_value_tensor(self, value, spec):
5881
5891
return True
5882
5892
5883
5893
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
5894
+ if self .single_default_value and callable (self .default_value ):
5895
+ tensordict .update (self .default_value ())
5896
+ for key , spec in self .primers .items (True , True ):
5897
+ if not self ._validated :
5898
+ self ._validate_value_tensor (tensordict .get (key ), spec )
5899
+ if not self ._validated :
5900
+ self ._validated = True
5901
+ return tensordict
5884
5902
for key , spec in self .primers .items (True , True ):
5885
5903
if spec .shape [: len (tensordict .shape )] != tensordict .shape :
5886
5904
raise RuntimeError (
@@ -5935,6 +5953,14 @@ def _reset(
5935
5953
):
5936
5954
self .primers = self ._expand_shape (self .primers )
5937
5955
if _reset .any ():
5956
+ if self .single_default_value and callable (self .default_value ):
5957
+ tensordict_reset .update (self .default_value ())
5958
+ for key , spec in self .primers .items (True , True ):
5959
+ if not self ._validated :
5960
+ self ._validate_value_tensor (tensordict_reset .get (key ), spec )
5961
+ self ._validated = True
5962
+ return tensordict_reset
5963
+
5938
5964
for key , spec in self .primers .items (True , True ):
5939
5965
if self .random :
5940
5966
shape = (
0 commit comments