@@ -4596,11 +4596,10 @@ class TensorDictPrimer(Transform):
4596
4596
The corresponding value has to be a TensorSpec instance indicating
4597
4597
what the value must be.
4598
4598
4599
- When used in a `TransformedEnv`, the spec shapes must match the environment's shape if
4600
- the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and
4601
- parent shapes do not match, the spec shapes are modified in-place to match the leading
4602
- dimensions of the parent's batch size. This adjustment is made for cases where the parent
4603
- batch size dimension is not known during instantiation.
4599
+ When used in a TransfomedEnv, the spec shapes must match the envs shape if
4600
+ the parent env is batch-locked (:obj:`env.batch_locked=True`).
4601
+ If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is
4602
+ given by the input tensordict instead.
4604
4603
4605
4604
Examples:
4606
4605
>>> from torchrl.envs.libs.gym import GymEnv
@@ -4640,40 +4639,6 @@ class TensorDictPrimer(Transform):
4640
4639
tensor([[1., 1., 1.],
4641
4640
[1., 1., 1.]])
4642
4641
4643
- Examples:
4644
- >>> from torchrl.envs.libs.gym import GymEnv
4645
- >>> from torchrl.envs import SerialEnv, TransformedEnv
4646
- >>> from torchrl.modules.utils import get_primers_from_module
4647
- >>> from torchrl.modules import GRUModule
4648
- >>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
4649
- >>> env = TransformedEnv(base_env)
4650
- >>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action")
4651
- >>> primers = get_primers_from_module(model)
4652
- >>> print(primers) # Primers shape is independent of the env batch size
4653
- TensorDictPrimer(primers=Composite(
4654
- recurrent_state: UnboundedContinuous(
4655
- shape=torch.Size([1, 2]),
4656
- space=ContinuousBox(
4657
- low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True),
4658
- high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)),
4659
- device=cpu,
4660
- dtype=torch.float32,
4661
- domain=continuous),
4662
- device=None,
4663
- shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
4664
- >>> env.append_transform(primers)
4665
- >>> print(env.reset()) # The primers are automatically expanded to match the env batch size
4666
- TensorDict(
4667
- fields={
4668
- done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4669
- observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
4670
- recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
4671
- terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4672
- truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
4673
- batch_size=torch.Size([2]),
4674
- device=None,
4675
- is_shared=False)
4676
-
4677
4642
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
4678
4643
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
4679
4644
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
@@ -4799,7 +4764,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
4799
4764
# We try to set the primer shape to the observation spec shape
4800
4765
self .primers .shape = observation_spec .shape
4801
4766
except ValueError :
4802
- # If we fail, we expand them to that shape
4767
+ # If we fail, we expnad them to that shape
4803
4768
self .primers = self ._expand_shape (self .primers )
4804
4769
device = observation_spec .device
4805
4770
observation_spec .update (self .primers .clone ().to (device ))
@@ -4866,17 +4831,12 @@ def _reset(
4866
4831
) -> TensorDictBase :
4867
4832
"""Sets the default values in the input tensordict.
4868
4833
4869
- If the parent is batch-locked, we make sure the specs have the appropriate leading
4834
+ If the parent is batch-locked, we assume that the specs have the appropriate leading
4870
4835
shape. We allow for execution when the parent is missing, in which case the
4871
4836
spec shape is assumed to match the tensordict's.
4837
+
4872
4838
"""
4873
4839
_reset = _get_reset (self .reset_key , tensordict )
4874
- if (
4875
- self .parent
4876
- and self .parent .batch_locked
4877
- and self .primers .shape [: len (self .parent .shape )] != self .parent .batch_size
4878
- ):
4879
- self .primers = self ._expand_shape (self .primers )
4880
4840
if _reset .any ():
4881
4841
for key , spec in self .primers .items (True , True ):
4882
4842
if self .random :
0 commit comments