Skip to content

Commit 83a7a57

Browse files
albertbou92Vincent Moens
andauthored
[BugFix] Allow expanding TensorDictPrimer transforms shape with parent batch size (#2552)
Co-authored-by: Vincent Moens <vmoens@meta.com>
1 parent 527a26a commit 83a7a57

File tree

2 files changed

+75
-8
lines changed

2 files changed

+75
-8
lines changed

test/test_transforms.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
160160
from torchrl.envs.utils import check_env_specs, step_mdp
161161
from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal
162+
from torchrl.modules.utils import get_primers_from_module
162163

163164
IS_WIN = platform == "win32"
164165
if IS_WIN:
@@ -7088,7 +7089,7 @@ def test_tensordictprimer_batching(self, batched_class, break_when_any_done):
70887089

70897090
env = TransformedEnv(
70907091
batched_class(2, lambda: GymEnv(CARTPOLE_VERSIONED())),
7091-
TensorDictPrimer(mykey=Unbounded([2, 4])),
7092+
TensorDictPrimer(Composite({"mykey": Unbounded([2, 4])}, shape=[2])),
70927093
)
70937094
torch.manual_seed(0)
70947095
env.set_seed(0)
@@ -7170,6 +7171,32 @@ def test_dict_default_value(self):
71707171
rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64)
71717172
).all
71727173

7174+
@pytest.mark.skipif(not _has_gym, reason="GYM not found")
7175+
def test_spec_shape_inplace_correction(self):
7176+
hidden_size = input_size = num_layers = 2
7177+
model = GRUModule(
7178+
input_size, hidden_size, num_layers, in_key="observation", out_key="action"
7179+
)
7180+
env = TransformedEnv(
7181+
SerialEnv(2, lambda: GymEnv(PENDULUM_VERSIONED())),
7182+
)
7183+
# These primers do not have the leading batch dimension
7184+
# since model is agnostic to batch dimension that will be used.
7185+
primers = get_primers_from_module(model)
7186+
for primer in primers.primers:
7187+
assert primers.primers.get(primer).shape == torch.Size(
7188+
[num_layers, hidden_size]
7189+
)
7190+
env.append_transform(primers)
7191+
# Reset should add the batch dimension to the primers
7192+
# since the parent exists and is batch_locked.
7193+
td = env.reset()
7194+
for primer in primers.primers:
7195+
assert primers.primers.get(primer).shape == torch.Size(
7196+
[2, num_layers, hidden_size]
7197+
)
7198+
assert td.get(primer).shape == torch.Size([2, num_layers, hidden_size])
7199+
71737200

71747201
class TestTimeMaxPool(TransformBase):
71757202
@pytest.mark.parametrize("T", [2, 4])

torchrl/envs/transforms/transforms.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4602,10 +4602,11 @@ class TensorDictPrimer(Transform):
46024602
The corresponding value has to be a TensorSpec instance indicating
46034603
what the value must be.
46044604
4605-
When used in a TransfomedEnv, the spec shapes must match the envs shape if
4606-
the parent env is batch-locked (:obj:`env.batch_locked=True`).
4607-
If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is
4608-
given by the input tensordict instead.
4605+
When used in a `TransformedEnv`, the spec shapes must match the environment's shape if
4606+
the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and
4607+
parent shapes do not match, the spec shapes are modified in-place to match the leading
4608+
dimensions of the parent's batch size. This adjustment is made for cases where the parent
4609+
batch size dimension is not known during instantiation.
46094610
46104611
Examples:
46114612
>>> from torchrl.envs.libs.gym import GymEnv
@@ -4645,6 +4646,40 @@ class TensorDictPrimer(Transform):
46454646
tensor([[1., 1., 1.],
46464647
[1., 1., 1.]])
46474648
4649+
Examples:
4650+
>>> from torchrl.envs.libs.gym import GymEnv
4651+
>>> from torchrl.envs import SerialEnv, TransformedEnv
4652+
>>> from torchrl.modules.utils import get_primers_from_module
4653+
>>> from torchrl.modules import GRUModule
4654+
>>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
4655+
>>> env = TransformedEnv(base_env)
4656+
>>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action")
4657+
>>> primers = get_primers_from_module(model)
4658+
>>> print(primers) # Primers shape is independent of the env batch size
4659+
TensorDictPrimer(primers=Composite(
4660+
recurrent_state: UnboundedContinuous(
4661+
shape=torch.Size([1, 2]),
4662+
space=ContinuousBox(
4663+
low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True),
4664+
high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)),
4665+
device=cpu,
4666+
dtype=torch.float32,
4667+
domain=continuous),
4668+
device=None,
4669+
shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
4670+
>>> env.append_transform(primers)
4671+
>>> print(env.reset()) # The primers are automatically expanded to match the env batch size
4672+
TensorDict(
4673+
fields={
4674+
done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4675+
observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
4676+
recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
4677+
terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4678+
truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
4679+
batch_size=torch.Size([2]),
4680+
device=None,
4681+
is_shared=False)
4682+
46484683
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
46494684
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
46504685
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
@@ -4770,7 +4805,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
47704805
# We try to set the primer shape to the observation spec shape
47714806
self.primers.shape = observation_spec.shape
47724807
except ValueError:
4773-
# If we fail, we expnad them to that shape
4808+
# If we fail, we expand them to that shape
47744809
self.primers = self._expand_shape(self.primers)
47754810
device = observation_spec.device
47764811
observation_spec.update(self.primers.clone().to(device))
@@ -4837,12 +4872,17 @@ def _reset(
48374872
) -> TensorDictBase:
48384873
"""Sets the default values in the input tensordict.
48394874
4840-
If the parent is batch-locked, we assume that the specs have the appropriate leading
4875+
If the parent is batch-locked, we make sure the specs have the appropriate leading
48414876
shape. We allow for execution when the parent is missing, in which case the
48424877
spec shape is assumed to match the tensordict's.
4843-
48444878
"""
48454879
_reset = _get_reset(self.reset_key, tensordict)
4880+
if (
4881+
self.parent
4882+
and self.parent.batch_locked
4883+
and self.primers.shape[: len(self.parent.shape)] != self.parent.batch_size
4884+
):
4885+
self.primers = self._expand_shape(self.primers)
48464886
if _reset.any():
48474887
for key, spec in self.primers.items(True, True):
48484888
if self.random:

0 commit comments

Comments
 (0)