@@ -4602,10 +4602,11 @@ class TensorDictPrimer(Transform):
4602
4602
The corresponding value has to be a TensorSpec instance indicating
4603
4603
what the value must be.
4604
4604
4605
- When used in a TransfomedEnv, the spec shapes must match the envs shape if
4606
- the parent env is batch-locked (:obj:`env.batch_locked=True`).
4607
- If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is
4608
- given by the input tensordict instead.
4605
+ When used in a `TransformedEnv`, the spec shapes must match the environment's shape if
4606
+ the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and
4607
+ parent shapes do not match, the spec shapes are modified in-place to match the leading
4608
+ dimensions of the parent's batch size. This adjustment is made for cases where the parent
4609
+ batch size dimension is not known during instantiation.
4609
4610
4610
4611
Examples:
4611
4612
>>> from torchrl.envs.libs.gym import GymEnv
@@ -4645,6 +4646,40 @@ class TensorDictPrimer(Transform):
4645
4646
tensor([[1., 1., 1.],
4646
4647
[1., 1., 1.]])
4647
4648
4649
+ Examples:
4650
+ >>> from torchrl.envs.libs.gym import GymEnv
4651
+ >>> from torchrl.envs import SerialEnv, TransformedEnv
4652
+ >>> from torchrl.modules.utils import get_primers_from_module
4653
+ >>> from torchrl.modules import GRUModule
4654
+ >>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
4655
+ >>> env = TransformedEnv(base_env)
4656
+ >>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action")
4657
+ >>> primers = get_primers_from_module(model)
4658
+ >>> print(primers) # Primers shape is independent of the env batch size
4659
+ TensorDictPrimer(primers=Composite(
4660
+ recurrent_state: UnboundedContinuous(
4661
+ shape=torch.Size([1, 2]),
4662
+ space=ContinuousBox(
4663
+ low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True),
4664
+ high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)),
4665
+ device=cpu,
4666
+ dtype=torch.float32,
4667
+ domain=continuous),
4668
+ device=None,
4669
+ shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
4670
+ >>> env.append_transform(primers)
4671
+ >>> print(env.reset()) # The primers are automatically expanded to match the env batch size
4672
+ TensorDict(
4673
+ fields={
4674
+ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4675
+ observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
4676
+ recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
4677
+ terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4678
+ truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
4679
+ batch_size=torch.Size([2]),
4680
+ device=None,
4681
+ is_shared=False)
4682
+
4648
4683
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
4649
4684
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
4650
4685
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
@@ -4770,7 +4805,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
4770
4805
# We try to set the primer shape to the observation spec shape
4771
4806
self .primers .shape = observation_spec .shape
4772
4807
except ValueError :
4773
- # If we fail, we expnad them to that shape
4808
+ # If we fail, we expand them to that shape
4774
4809
self .primers = self ._expand_shape (self .primers )
4775
4810
device = observation_spec .device
4776
4811
observation_spec .update (self .primers .clone ().to (device ))
@@ -4837,12 +4872,17 @@ def _reset(
4837
4872
) -> TensorDictBase :
4838
4873
"""Sets the default values in the input tensordict.
4839
4874
4840
- If the parent is batch-locked, we assume that the specs have the appropriate leading
4875
+ If the parent is batch-locked, we make sure the specs have the appropriate leading
4841
4876
shape. We allow for execution when the parent is missing, in which case the
4842
4877
spec shape is assumed to match the tensordict's.
4843
-
4844
4878
"""
4845
4879
_reset = _get_reset (self .reset_key , tensordict )
4880
+ if (
4881
+ self .parent
4882
+ and self .parent .batch_locked
4883
+ and self .primers .shape [: len (self .parent .shape )] != self .parent .batch_size
4884
+ ):
4885
+ self .primers = self ._expand_shape (self .primers )
4846
4886
if _reset .any ():
4847
4887
for key , spec in self .primers .items (True , True ):
4848
4888
if self .random :
0 commit comments