Skip to content

Commit cac93eb

Browse files
author
Vincent Moens
committed
[Feature] automatically determine return_contiguous
ghstack-source-id: 6d1fc31 Pull Request resolved: #2724
1 parent 37a514d commit cac93eb

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

test/test_env.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3285,8 +3285,9 @@ def test_dynamic_rollout(self):
32853285
RuntimeError,
32863286
match="The environment specs are dynamic. Call rollout with return_contiguous=False",
32873287
):
3288-
rollout = env.rollout(4)
3289-
rollout = env.rollout(4, return_contiguous=False)
3288+
env.rollout(4, return_contiguous=True)
3289+
env.rollout(4)
3290+
env.rollout(4, return_contiguous=False)
32903291
check_env_specs(env, return_contiguous=False)
32913292

32923293
@pytest.mark.skipif(not _has_gym, reason="requires gym to be installed")

torchrl/envs/common.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2699,6 +2699,7 @@ def specs(self) -> Composite:
26992699

27002700
@property
27012701
def _has_dynamic_specs(self) -> bool:
2702+
# TODO: cache this value
27022703
return _has_dynamic_specs(self.specs)
27032704

27042705
def rollout(
@@ -2711,7 +2712,7 @@ def rollout(
27112712
auto_cast_to_device: bool = False,
27122713
break_when_any_done: bool | None = None,
27132714
break_when_all_done: bool | None = None,
2714-
return_contiguous: bool = True,
2715+
return_contiguous: bool | None = False,
27152716
tensordict: Optional[TensorDictBase] = None,
27162717
set_truncated: bool = False,
27172718
out=None,
@@ -2746,7 +2747,8 @@ def rollout(
27462747
break_when_all_done (bool, optional): if ``True``, break if all of the contained environments reach any
27472748
of the done states. If ``False``, break if at least one environment reaches any of the done states.
27482749
Default is ``False``.
2749-
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
2750+
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is `True` if
2751+
the env does not have dynamic specs, otherwise `False`.
27502752
tensordict (TensorDict, optional): if ``auto_reset`` is False, an initial
27512753
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
27522754
environment in those dimensions (if needed).
@@ -2957,7 +2959,8 @@ def rollout(
29572959
raise TypeError(
29582960
"Cannot have both break_when_all_done and break_when_any_done True at the same time."
29592961
)
2960-
2962+
if return_contiguous is None:
2963+
return_contiguous = not self._has_dynamic_specs
29612964
if policy is not None:
29622965
policy = _make_compatible_policy(
29632966
policy,

0 commit comments

Comments
 (0)