Skip to content

Commit d909444

Browse files
authored
[Feature] InitTracker transform (#962)
1 parent 2de55cb commit d909444

File tree

8 files changed

+217
-15
lines changed

8 files changed

+217
-15
lines changed

docs/source/reference/envs.rst

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ With these, the following methods are implemented:
4949

5050
- :meth:`env.reset`: a reset method that may (but not necessarily requires to) take
5151
a :class:`tensordict.TensorDict` input. It return the first tensordict of a rollout, usually
52-
containing a :obj:`"done"` state and a set of observations.
52+
containing a :obj:`"done"` state and a set of observations. If not present,
53+
a `"reward"` key will be instantiated with 0s and the appropriate shape.
5354
- :meth:`env.step`: a step method that takes a :class:`tensordict.TensorDict` input
5455
containing an input action as well as other inputs (for model-based or stateless
5556
environments, for instance).
@@ -88,6 +89,21 @@ function.
8889
TorchRL's collectors and rollout methods will be looking for one of these
8990
keys when assessing if the env should be reset.
9091

92+
.. note::
93+
94+
The `torchrl.collectors.utils.split_trajectories` function can be used to
95+
slice adjacent trajectories. It relies on a ``"traj_ids"`` entry in the
96+
input tensordict, or to the junction of ``"done"`` and ``"truncated"`` key
97+
if the ``"traj_ids"`` is missing.
98+
99+
100+
.. note::
101+
102+
In some contexts, it can be useful to mark the first step of a trajectory.
103+
TorchRL provides such functionality through the :class:`torchrl.envs.InitTracker`
104+
transform.
105+
106+
91107
Our environment `tutorial <https://pytorch.org/rl/tutorials/pendulum.html>`_
92108
provides more information on how to design a custom environment from scratch.
93109

@@ -309,6 +325,7 @@ to be able to create this other composition:
309325
FrameSkipTransform
310326
GrayScale
311327
gSDENoise
328+
InitTracker
312329
NoopResetEnv
313330
ObservationNorm
314331
ObservationTransform

test/mocking_classes.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
905905
(
906906
*self.batch_size,
907907
1,
908-
)
908+
),
909+
dtype=torch.int32,
909910
),
910911
shape=self.batch_size,
911912
)
@@ -915,6 +916,14 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
915916
1,
916917
)
917918
)
919+
self.done_spec = DiscreteTensorSpec(
920+
2,
921+
dtype=torch.bool,
922+
shape=(
923+
*self.batch_size,
924+
1,
925+
),
926+
)
918927
self.input_spec = CompositeSpec(
919928
action=BinaryDiscreteTensorSpec(n=1, shape=[*self.batch_size, 1]),
920929
shape=self.batch_size,
@@ -978,19 +987,19 @@ def __init__(
978987
if max_steps is None:
979988
max_steps = torch.tensor(5)
980989
if start_val is None:
981-
start_val = torch.zeros(())
990+
start_val = torch.zeros((), dtype=torch.int32)
982991
if not max_steps.shape == self.batch_size:
983992
raise RuntimeError("batch_size and max_steps shape must match.")
984993

985994
self.max_steps = max_steps
986-
self.start_val = start_val
987995

988996
self.observation_spec = CompositeSpec(
989997
observation=UnboundedContinuousTensorSpec(
990998
(
991999
*self.batch_size,
9921000
1,
993-
)
1001+
),
1002+
dtype=torch.int32,
9941003
),
9951004
shape=self.batch_size,
9961005
)
@@ -1000,6 +1009,14 @@ def __init__(
10001009
1,
10011010
)
10021011
)
1012+
self.done_spec = DiscreteTensorSpec(
1013+
2,
1014+
dtype=torch.bool,
1015+
shape=(
1016+
*self.batch_size,
1017+
1,
1018+
),
1019+
)
10031020
self.input_spec = CompositeSpec(
10041021
action=BinaryDiscreteTensorSpec(n=1, shape=[*self.batch_size, 1]),
10051022
shape=self.batch_size,
@@ -1008,20 +1025,24 @@ def __init__(
10081025
self.count = torch.zeros(
10091026
(*self.batch_size, 1), device=self.device, dtype=torch.int
10101027
)
1028+
if start_val.numel() == self.batch_size.numel():
1029+
self.start_val = start_val.view(*self.batch_size, 1)
1030+
elif start_val.numel() <= 1:
1031+
self.start_val = start_val.expand_as(self.count)
10111032

10121033
def _set_seed(self, seed: Optional[int]):
10131034
torch.manual_seed(seed)
10141035

10151036
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10161037
if tensordict is not None and "_reset" in tensordict.keys():
10171038
_reset = tensordict.get("_reset")
1018-
self.count[_reset] = self.start_val[_reset].unsqueeze(-1)
1039+
self.count[_reset] = self.start_val[_reset].view_as(self.count[_reset])
10191040
else:
1020-
self.count[:] = self.start_val.unsqueeze(-1)
1041+
self.count[:] = self.start_val.view_as(self.count)
10211042
return TensorDict(
10221043
source={
10231044
"observation": self.count.clone(),
1024-
"done": self.count > self.max_steps.unsqueeze(-1),
1045+
"done": self.count > self.max_steps.view_as(self.count),
10251046
},
10261047
batch_size=self.batch_size,
10271048
device=self.device,
@@ -1032,7 +1053,7 @@ def _step(
10321053
tensordict: TensorDictBase,
10331054
) -> TensorDictBase:
10341055
action = tensordict.get("action")
1035-
self.count += action.to(torch.int).unsqueeze(-1)
1056+
self.count += action.to(torch.int).view_as(self.count)
10361057
tensordict = TensorDict(
10371058
source={
10381059
"observation": self.count.clone(),

test/test_collector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,17 +1223,17 @@ def test_initial_obs_consistency(env_class, seed=1):
12231223
if env_class == CountingEnv:
12241224
arange_0 = start_val + torch.arange(max_steps - 3)
12251225
arange = start_val + torch.arange(2)
1226-
expected = torch.cat([arange_0, arange_0, arange]).float()
1226+
expected = torch.cat([arange_0, arange_0, arange])
12271227
else:
12281228
# the first env has a shorter horizon than the second
12291229
arange_0 = start_val + torch.arange(max_steps - 3 - 1)
12301230
arange = start_val + torch.arange(start_val)
1231-
expected_0 = torch.cat([arange_0, arange_0, arange]).float()
1231+
expected_0 = torch.cat([arange_0, arange_0, arange])
12321232
arange_0 = start_val + torch.arange(max_steps - 3)
12331233
arange = start_val + torch.arange(2)
1234-
expected_1 = torch.cat([arange_0, arange_0, arange]).float()
1234+
expected_1 = torch.cat([arange_0, arange_0, arange])
12351235
expected = torch.stack([expected_0, expected_1])
1236-
assert torch.allclose(obs, expected)
1236+
assert torch.allclose(obs, expected.to(obs.dtype))
12371237

12381238

12391239
def weight_reset(m):

test/test_transforms.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from mocking_classes import (
2323
ContinuousActionVecMockEnv,
24+
CountingBatchedEnv,
2425
DiscreteActionConvMockEnvNumpy,
2526
MockBatchedLockedEnv,
2627
MockBatchedUnLockedEnv,
@@ -75,7 +76,7 @@
7576
from torchrl.envs.libs.gym import _has_gym, GymEnv
7677
from torchrl.envs.transforms import VecNorm
7778
from torchrl.envs.transforms.r3m import _R3MNet
78-
from torchrl.envs.transforms.transforms import _has_tv
79+
from torchrl.envs.transforms.transforms import _has_tv, InitTracker
7980
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
8081
from torchrl.envs.utils import check_env_specs, step_mdp
8182

@@ -6314,6 +6315,99 @@ def test_crop_mask(self, mask_key):
63146315
assert tensordict_crop[mask_key].all()
63156316

63166317

6318+
class TestInitTracker(TransformBase):
6319+
def test_single_trans_env_check(self):
6320+
env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2])
6321+
env = TransformedEnv(env, InitTracker())
6322+
check_env_specs(env)
6323+
6324+
def test_serial_trans_env_check(self):
6325+
def make_env():
6326+
env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2])
6327+
env = TransformedEnv(env, InitTracker())
6328+
return env
6329+
6330+
env = SerialEnv(2, make_env)
6331+
check_env_specs(env)
6332+
6333+
def test_parallel_trans_env_check(self):
6334+
def make_env():
6335+
env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2])
6336+
env = TransformedEnv(env, InitTracker())
6337+
return env
6338+
6339+
env = ParallelEnv(2, make_env)
6340+
check_env_specs(env)
6341+
6342+
def test_trans_serial_env_check(self):
6343+
def make_env():
6344+
env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2])
6345+
return env
6346+
6347+
env = SerialEnv(2, make_env)
6348+
env = TransformedEnv(env, InitTracker())
6349+
check_env_specs(env)
6350+
6351+
def test_trans_parallel_env_check(self):
6352+
def make_env():
6353+
env = CountingBatchedEnv(max_steps=torch.tensor([4, 5]), batch_size=[2])
6354+
return env
6355+
6356+
env = ParallelEnv(2, make_env)
6357+
env = TransformedEnv(env, InitTracker())
6358+
check_env_specs(env)
6359+
6360+
def test_transform_no_env(self):
6361+
with pytest.raises(
6362+
NotImplementedError, match="InitTracker cannot be executed without a parent"
6363+
):
6364+
InitTracker()(None)
6365+
6366+
def test_transform_compose(self):
6367+
with pytest.raises(
6368+
NotImplementedError, match="InitTracker cannot be executed without a parent"
6369+
):
6370+
Compose(InitTracker())(None)
6371+
6372+
def test_transform_env(self):
6373+
policy = lambda tensordict: tensordict.set(
6374+
"action", torch.ones(tensordict.shape, dtype=torch.int32)
6375+
)
6376+
env = CountingBatchedEnv(max_steps=torch.tensor([3, 4]), batch_size=[2])
6377+
env = TransformedEnv(env, InitTracker())
6378+
r = env.rollout(100, policy, break_when_any_done=False)
6379+
assert (r["is_init"].sum(1) == torch.tensor([25, 20])).all()
6380+
6381+
def test_transform_model(self):
6382+
with pytest.raises(
6383+
NotImplementedError, match="InitTracker cannot be executed without a parent"
6384+
):
6385+
td = TensorDict({}, [])
6386+
chain = nn.Sequential(InitTracker())
6387+
chain(td)
6388+
6389+
def test_transform_rb(self):
6390+
batch = [1]
6391+
device = "cpu"
6392+
rb = ReplayBuffer(LazyTensorStorage(20))
6393+
rb.append_transform(InitTracker())
6394+
reward = torch.randn(*batch, 1, device=device)
6395+
misc = torch.randn(*batch, 1, device=device)
6396+
td = TensorDict(
6397+
{"misc": misc, "reward": reward},
6398+
batch,
6399+
device=device,
6400+
)
6401+
rb.extend(td)
6402+
with pytest.raises(
6403+
NotImplementedError, match="InitTracker cannot be executed without a parent"
6404+
):
6405+
_ = rb.sample(20)
6406+
6407+
def test_transform_inverse(self):
6408+
raise pytest.skip("No inverse for InitTracker")
6409+
6410+
63176411
if __name__ == "__main__":
63186412
args, unknown = argparse.ArgumentParser().parse_known_args()
63196413
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/envs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
FrameSkipTransform,
2222
GrayScale,
2323
gSDENoise,
24+
InitTracker,
2425
NoopResetEnv,
2526
ObservationNorm,
2627
ObservationTransform,

torchrl/envs/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def policy(td):
708708
exclude_action=False,
709709
)
710710
if not break_when_any_done and done.any():
711-
_reset = done.squeeze(-1)
711+
_reset = done.view(tensordict.shape)
712712
tensordict.set("_reset", _reset)
713713
self.reset(tensordict)
714714

torchrl/envs/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
FrameSkipTransform,
1919
GrayScale,
2020
gSDENoise,
21+
InitTracker,
2122
NoopResetEnv,
2223
ObservationNorm,
2324
ObservationTransform,

torchrl/envs/transforms/transforms.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
CompositeSpec,
2424
ContinuousBox,
2525
DEVICE_TYPING,
26+
DiscreteTensorSpec,
2627
OneHotDiscreteTensorSpec,
2728
TensorSpec,
2829
UnboundedContinuousTensorSpec,
@@ -56,6 +57,8 @@ def interpolation_fn(interpolation): # noqa: D103
5657
IMAGE_KEYS = ["pixels"]
5758
_MAX_NOOPS_TRIALS = 10
5859

60+
FORWARD_NOT_IMPLEMENTED = "class {} cannot be executed without a parent" "environment."
61+
5962

6063
def _apply_to_composite(function):
6164
def new_fun(self, observation_spec):
@@ -2960,6 +2963,11 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
29602963
observation_spec.update(episode_specs)
29612964
return observation_spec
29622965

2966+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
2967+
raise NotImplementedError(
2968+
FORWARD_NOT_IMPLEMENTED.format(self.__class__.__name__)
2969+
)
2970+
29632971

29642972
class StepCounter(Transform):
29652973
"""Counts the steps from a reset and sets the done state to True after a certain number of steps.
@@ -3370,3 +3378,63 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
33703378
arange = arange.view(arange_shape)
33713379
idx = idx_0 + arange
33723380
return tensordict.gather(dim=self.sample_dim, index=idx)
3381+
3382+
3383+
class InitTracker(Transform):
3384+
"""Reset tracker.
3385+
3386+
This transform populates the step/reset tensordict with a reset tracker entry
3387+
that is set to ``True`` whenever :meth:`~.reset` is called.
3388+
3389+
Args:
3390+
init_key (str, optional): the key to be used for the tracker entry.
3391+
3392+
Examples:
3393+
>>> from torchrl.envs.libs.gym import GymEnv
3394+
>>> env = TransformedEnv(GymEnv("Pendulum-v1"), InitTracker())
3395+
>>> td = env.reset()
3396+
>>> print(td["is_init"])
3397+
tensor(True)
3398+
>>> td = env.rand_step(td)
3399+
>>> print(td["next", "is_init"])
3400+
tensor(False)
3401+
3402+
"""
3403+
3404+
def __init__(self, init_key: bool = "is_init"):
3405+
super().__init__(in_keys=[], out_keys=[init_key])
3406+
3407+
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
3408+
if self.out_keys[0] not in tensordict.keys():
3409+
device = tensordict.device
3410+
if device is None:
3411+
device = torch.device("cpu")
3412+
tensordict.set(
3413+
self.out_keys[0],
3414+
torch.zeros(tensordict.shape, device=device, dtype=torch.bool),
3415+
)
3416+
return tensordict
3417+
3418+
def reset(self, tensordict: TensorDictBase) -> TensorDictBase:
3419+
device = tensordict.device
3420+
if device is None:
3421+
device = torch.device("cpu")
3422+
_reset = tensordict.get("_reset", None)
3423+
if _reset is None:
3424+
_reset = torch.ones(tensordict.shape, device=device, dtype=torch.bool)
3425+
tensordict.set(self.out_keys[0], _reset.clone())
3426+
return tensordict
3427+
3428+
def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
3429+
observation_spec[self.out_keys[0]] = DiscreteTensorSpec(
3430+
2,
3431+
dtype=torch.bool,
3432+
device=self.parent.device,
3433+
shape=self.parent.batch_size,
3434+
)
3435+
return observation_spec
3436+
3437+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
3438+
raise NotImplementedError(
3439+
FORWARD_NOT_IMPLEMENTED.format(self.__class__.__name__)
3440+
)

0 commit comments

Comments
 (0)