Skip to content

Commit e0c0d34

Browse files
authored
[Doc] Using R3M with a replay buffer (#820)
1 parent dc1584d commit e0c0d34

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

tutorials/sphinx-tutorials/pretrained_models.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,36 @@
7373
print("rollout, fine tuning:", rollout)
7474

7575
##############################################################################
76-
# The easyness with which we have swapped the transform from the env to the policy
76+
# The easiness with which we have swapped the transform from the env to the policy
7777
# is due to the fact that both behave like TensorDictModule: they have a set of `"in_keys"` and
7878
# `"out_keys"` that make it easy to read and write output in different context.
7979
#
80+
# To conclude this tutorial, let's have a look at how we could use R3M to read
81+
# images stored in a replay buffer (e.g. in an offline RL context). First, let's build our dataset:
82+
#
83+
from torchrl.data import LazyMemmapStorage, ReplayBuffer
84+
85+
storage = LazyMemmapStorage(1000)
86+
rb = ReplayBuffer(storage=storage, transform=r3m)
87+
88+
##############################################################################
89+
# We can now collect the data (random rollouts for our purpose) and fill the replay
90+
# buffer with it:
91+
#
92+
total = 0
93+
while total < 1000:
94+
tensordict = base_env.rollout(1000)
95+
rb.extend(tensordict)
96+
total += tensordict.numel()
97+
98+
##############################################################################
99+
# Let's check what our replay buffer storage looks like. It should not contain the "r3m_vec" entry
100+
# since we haven't used it yet:
101+
print("stored data:", storage._storage)
102+
103+
##############################################################################
104+
# When sampling, the data will go through the R3M transform, giving us the processed data that we wanted.
105+
# In this way, we can train an algorithm offline on a dataset made of images:
106+
#
107+
batch = rb.sample(32)
108+
print("data after sampling:", batch)

0 commit comments

Comments
 (0)