Skip to content

Commit b840a77

Browse files
author
Vincent Moens
committed
[Example] Efficient Trajectory Sampling with CompletedTrajRepertoire
ghstack-source-id: 4d5c587 Pull Request resolved: #2642
1 parent 2511c04 commit b840a77

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Efficient Trajectory Sampling with CompletedTrajRepertoire
7+
8+
This example demonstrates how to design a custom transform that filters trajectories during sampling,
9+
ensuring that only completed trajectories are present in sampled batches. This can be particularly useful
10+
when dealing with environments where some trajectories might be corrupted or never reach a done state,
11+
which could skew the learning process or lead to biased models. For instance, in robotics or autonomous
12+
driving, a trajectory might be interrupted due to external factors such as hardware failures or human
13+
intervention, resulting in incomplete or inconsistent data. By filtering out these incomplete trajectories,
14+
we can improve the quality of the training data and increase the robustness of our models.
15+
"""
16+
17+
import torch
18+
from tensordict import TensorDictBase
19+
from torchrl.data import LazyTensorStorage, ReplayBuffer
20+
from torchrl.envs import GymEnv, TrajCounter, Transform
21+
22+
23+
class CompletedTrajectoryRepertoire(Transform):
24+
"""
25+
A transform that keeps track of completed trajectories and filters them out during sampling.
26+
"""
27+
28+
def __init__(self):
29+
super().__init__()
30+
self.completed_trajectories = set()
31+
self.repertoire_tensor = torch.zeros((), dtype=torch.int64)
32+
33+
def _update_repertoire(self, tensordict: TensorDictBase) -> None:
34+
"""Updates the repertoire of completed trajectories."""
35+
done = tensordict["next", "terminated"].squeeze(-1)
36+
traj = tensordict["next", "traj_count"][done].view(-1)
37+
if traj.numel():
38+
self.completed_trajectories = self.completed_trajectories.union(
39+
traj.tolist()
40+
)
41+
self.repertoire_tensor = torch.tensor(
42+
list(self.completed_trajectories), dtype=torch.int64
43+
)
44+
45+
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
46+
"""Updates the repertoire of completed trajectories during insertion."""
47+
self._update_repertoire(tensordict)
48+
return tensordict
49+
50+
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
51+
"""Filters out incomplete trajectories during sampling."""
52+
traj = tensordict["next", "traj_count"]
53+
traj = traj.unsqueeze(-1)
54+
has_traj = (traj == self.repertoire_tensor).any(-1)
55+
has_traj = has_traj.view(tensordict.shape)
56+
return tensordict[has_traj]
57+
58+
59+
def main():
60+
# Create a CartPole environment with trajectory counting
61+
env = GymEnv("CartPole-v1").append_transform(TrajCounter())
62+
63+
# Create a replay buffer with the completed trajectory repertoire transform
64+
buffer = ReplayBuffer(
65+
storage=LazyTensorStorage(1_000_000), transform=CompletedTrajectoryRepertoire()
66+
)
67+
68+
# Roll out the environment for 1000 steps
69+
while True:
70+
rollout = env.rollout(1000, break_when_any_done=False)
71+
if not rollout["next", "done"][-1].item():
72+
break
73+
74+
# Extend the replay buffer with the rollout
75+
buffer.extend(rollout)
76+
77+
# Get the last trajectory count
78+
last_traj_count = rollout[-1]["next", "traj_count"].item()
79+
print(f"Incomplete trajectory: {last_traj_count}")
80+
81+
# Sample from the replay buffer 10 times
82+
for _ in range(10):
83+
sample_traj_counts = buffer.sample(32)["next", "traj_count"].unique()
84+
print(f"Sampled trajectories: {sample_traj_counts}")
85+
assert last_traj_count not in sample_traj_counts
86+
87+
88+
if __name__ == "__main__":
89+
main()

0 commit comments

Comments
 (0)