Skip to content

Commit 6655b2d

Browse files
authored
[Doc, BugFix] Fix tensordictmodule tutorial (#819)
1 parent b99fb6d commit 6655b2d

File tree

5 files changed

+169
-192
lines changed

5 files changed

+169
-192
lines changed

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,5 +182,5 @@
182182

183183
generate_knowledge_base_references("../../knowledge_base")
184184
generate_tutorial_references("../../tutorials/sphinx-tutorials/", "tutorial")
185-
generate_tutorial_references("../../tutorials/src/", "src")
185+
# generate_tutorial_references("../../tutorials/src/", "src")
186186
generate_tutorial_references("../../tutorials/media/", "media")

docs/source/content_generation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
FILE_DIR = os.path.dirname(__file__)
77
KNOWLEDGE_GEN_DIR = "reference/generated/knowledge_base"
88
TUTORIALS_GEN_DIR = "reference/generated/tutorials"
9-
TUTORIALS_SRC_GEN_DIR = "reference/generated/tutorials/src"
9+
# TUTORIALS_SRC_GEN_DIR = "reference/generated/tutorials/src"
1010
TUTORIALS_MEDIA_GEN_DIR = "reference/generated/tutorials/media"
1111

1212

@@ -71,8 +71,8 @@ def generate_tutorial_references(tutorial_path: str, file_type: str) -> None:
7171
# Create target dir
7272
if file_type == "tutorial":
7373
target_path = os.path.join(FILE_DIR, TUTORIALS_GEN_DIR)
74-
elif file_type == "src":
75-
target_path = os.path.join(FILE_DIR, TUTORIALS_SRC_GEN_DIR)
74+
# elif file_type == "src":
75+
# target_path = os.path.join(FILE_DIR, TUTORIALS_SRC_GEN_DIR)
7676
else:
7777
target_path = os.path.join(FILE_DIR, TUTORIALS_MEDIA_GEN_DIR)
7878
Path(target_path).mkdir(parents=True, exist_ok=True)

tutorials/sphinx-tutorials/tensordict_module.py

Lines changed: 165 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -333,16 +333,171 @@ def forward(self, x):
333333
#
334334
# We have let the positional encoders aside for simplicity.
335335
#
336-
# Let's first import the classical transformers blocks
337-
# (see ``src/transformer.py`` for more details.)
338-
339-
from tutorials.src.transformer import (
340-
Attention,
341-
FFN,
342-
SkipLayerNorm,
343-
SplitHeads,
344-
TokensToQKV,
345-
)
336+
# Let's re-write the classical transformers blocks:
337+
338+
339+
class TokensToQKV(nn.Module):
340+
def __init__(self, to_dim, from_dim, latent_dim):
341+
super().__init__()
342+
self.q = nn.Linear(to_dim, latent_dim)
343+
self.k = nn.Linear(from_dim, latent_dim)
344+
self.v = nn.Linear(from_dim, latent_dim)
345+
346+
def forward(self, X_to, X_from):
347+
Q = self.q(X_to)
348+
K = self.k(X_from)
349+
V = self.v(X_from)
350+
return Q, K, V
351+
352+
353+
class SplitHeads(nn.Module):
354+
def __init__(self, num_heads):
355+
super().__init__()
356+
self.num_heads = num_heads
357+
358+
def forward(self, Q, K, V):
359+
batch_size, to_num, latent_dim = Q.shape
360+
_, from_num, _ = K.shape
361+
d_tensor = latent_dim // self.num_heads
362+
Q = Q.reshape(batch_size, to_num, self.num_heads, d_tensor).transpose(1, 2)
363+
K = K.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
364+
V = V.reshape(batch_size, from_num, self.num_heads, d_tensor).transpose(1, 2)
365+
return Q, K, V
366+
367+
368+
class Attention(nn.Module):
369+
def __init__(self, latent_dim, to_dim):
370+
super().__init__()
371+
self.softmax = nn.Softmax(dim=-1)
372+
self.out = nn.Linear(latent_dim, to_dim)
373+
374+
def forward(self, Q, K, V):
375+
batch_size, n_heads, to_num, d_in = Q.shape
376+
attn = self.softmax(Q @ K.transpose(2, 3) / d_in)
377+
out = attn @ V
378+
out = self.out(out.transpose(1, 2).reshape(batch_size, to_num, n_heads * d_in))
379+
return out, attn
380+
381+
382+
class SkipLayerNorm(nn.Module):
383+
def __init__(self, to_len, to_dim):
384+
super().__init__()
385+
self.layer_norm = nn.LayerNorm((to_len, to_dim))
386+
387+
def forward(self, x_0, x_1):
388+
return self.layer_norm(x_0 + x_1)
389+
390+
391+
class FFN(nn.Module):
392+
def __init__(self, to_dim, hidden_dim, dropout_rate=0.2):
393+
super().__init__()
394+
self.FFN = nn.Sequential(
395+
nn.Linear(to_dim, hidden_dim),
396+
nn.ReLU(),
397+
nn.Linear(hidden_dim, to_dim),
398+
nn.Dropout(dropout_rate),
399+
)
400+
401+
def forward(self, X):
402+
return self.FFN(X)
403+
404+
405+
class AttentionBlock(nn.Module):
406+
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
407+
super().__init__()
408+
self.tokens_to_qkv = TokensToQKV(to_dim, from_dim, latent_dim)
409+
self.split_heads = SplitHeads(num_heads)
410+
self.attention = Attention(latent_dim, to_dim)
411+
self.skip = SkipLayerNorm(to_len, to_dim)
412+
413+
def forward(self, X_to, X_from):
414+
Q, K, V = self.tokens_to_qkv(X_to, X_from)
415+
Q, K, V = self.split_heads(Q, K, V)
416+
out, attention = self.attention(Q, K, V)
417+
out = self.skip(X_to, out)
418+
return out
419+
420+
421+
class EncoderTransformerBlock(nn.Module):
422+
def __init__(self, to_dim, to_len, latent_dim, num_heads):
423+
super().__init__()
424+
self.attention_block = AttentionBlock(
425+
to_dim, to_len, to_dim, latent_dim, num_heads
426+
)
427+
self.FFN = FFN(to_dim, 4 * to_dim)
428+
self.skip = SkipLayerNorm(to_len, to_dim)
429+
430+
def forward(self, X_to):
431+
X_to = self.attention_block(X_to, X_to)
432+
X_out = self.FFN(X_to)
433+
return self.skip(X_out, X_to)
434+
435+
436+
class DecoderTransformerBlock(nn.Module):
437+
def __init__(self, to_dim, to_len, from_dim, latent_dim, num_heads):
438+
super().__init__()
439+
self.attention_block = AttentionBlock(
440+
to_dim, to_len, from_dim, latent_dim, num_heads
441+
)
442+
self.encoder_block = EncoderTransformerBlock(
443+
to_dim, to_len, latent_dim, num_heads
444+
)
445+
446+
def forward(self, X_to, X_from):
447+
X_to = self.attention_block(X_to, X_from)
448+
X_to = self.encoder_block(X_to)
449+
return X_to
450+
451+
452+
class TransformerEncoder(nn.Module):
453+
def __init__(self, num_blocks, to_dim, to_len, latent_dim, num_heads):
454+
super().__init__()
455+
self.encoder = nn.ModuleList(
456+
[
457+
EncoderTransformerBlock(to_dim, to_len, latent_dim, num_heads)
458+
for i in range(num_blocks)
459+
]
460+
)
461+
462+
def forward(self, X_to):
463+
for i in range(len(self.encoder)):
464+
X_to = self.encoder[i](X_to)
465+
return X_to
466+
467+
468+
class TransformerDecoder(nn.Module):
469+
def __init__(self, num_blocks, to_dim, to_len, from_dim, latent_dim, num_heads):
470+
super().__init__()
471+
self.decoder = nn.ModuleList(
472+
[
473+
DecoderTransformerBlock(to_dim, to_len, from_dim, latent_dim, num_heads)
474+
for i in range(num_blocks)
475+
]
476+
)
477+
478+
def forward(self, X_to, X_from):
479+
for i in range(len(self.decoder)):
480+
X_to = self.decoder[i](X_to, X_from)
481+
return X_to
482+
483+
484+
class Transformer(nn.Module):
485+
def __init__(
486+
self, num_blocks, to_dim, to_len, from_dim, from_len, latent_dim, num_heads
487+
):
488+
super().__init__()
489+
self.encoder = TransformerEncoder(
490+
num_blocks, to_dim, to_len, latent_dim, num_heads
491+
)
492+
self.decoder = TransformerDecoder(
493+
num_blocks, from_dim, from_len, to_dim, latent_dim, num_heads
494+
)
495+
496+
def forward(self, X_to, X_from):
497+
X_to = self.encoder(X_to)
498+
X_out = self.decoder(X_from, X_to)
499+
return X_out
500+
346501

347502
###############################################################################
348503
# We first create the ``AttentionBlockTensorDict``, the attention block using
@@ -608,8 +763,6 @@ def __init__(
608763
# Benchmarking
609764
# ------------------------------
610765

611-
from tutorials.src.transformer import Transformer
612-
613766
###############################################################################
614767

615768
to_dim = 5

tutorials/src/envs.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

tutorials/src/transformer.py

Lines changed: 0 additions & 164 deletions
This file was deleted.

0 commit comments

Comments
 (0)