Skip to content

Commit 8a8b4c3

Browse files
author
Vincent Moens
authored
Revert "[BugFix] Allow expanding TensorDictPrimer transforms shape with parent batch size" (#2544)
1 parent e9d1677 commit 8a8b4c3

File tree

2 files changed

+7
-75
lines changed

2 files changed

+7
-75
lines changed

test/test_transforms.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@
159159
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
160160
from torchrl.envs.utils import check_env_specs, step_mdp
161161
from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal
162-
from torchrl.modules.utils import get_primers_from_module
163162

164163
IS_WIN = platform == "win32"
165164
if IS_WIN:
@@ -7164,33 +7163,6 @@ def test_dict_default_value(self):
71647163
rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64)
71657164
).all
71667165

7167-
def test_spec_shape_inplace_correction(self):
7168-
hidden_size = input_size = num_layers = 2
7169-
model = GRUModule(
7170-
input_size, hidden_size, num_layers, in_key="observation", out_key="action"
7171-
)
7172-
env = TransformedEnv(
7173-
SerialEnv(2, lambda: GymEnv("Pendulum-v1")),
7174-
)
7175-
# These primers do not have the leading batch dimension
7176-
# since model is agnostic to batch dimension that will be used.
7177-
primers = get_primers_from_module(model)
7178-
for primer in primers.primers:
7179-
assert primers.primers.get(primer).shape == torch.Size(
7180-
[num_layers, hidden_size]
7181-
)
7182-
env.append_transform(primers)
7183-
7184-
# Reset should add the batch dimension to the primers
7185-
# since the parent exists and is batch_locked.
7186-
td = env.reset()
7187-
7188-
for primer in primers.primers:
7189-
assert primers.primers.get(primer).shape == torch.Size(
7190-
[2, num_layers, hidden_size]
7191-
)
7192-
assert td.get(primer).shape == torch.Size([2, num_layers, hidden_size])
7193-
71947166

71957167
class TestTimeMaxPool(TransformBase):
71967168
@pytest.mark.parametrize("T", [2, 4])

torchrl/envs/transforms/transforms.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4596,11 +4596,10 @@ class TensorDictPrimer(Transform):
45964596
The corresponding value has to be a TensorSpec instance indicating
45974597
what the value must be.
45984598
4599-
When used in a `TransformedEnv`, the spec shapes must match the environment's shape if
4600-
the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and
4601-
parent shapes do not match, the spec shapes are modified in-place to match the leading
4602-
dimensions of the parent's batch size. This adjustment is made for cases where the parent
4603-
batch size dimension is not known during instantiation.
4599+
When used in a TransfomedEnv, the spec shapes must match the envs shape if
4600+
the parent env is batch-locked (:obj:`env.batch_locked=True`).
4601+
If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is
4602+
given by the input tensordict instead.
46044603
46054604
Examples:
46064605
>>> from torchrl.envs.libs.gym import GymEnv
@@ -4640,40 +4639,6 @@ class TensorDictPrimer(Transform):
46404639
tensor([[1., 1., 1.],
46414640
[1., 1., 1.]])
46424641
4643-
Examples:
4644-
>>> from torchrl.envs.libs.gym import GymEnv
4645-
>>> from torchrl.envs import SerialEnv, TransformedEnv
4646-
>>> from torchrl.modules.utils import get_primers_from_module
4647-
>>> from torchrl.modules import GRUModule
4648-
>>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
4649-
>>> env = TransformedEnv(base_env)
4650-
>>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action")
4651-
>>> primers = get_primers_from_module(model)
4652-
>>> print(primers) # Primers shape is independent of the env batch size
4653-
TensorDictPrimer(primers=Composite(
4654-
recurrent_state: UnboundedContinuous(
4655-
shape=torch.Size([1, 2]),
4656-
space=ContinuousBox(
4657-
low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True),
4658-
high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)),
4659-
device=cpu,
4660-
dtype=torch.float32,
4661-
domain=continuous),
4662-
device=None,
4663-
shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
4664-
>>> env.append_transform(primers)
4665-
>>> print(env.reset()) # The primers are automatically expanded to match the env batch size
4666-
TensorDict(
4667-
fields={
4668-
done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4669-
observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
4670-
recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
4671-
terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4672-
truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
4673-
batch_size=torch.Size([2]),
4674-
device=None,
4675-
is_shared=False)
4676-
46774642
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
46784643
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
46794644
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
@@ -4799,7 +4764,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
47994764
# We try to set the primer shape to the observation spec shape
48004765
self.primers.shape = observation_spec.shape
48014766
except ValueError:
4802-
# If we fail, we expand them to that shape
4767+
# If we fail, we expnad them to that shape
48034768
self.primers = self._expand_shape(self.primers)
48044769
device = observation_spec.device
48054770
observation_spec.update(self.primers.clone().to(device))
@@ -4866,17 +4831,12 @@ def _reset(
48664831
) -> TensorDictBase:
48674832
"""Sets the default values in the input tensordict.
48684833
4869-
If the parent is batch-locked, we make sure the specs have the appropriate leading
4834+
If the parent is batch-locked, we assume that the specs have the appropriate leading
48704835
shape. We allow for execution when the parent is missing, in which case the
48714836
spec shape is assumed to match the tensordict's.
4837+
48724838
"""
48734839
_reset = _get_reset(self.reset_key, tensordict)
4874-
if (
4875-
self.parent
4876-
and self.parent.batch_locked
4877-
and self.primers.shape[: len(self.parent.shape)] != self.parent.batch_size
4878-
):
4879-
self.primers = self._expand_shape(self.primers)
48804840
if _reset.any():
48814841
for key, spec in self.primers.items(True, True):
48824842
if self.random:

0 commit comments

Comments
 (0)