Skip to content

Commit f07015d

Browse files
[Doc] Tensordictmodule tutorial (#267)
1 parent a0af473 commit f07015d

File tree

4 files changed

+1382
-2
lines changed

4 files changed

+1382
-2
lines changed

tutorials/README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ Get a sense of TorchRL functionalities through our tutorials.
44

55
For an overview of TorchRL, try the [TorchRL demo](demo.ipynb).
66

7-
Make sure you test the [TensorDict demo](tensordict.ipynb) to see what TensorDict
7+
Make sure you test the [TensorDict tutorial](tensordict.ipynb) to see what TensorDict
88
is about and what it can do.
99

10-
Checkout the [environment demo](envs.ipynb) for a deep dive in the envs
10+
To understand how to use `TensorDict` with pytorch modules, make sure to check out the [TensorDictModule tutorial](tensordictmodule.ipynb).
11+
12+
Checkout the [environment tutorial](envs.ipynb) for a deep dive in the envs
1113
functionalities.

tutorials/media/transformer.png

367 KB
Loading

tutorials/src/transformer.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import torch.nn as nn
2+
3+
4+
class TokensToQKV(nn.Module):
5+
def __init__(self, to_dim, from_dim, latent_dim):
6+
super().__init__()
7+
self.q = nn.Linear(to_dim, latent_dim)
8+
self.k = nn.Linear(from_dim, latent_dim)
9+
self.v = nn.Linear(from_dim, latent_dim)
10+
11+
def forward(self, X_to, X_from):
12+
Q = self.q(X_to)
13+
K = self.k(X_from)
14+
V = self.v(X_from)
15+
return Q, K, V
16+
17+
18+
class SplitHeads(nn.Module):
19+
def __init__(self, num_heads):
20+
super().__init__()
21+
self.num_heads = num_heads
22+
23+
def forward(self, Q, K, V):
24+
batch_size, to_num, latent_dim = Q.shape
25+
_, from_num, _ = K.shape
26+
d_tensor = latent_dim // self.num_heads
27+
Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
28+
K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
29+
V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
30+
return Q, K, V
31+
32+
33+
class Attention(nn.Module):
34+
def __init__(self, latent_dim, to_dim):
35+
super().__init__()
36+
self.softmax = nn.Softmax(dim=-1)
37+
self.out = nn.Linear(latent_dim, to_dim)
38+
39+
def forward(self, Q, K, V):
40+
batch_size, n_heads, to_num, d_in = Q.shape
41+
attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
42+
out = attn @ V
43+
out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in))
44+
return out, attn
45+
46+
47+
class SkipLayerNorm(nn.Module):
48+
def __init__(self, to_len, to_dim):
49+
super().__init__()
50+
self.layer_norm = nn.LayerNorm((to_len, to_dim))
51+
52+
def forward(self, x_0, x_1):
53+
return self.layer_norm(x_0 + x_1)
54+
55+
56+
class FFN(nn.Module):
57+
def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):
58+
super().__init__()
59+
self.FFN = nn.Sequential(
60+
nn.Linear(to_dim, hidden_dim),
61+
nn.ReLU(),
62+
nn.Linear(hidden_dim, to_dim),
63+
nn.Dropout(dropout_rate),
64+
)
65+
66+
def forward(self, X):
67+
return self.FFN(X)
68+
69+
70+
class AttentionBlock(nn.Module):
71+
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
72+
super().__init__()
73+
self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim)
74+
self.split_heads = SplitHeads(num_heads)
75+
self.attention = Attention(latent_dim, to_dim)
76+
self.skip = SkipLayerNorm(to_len, to_dim)
77+
78+
def forward(self, X_to, X_from):
79+
Q, K, V = self.tokens_to_qkv(X_to, X_from)
80+
Q, K, V = self.split_heads(Q, K, V)
81+
out, attention = self.attention(Q, K, V)
82+
out = self.skip(X_to, out)
83+
return out
84+
85+
86+
class EncoderTransformerBlock(nn.Module):
87+
def __init__(self, to_dim, to_len, latent_dim, num_heads):
88+
super().__init__()
89+
self.attention_block = AttentionBlock(
90+
to_dim, to_len, to_dim, latent_dim, num_heads
91+
)
92+
self.FFN = FFN(to_dim, 4 * to_dim)
93+
self.skip = SkipLayerNorm(to_len, to_dim)
94+
95+
def forward(self, X_to):
96+
X_to = self.attention_block(X_to, X_to)
97+
X_out = self.FFN(X_to)
98+
return self.skip(X_out, X_to)
99+
100+
101+
class DecoderTransformerBlock(nn.Module):
102+
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
103+
super().__init__()
104+
self.attention_block = AttentionBlock(
105+
to_dim, to_len, from_dim, latent_dim, num_heads
106+
)
107+
self.encoder_block = EncoderTransformerBlock(
108+
to_dim, to_len, latent_dim, num_heads
109+
)
110+
111+
def forward(self, X_to, X_from):
112+
X_to = self.attention_block(X_to, X_from)
113+
X_to = self.encoder_block(X_to)
114+
return X_to
115+
116+
117+
class TransformerEncoder(nn.Module):
118+
def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads):
119+
super().__init__()
120+
self.encoder = nn.ModuleList(
121+
[
122+
EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads)
123+
for i in range(num_blocks)
124+
]
125+
)
126+
127+
def forward(self, X_to):
128+
for i in range(len(self.encoder)):
129+
X_to = self.encoder[i](X_to)
130+
return X_to
131+
132+
133+
class TransformerDecoder(nn.Module):
134+
def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads):
135+
super().__init__()
136+
self.decoder = nn.ModuleList(
137+
[
138+
DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads)
139+
for i in range(num_blocks)
140+
]
141+
)
142+
143+
def forward(self, X_to, X_from):
144+
for i in range(len(self.decoder)):
145+
X_to = self.decoder[i](X_to, X_from)
146+
return X_to
147+
148+
149+
class Transformer(nn.Module):
150+
def __init__(
151+
self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads
152+
):
153+
super().__init__()
154+
self.encoder = TransformerEncoder(
155+
num_blocks, to_dim, to_len, latent_dim, num_heads
156+
)
157+
self.decoder = TransformerDecoder(
158+
num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads
159+
)
160+
161+
def forward(self, X_to, X_from):
162+
X_to = self.encoder(X_to)
163+
X_out = self.decoder(X_from, X_to)
164+
return X_out

0 commit comments

Comments
 (0)