|
15 | 15 | import torch.nn as nn
|
16 | 16 | from tensordict.tensordict import TensorDict, TensorDictBase
|
17 | 17 |
|
| 18 | +from torchrl._utils import prod, seed_generator |
| 19 | + |
18 | 20 | from torchrl.data.tensor_specs import (
|
19 | 21 | CompositeSpec,
|
20 | 22 | DiscreteTensorSpec,
|
21 | 23 | TensorSpec,
|
22 | 24 | UnboundedContinuousTensorSpec,
|
23 | 25 | )
|
24 |
| - |
25 |
| -from .._utils import prod, seed_generator |
26 |
| -from ..data.utils import DEVICE_TYPING |
27 |
| -from .utils import get_available_libraries, step_mdp |
| 26 | +from torchrl.data.utils import DEVICE_TYPING |
| 27 | +from torchrl.envs.utils import get_available_libraries, step_mdp |
28 | 28 |
|
29 | 29 | LIBRARIES = get_available_libraries()
|
30 | 30 |
|
@@ -219,6 +219,13 @@ def batch_size(self) -> TensorSpec:
|
219 | 219 | def batch_size(self, value: torch.Size) -> None:
|
220 | 220 | self._batch_size = torch.Size(value)
|
221 | 221 |
|
| 222 | + def ndimension(self): |
| 223 | + return len(self.batch_size) |
| 224 | + |
| 225 | + @property |
| 226 | + def ndim(self): |
| 227 | + return self.ndimension() |
| 228 | + |
222 | 229 | # Parent specs: input and output spec.
|
223 | 230 | @property
|
224 | 231 | def input_spec(self) -> TensorSpec:
|
@@ -661,6 +668,97 @@ def rollout(
|
661 | 668 | Returns:
|
662 | 669 | TensorDict object containing the resulting trajectory.
|
663 | 670 |
|
| 671 | + The data returned will be marked with a "time" dimension name for the last |
| 672 | + dimension of the tensordict (at the ``env.ndim`` index). |
| 673 | +
|
| 674 | + Examples: |
| 675 | + >>> from torchrl.envs.libs.gym import GymEnv |
| 676 | + >>> from torchrl.envs.transforms import TransformedEnv, StepCounter |
| 677 | + >>> env = TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20)) |
| 678 | + >>> rollout = env.rollout(max_steps=1000) |
| 679 | + >>> print(rollout) |
| 680 | + TensorDict( |
| 681 | + fields={ |
| 682 | + action: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False), |
| 683 | + done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 684 | + next: TensorDict( |
| 685 | + fields={ |
| 686 | + done: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 687 | + observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False), |
| 688 | + reward: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.float32, is_shared=False), |
| 689 | + step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False), |
| 690 | + truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 691 | + batch_size=torch.Size([20]), |
| 692 | + device=cpu, |
| 693 | + is_shared=False), |
| 694 | + observation: Tensor(shape=torch.Size([20, 3]), device=cpu, dtype=torch.float32, is_shared=False), |
| 695 | + step_count: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.int64, is_shared=False), |
| 696 | + truncated: Tensor(shape=torch.Size([20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 697 | + batch_size=torch.Size([20]), |
| 698 | + device=cpu, |
| 699 | + is_shared=False) |
| 700 | + >>> print(rollout.names) |
| 701 | + ['time'] |
| 702 | + >>> # with envs that contain more dimensions |
| 703 | + >>> from torchrl.envs import SerialEnv |
| 704 | + >>> env = SerialEnv(3, lambda: TransformedEnv(GymEnv("Pendulum-v1"), StepCounter(max_steps=20))) |
| 705 | + >>> rollout = env.rollout(max_steps=1000) |
| 706 | + >>> print(rollout) |
| 707 | + TensorDict( |
| 708 | + fields={ |
| 709 | + action: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False), |
| 710 | + done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 711 | + next: TensorDict( |
| 712 | + fields={ |
| 713 | + done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 714 | + observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False), |
| 715 | + reward: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False), |
| 716 | + step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False), |
| 717 | + truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 718 | + batch_size=torch.Size([3, 20]), |
| 719 | + device=cpu, |
| 720 | + is_shared=False), |
| 721 | + observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False), |
| 722 | + step_count: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.int64, is_shared=False), |
| 723 | + truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 724 | + batch_size=torch.Size([3, 20]), |
| 725 | + device=cpu, |
| 726 | + is_shared=False) |
| 727 | + >>> print(rollout.names) |
| 728 | + [None, 'time'] |
| 729 | +
|
| 730 | + In some instances, contiguous tensordict cannot be obtained because |
| 731 | + they cannot be stacked. This can happen when the data returned at |
| 732 | + each step may have a different shape, or when different environments |
| 733 | + are executed together. In that case, ``return_contiguous=False`` |
| 734 | + will cause the returned tensordict to be a lazy stack of tensordicts: |
| 735 | +
|
| 736 | + Examples: |
| 737 | + >>> rollout = env.rollout(4, return_contiguous=False) |
| 738 | + >>> print(rollout) |
| 739 | + LazyStackedTensorDict( |
| 740 | + fields={ |
| 741 | + action: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False), |
| 742 | + done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 743 | + next: LazyStackedTensorDict( |
| 744 | + fields={ |
| 745 | + done: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False), |
| 746 | + observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False), |
| 747 | + reward: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.float32, is_shared=False), |
| 748 | + step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False), |
| 749 | + truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 750 | + batch_size=torch.Size([3, 4]), |
| 751 | + device=cpu, |
| 752 | + is_shared=False), |
| 753 | + observation: Tensor(shape=torch.Size([3, 4, 3]), device=cpu, dtype=torch.float32, is_shared=False), |
| 754 | + step_count: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.int64, is_shared=False), |
| 755 | + truncated: Tensor(shape=torch.Size([3, 4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, |
| 756 | + batch_size=torch.Size([3, 4]), |
| 757 | + device=cpu, |
| 758 | + is_shared=False) |
| 759 | + >>> print(rollout.names) |
| 760 | + [None, 'time'] |
| 761 | +
|
664 | 762 | """
|
665 | 763 | try:
|
666 | 764 | policy_device = next(policy.parameters()).device
|
@@ -718,6 +816,7 @@ def policy(td):
|
718 | 816 | batch_size = self.batch_size if tensordict is None else tensordict.batch_size
|
719 | 817 |
|
720 | 818 | out_td = torch.stack(tensordicts, len(batch_size))
|
| 819 | + out_td.refine_names(..., "time") |
721 | 820 | if return_contiguous:
|
722 | 821 | return out_td.contiguous()
|
723 | 822 | return out_td
|
|
0 commit comments