|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +""" |
| 3 | +Using pretrained models |
| 4 | +======================= |
| 5 | +This tutorial explains how to use pretrained models in TorchRL. |
| 6 | +""" |
| 7 | +############################################################################## |
| 8 | +# At the end of this tutorial, you will be capable of using pretrained models |
| 9 | +# for efficient image representation, and fine-tune them. |
| 10 | +# |
| 11 | +# TorchRL provides pretrained models that are to be used either as transforms or as |
| 12 | +# components of the policy. As the sematic is the same, they can be used interchangeably |
| 13 | +# in one or the other context. In this tutorial, we will be using R3M (https://arxiv.org/abs/2203.12601), |
| 14 | +# but other models (e.g. VIP) will work equally well. |
| 15 | +# |
| 16 | +import torch.cuda |
| 17 | +from tensordict.nn import TensorDictSequential |
| 18 | +from torch import nn |
| 19 | +from torchrl.envs import R3MTransform, TransformedEnv |
| 20 | +from torchrl.envs.libs.gym import GymEnv |
| 21 | +from torchrl.modules import Actor |
| 22 | + |
| 23 | +device = "cuda:0" if torch.cuda.device_count() else "cpu" |
| 24 | + |
| 25 | +############################################################################## |
| 26 | +# Let us first create an environment. For the sake of simplicity, we will be using |
| 27 | +# a common gym environment. In practice, this will work in more challenging, embodied |
| 28 | +# AI contexts (e.g. have a look at our Habitat wrappers). |
| 29 | +# |
| 30 | +base_env = GymEnv("Ant-v4", from_pixels=True, device=device) |
| 31 | + |
| 32 | +############################################################################## |
| 33 | +# Let us fetch our pretrained model. We ask for the pretrained version of the model through the |
| 34 | +# download=True flag. By default this is turned off. |
| 35 | +# Next, we will append our transform to the environment. In practice, what will happen is that |
| 36 | +# each batch of data collected will go through the transform and be mapped on a "r3m_vec" entry |
| 37 | +# in the output tensordict. Our policy, consisting of a single layer MLP, will then read this vector and compute |
| 38 | +# the corresponding action. |
| 39 | +# |
| 40 | +r3m = R3MTransform("resnet50", in_keys=["pixels"], download=True).to(device) |
| 41 | +env_transformed = TransformedEnv(base_env, r3m) |
| 42 | +net = nn.Sequential( |
| 43 | + nn.LazyLinear(128), nn.Tanh(), nn.Linear(128, base_env.action_spec.shape[-1]) |
| 44 | +) |
| 45 | +policy = Actor(net, in_keys=["r3m_vec"]) |
| 46 | + |
| 47 | +############################################################################## |
| 48 | +# Let's check the number of parameters of the policy: |
| 49 | +# |
| 50 | +print("number of params:", len(list(policy.parameters()))) |
| 51 | + |
| 52 | +############################################################################## |
| 53 | +# We collect a rollout of 32 steps and print its output: |
| 54 | +# |
| 55 | +rollout = env_transformed.rollout(32, policy) |
| 56 | +print("rollout with transform:", rollout) |
| 57 | + |
| 58 | +############################################################################## |
| 59 | +# For fine tuning, we integrate the transform in the policy after making the parameters |
| 60 | +# trainable. In practice, it may be wiser to restrict this to a subset of the parameters (say the last layer |
| 61 | +# of the MLP). |
| 62 | +# |
| 63 | +r3m.train() |
| 64 | +policy = TensorDictSequential(r3m, policy) |
| 65 | +print("number of params after r3m is integrated:", len(list(policy.parameters()))) |
| 66 | + |
| 67 | +############################################################################## |
| 68 | +# Again, we collect a rollout with R3M. The structure of the output has changed slightly, as now |
| 69 | +# the environment returns pixels (and not an embedding). The embedding "r3m_vec" is an intermediate |
| 70 | +# result of our policy. |
| 71 | +# |
| 72 | +rollout = base_env.rollout(32, policy) |
| 73 | +print("rollout, fine tuning:", rollout) |
| 74 | + |
| 75 | +############################################################################## |
| 76 | +# The easyness with which we have swapped the transform from the env to the policy |
| 77 | +# is due to the fact that both behave like TensorDictModule: they have a set of `"in_keys"` and |
| 78 | +# `"out_keys"` that make it easy to read and write output in different context. |
| 79 | +# |
0 commit comments