Skip to content

Commit b99fb6d

Browse files
authored
[Doc] Pretrained models tutorial (#814)
1 parent 4b1ad2b commit b99fb6d

File tree

4 files changed

+86
-4
lines changed

4 files changed

+86
-4
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Tutorials
3434
:maxdepth: 2
3535

3636
tutorials/torchrl_demo
37+
tutorials/pretrained_models
3738
tutorials/tensordict_tutorial
3839
tutorials/tensordict_module
3940
tutorials/torch_envs

torchrl/envs/transforms/r3m.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class R3MTransform(Compose):
161161
162162
Args:
163163
model_name (str): one of resnet50, resnet34 or resnet18
164-
in_keys (list of str, optional): list of input keys. If left empty, the
164+
in_keys (list of str): list of input keys. If left empty, the
165165
"pixels" key is assumed.
166166
out_keys (list of str, optional): list of output keys. If left empty,
167167
"r3m_vec" is assumed.
@@ -190,7 +190,7 @@ def __new__(cls, *args, **kwargs):
190190
def __init__(
191191
self,
192192
model_name: str,
193-
in_keys: List[str] = None,
193+
in_keys: List[str],
194194
out_keys: List[str] = None,
195195
size: int = 244,
196196
stack_images: bool = True,
@@ -199,7 +199,7 @@ def __init__(
199199
tensor_pixels_keys: List[str] = None,
200200
):
201201
super().__init__()
202-
self.in_keys = in_keys
202+
self.in_keys = in_keys if in_keys is not None else ["pixels"]
203203
self.download = download
204204
self.download_path = download_path
205205
self.model_name = model_name
@@ -258,6 +258,7 @@ def _init(self):
258258
out_keys = ["r3m_vec"]
259259
else:
260260
out_keys = [f"r3m_vec_{i}" for i in range(len(in_keys))]
261+
self.out_keys = out_keys
261262
elif stack_images and len(out_keys) != 1:
262263
raise ValueError(
263264
f"out_key must be of length 1 if stack_images is True. Got out_keys={out_keys}"

torchrl/envs/transforms/vip.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def __init__(
174174
tensor_pixels_keys: List[str] = None,
175175
):
176176
super().__init__()
177-
self.in_keys = in_keys
177+
self.in_keys = in_keys if in_keys is not None else ["pixels"]
178178
self.download = download
179179
self.download_path = download_path
180180
self.model_name = model_name
@@ -233,6 +233,7 @@ def _init(self):
233233
out_keys = ["vip_vec"]
234234
else:
235235
out_keys = [f"vip_vec_{i}" for i in range(len(in_keys))]
236+
self.out_keys = out_keys
236237
elif stack_images and len(out_keys) != 1:
237238
raise ValueError(
238239
f"out_key must be of length 1 if stack_images is True. Got out_keys={out_keys}"
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
Using pretrained models
4+
=======================
5+
This tutorial explains how to use pretrained models in TorchRL.
6+
"""
7+
##############################################################################
8+
# At the end of this tutorial, you will be capable of using pretrained models
9+
# for efficient image representation, and fine-tune them.
10+
#
11+
# TorchRL provides pretrained models that are to be used either as transforms or as
12+
# components of the policy. As the sematic is the same, they can be used interchangeably
13+
# in one or the other context. In this tutorial, we will be using R3M (https://arxiv.org/abs/2203.12601),
14+
# but other models (e.g. VIP) will work equally well.
15+
#
16+
import torch.cuda
17+
from tensordict.nn import TensorDictSequential
18+
from torch import nn
19+
from torchrl.envs import R3MTransform, TransformedEnv
20+
from torchrl.envs.libs.gym import GymEnv
21+
from torchrl.modules import Actor
22+
23+
device = "cuda:0" if torch.cuda.device_count() else "cpu"
24+
25+
##############################################################################
26+
# Let us first create an environment. For the sake of simplicity, we will be using
27+
# a common gym environment. In practice, this will work in more challenging, embodied
28+
# AI contexts (e.g. have a look at our Habitat wrappers).
29+
#
30+
base_env = GymEnv("Ant-v4", from_pixels=True, device=device)
31+
32+
##############################################################################
33+
# Let us fetch our pretrained model. We ask for the pretrained version of the model through the
34+
# download=True flag. By default this is turned off.
35+
# Next, we will append our transform to the environment. In practice, what will happen is that
36+
# each batch of data collected will go through the transform and be mapped on a "r3m_vec" entry
37+
# in the output tensordict. Our policy, consisting of a single layer MLP, will then read this vector and compute
38+
# the corresponding action.
39+
#
40+
r3m = R3MTransform("resnet50", in_keys=["pixels"], download=True).to(device)
41+
env_transformed = TransformedEnv(base_env, r3m)
42+
net = nn.Sequential(
43+
nn.LazyLinear(128), nn.Tanh(), nn.Linear(128, base_env.action_spec.shape[-1])
44+
)
45+
policy = Actor(net, in_keys=["r3m_vec"])
46+
47+
##############################################################################
48+
# Let's check the number of parameters of the policy:
49+
#
50+
print("number of params:", len(list(policy.parameters())))
51+
52+
##############################################################################
53+
# We collect a rollout of 32 steps and print its output:
54+
#
55+
rollout = env_transformed.rollout(32, policy)
56+
print("rollout with transform:", rollout)
57+
58+
##############################################################################
59+
# For fine tuning, we integrate the transform in the policy after making the parameters
60+
# trainable. In practice, it may be wiser to restrict this to a subset of the parameters (say the last layer
61+
# of the MLP).
62+
#
63+
r3m.train()
64+
policy = TensorDictSequential(r3m, policy)
65+
print("number of params after r3m is integrated:", len(list(policy.parameters())))
66+
67+
##############################################################################
68+
# Again, we collect a rollout with R3M. The structure of the output has changed slightly, as now
69+
# the environment returns pixels (and not an embedding). The embedding "r3m_vec" is an intermediate
70+
# result of our policy.
71+
#
72+
rollout = base_env.rollout(32, policy)
73+
print("rollout, fine tuning:", rollout)
74+
75+
##############################################################################
76+
# The easyness with which we have swapped the transform from the env to the policy
77+
# is due to the fact that both behave like TensorDictModule: they have a set of `"in_keys"` and
78+
# `"out_keys"` that make it easy to read and write output in different context.
79+
#

0 commit comments

Comments
 (0)