Skip to content

Commit 99a95e3

Browse files
vmoensrohitnig
andauthored
[Feature] Marking the time dimension (#1095)
Co-authored-by: Rohit Nigam <rohitnigam@meta.com> Co-authored-by: Rohit Nigam <rohitnigam@gmail.com>
1 parent 39fe662 commit 99a95e3

File tree

5 files changed

+120
-4
lines changed

5 files changed

+120
-4
lines changed

test/test_collector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def env_fn(seed):
212212
device, policy_device, storing_device, d.device.type
213213
)
214214
break
215+
assert d.names[-1] == "time"
215216

216217
collector.shutdown()
217218

@@ -231,6 +232,7 @@ def env_fn(seed):
231232
device, policy_device, storing_device, d.device.type
232233
)
233234
break
235+
assert d.names[-1] == "time"
234236

235237
ccollector.shutdown()
236238

@@ -273,6 +275,7 @@ def env_fn(seed):
273275
b2 = d
274276
else:
275277
break
278+
assert d.names[-1] == "time"
276279
with pytest.raises(AssertionError):
277280
assert_allclose_td(b1, b2)
278281
collector.shutdown()
@@ -292,6 +295,7 @@ def env_fn(seed):
292295
b2c = d
293296
else:
294297
break
298+
assert d.names[-1] == "time"
295299
with pytest.raises(AssertionError):
296300
assert_allclose_td(b1c, b2c)
297301

@@ -508,6 +512,7 @@ def env_fn():
508512
assert b.numel() == -(-frames_per_batch // num_env) * num_env
509513
if i == 5:
510514
break
515+
assert b.names[-1] == "time"
511516
ccollector.shutdown()
512517

513518
ccollector = MultiSyncDataCollector(
@@ -525,6 +530,7 @@ def env_fn():
525530
)
526531
if i == 5:
527532
break
533+
assert b.names[-1] == "time"
528534
ccollector.shutdown()
529535

530536

test/test_distributed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _test_distributed_collector_basic(cls, queue, frames_per_batch):
100100
for data in collector:
101101
total += data.numel()
102102
assert data.numel() == frames_per_batch
103+
assert data.names[-1] == "time"
103104
collector.shutdown()
104105
assert total == 1000
105106
queue.put("passed")

test/test_env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,14 @@ def test_rollout(env_name, frame_skip, seed=0):
156156
env.set_seed(seed)
157157
env.reset()
158158
rollout1 = env.rollout(max_steps=100)
159+
assert rollout1.names[-1] == "time"
159160

160161
torch.manual_seed(seed)
161162
np.random.seed(seed)
162163
env.set_seed(seed)
163164
env.reset()
164165
rollout2 = env.rollout(max_steps=100)
166+
assert rollout2.names[-1] == "time"
165167

166168
assert_allclose_td(rollout1, rollout2)
167169

@@ -231,6 +233,7 @@ def test_rollout_reset(env_name, frame_skip, parallel, truncated_key, seed=0):
231233
env = SerialEnv(3, envs)
232234
env.set_seed(100)
233235
out = env.rollout(100, break_when_any_done=False)
236+
assert out.names[-1] == "time"
234237
assert out.shape == torch.Size([3, 100])
235238
assert (
236239
out["next", truncated_key].squeeze().sum(-1) == torch.tensor([5, 3, 2])

torchrl/collectors/collectors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,12 @@ class SyncDataCollector(DataCollectorBase):
464464
is_shared=False)
465465
>>> del collector
466466
467+
The collector delivers batches of data that are marked with a ``"time"``
468+
dimension.
469+
470+
Examples:
471+
>>> assert data.names[-1] == "time"
472+
467473
"""
468474

469475
def __init__(
@@ -665,6 +671,7 @@ def __init__(
665671
device=self.storing_device,
666672
),
667673
)
674+
self._tensordict_out.refine_names(..., "time")
668675

669676
if split_trajs is None:
670677
split_trajs = False

torchrl/envs/common.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@
1515
import torch.nn as nn
1616
from tensordict.tensordict import TensorDict, TensorDictBase
1717

18+
from torchrl._utils import prod, seed_generator
19+
1820
from torchrl.data.tensor_specs import (
1921
CompositeSpec,
2022
DiscreteTensorSpec,
2123
TensorSpec,
2224
UnboundedContinuousTensorSpec,
2325
)
24-
25-
from .._utils import prod, seed_generator
26-
from ..data.utils import DEVICE_TYPING
27-
from .utils import get_available_libraries, step_mdp
26+
from torchrl.data.utils import DEVICE_TYPING
27+
from torchrl.envs.utils import get_available_libraries, step_mdp
2828

2929
LIBRARIES = get_available_libraries()
3030

@@ -219,6 +219,13 @@ def batch_size(self) -> TensorSpec:
219219
def batch_size(self, value: torch.Size) -> None:
220220
self._batch_size = torch.Size(value)
221221

222+
def ndimension(self):
223+
return len(self.batch_size)
224+
225+
@property
226+
def ndim(self):
227+
return self.ndimension()
228+
222229
# Parent specs: input and output spec.
223230
@property
224231
def input_spec(self) -> TensorSpec:
@@ -661,6 +668,97 @@ def rollout(
661668
Returns:
662669
TensorDict object containing the resulting trajectory.
663670
671+
The data returned will be marked with a "time" dimension name for the last
672+
dimension of the tensordict (at the ``env.ndim`` index).
673+
674+
Examples:
675+
>>> from torchrl.envs.libs.gym import GymEnv
676+
>>> from torchrl.envs.transforms import TransformedEnv, StepCounter
677+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20))
678+
>>> rollout = env.rollout(max_steps=1000)
679+
>>> print(rollout)
680+
TensorDict(
681+
fields={
682+
action: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
683+
done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
684+
next: TensorDict(
685+
fields={
686+
done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
687+
observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
688+
reward: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
689+
step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
690+
truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
691+
batch_size=torch.Size([20]),
692+
device=cpu,
693+
is_shared=False),
694+
observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
695+
step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
696+
truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
697+
batch_size=torch.Size([20]),
698+
device=cpu,
699+
is_shared=False)
700+
>>> print(rollout.names)
701+
['time']
702+
>>> # with envs that contain more dimensions
703+
>>> from torchrl.envs import SerialEnv
704+
>>> env = SerialEnv(3, lambda: TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20)))
705+
>>> rollout = env.rollout(max_steps=1000)
706+
>>> print(rollout)
707+
TensorDict(
708+
fields={
709+
action: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
710+
done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
711+
next: TensorDict(
712+
fields={
713+
done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False),
714+
observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
715+
reward: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False),
716+
step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
717+
truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
718+
batch_size=torch.Size([3, 20]),
719+
device=cpu,
720+
is_shared=False),
721+
observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False),
722+
step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False),
723+
truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
724+
batch_size=torch.Size([3, 20]),
725+
device=cpu,
726+
is_shared=False)
727+
>>> print(rollout.names)
728+
[None, 'time']
729+
730+
In some instances, contiguous tensordict cannot be obtained because
731+
they cannot be stacked. This can happen when the data returned at
732+
each step may have a different shape, or when different environments
733+
are executed together. In that case, ``return_contiguous=False``
734+
will cause the returned tensordict to be a lazy stack of tensordicts:
735+
736+
Examples:
737+
>>> rollout = env.rollout(4, return_contiguous=False)
738+
>>> print(rollout)
739+
LazyStackedTensorDict(
740+
fields={
741+
action: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
742+
done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
743+
next: LazyStackedTensorDict(
744+
fields={
745+
done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False),
746+
observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
747+
reward: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False),
748+
step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
749+
truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
750+
batch_size=torch.Size([3, 4]),
751+
device=cpu,
752+
is_shared=False),
753+
observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False),
754+
step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False),
755+
truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
756+
batch_size=torch.Size([3, 4]),
757+
device=cpu,
758+
is_shared=False)
759+
>>> print(rollout.names)
760+
[None, 'time']
761+
664762
"""
665763
try:
666764
policy_device = next(policy.parameters()).device
@@ -718,6 +816,7 @@ def policy(td):
718816
batch_size = self.batch_size if tensordict is None else tensordict.batch_size
719817

720818
out_td = torch.stack(tensordicts, len(batch_size))
819+
out_td.refine_names(..., "time")
721820
if return_contiguous:
722821
return out_td.contiguous()
723822
return out_td

0 commit comments

Comments
 (0)