Skip to content

Commit 5e81445

Browse files
Vincent Moensmatteobettini
andauthored
[BugFix] Fix shape setting in CompositeSpec (#1620)
Co-authored-by: Matteo Bettini <55539777+matteobettini@users.noreply.github.com>
1 parent fdee633 commit 5e81445

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

test/test_specs.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,23 @@ def test_nested_composite_spec_update(self, shape, is_complete, device, dtype):
651651
}
652652
assert ts["nested_cp"]["act"] is not None
653653

654+
def test_change_batch_size(self, shape, is_complete, device, dtype):
655+
ts = self._composite_spec(shape, is_complete, device, dtype)
656+
ts["nested"] = CompositeSpec(
657+
leaf=UnboundedContinuousTensorSpec(shape), shape=shape
658+
)
659+
ts = ts.expand(3, *shape)
660+
assert ts["nested"].shape == (3, *shape)
661+
assert ts["nested", "leaf"].shape == (3, *shape)
662+
ts.shape = ()
663+
# this does not change
664+
assert ts["nested"].shape == (3, *shape)
665+
assert ts.shape == ()
666+
ts["nested"].shape = ()
667+
ts.shape = (3,)
668+
assert ts.shape == (3,)
669+
assert ts["nested"].shape == (3,)
670+
654671

655672
@pytest.mark.parametrize("shape", [(), (2, 3)])
656673
@pytest.mark.parametrize("device", get_default_devices())

torchrl/data/tensor_specs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3135,10 +3135,10 @@ def shape(self, value: torch.Size):
31353135
raise RuntimeError("Cannot modify shape of locked composite spec.")
31363136
for key, spec in self.items():
31373137
if isinstance(spec, CompositeSpec):
3138-
if spec.shape[: self.ndim] != self.shape:
3138+
if spec.shape[: len(value)] != value:
31393139
spec.shape = value
31403140
elif spec is not None:
3141-
if spec.shape[: self.ndim] != self.shape:
3141+
if spec.shape[: len(value)] != value:
31423142
raise ValueError(
31433143
f"The shape of the spec and the CompositeSpec mismatch during shape resetting: the "
31443144
f"{self.ndim} first dimensions should match but got self['{key}'].shape={spec.shape} and "

torchrl/envs/libs/gym.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import annotations
7+
58
import importlib
69
import warnings
710
from copy import copy
811
from types import ModuleType
9-
from typing import Dict, List, Optional, Tuple
12+
from typing import Dict, List, Tuple
1013
from warnings import warn
1114

1215
import numpy as np
@@ -310,7 +313,8 @@ def _gym_to_torchrl_spec_transform(
310313
categorical_action_encoding=categorical_action_encoding,
311314
remap_state_to_observation=remap_state_to_observation,
312315
)
313-
return CompositeSpec(**spec_out)
316+
# the batch-size must be set later
317+
return CompositeSpec(spec_out)
314318
elif isinstance(spec, gym_spaces.dict.Dict):
315319
return _gym_to_torchrl_spec_transform(
316320
spec.spaces,
@@ -910,7 +914,7 @@ def info_dict_reader(self, value: callable):
910914
self._info_dict_reader = value
911915

912916
def _reset(
913-
self, tensordict: Optional[TensorDictBase] = None, **kwargs
917+
self, tensordict: TensorDictBase | None = None, **kwargs
914918
) -> TensorDictBase:
915919
if self._is_batched:
916920
# batched (aka 'vectorized') env reset is a bit special: envs are

0 commit comments

Comments
 (0)