@@ -333,16 +333,171 @@ def forward(self, x):
333
333
#
334
334
# We have let the positional encoders aside for simplicity.
335
335
#
336
- # Let's first import the classical transformers blocks
337
- # (see ``src/transformer.py`` for more details.)
338
-
339
- from tutorials .src .transformer import (
340
- Attention ,
341
- FFN ,
342
- SkipLayerNorm ,
343
- SplitHeads ,
344
- TokensToQKV ,
345
- )
336
+ # Let's re-write the classical transformers blocks:
337
+
338
+
339
+ class TokensToQKV (nn .Module ):
340
+ def __init__ (self , to_dim , from_dim , latent_dim ):
341
+ super ().__init__ ()
342
+ self .q = nn .Linear (to_dim , latent_dim )
343
+ self .k = nn .Linear (from_dim , latent_dim )
344
+ self .v = nn .Linear (from_dim , latent_dim )
345
+
346
+ def forward (self , X_to , X_from ):
347
+ Q = self .q (X_to )
348
+ K = self .k (X_from )
349
+ V = self .v (X_from )
350
+ return Q , K , V
351
+
352
+
353
+ class SplitHeads (nn .Module ):
354
+ def __init__ (self , num_heads ):
355
+ super ().__init__ ()
356
+ self .num_heads = num_heads
357
+
358
+ def forward (self , Q , K , V ):
359
+ batch_size , to_num , latent_dim = Q .shape
360
+ _ , from_num , _ = K .shape
361
+ d_tensor = latent_dim // self .num_heads
362
+ Q = Q .reshape (batch_size , to_num , self .num_heads , d_tensor ).transpose (1 , 2 )
363
+ K = K .reshape (batch_size , from_num , self .num_heads , d_tensor ).transpose (1 , 2 )
364
+ V = V .reshape (batch_size , from_num , self .num_heads , d_tensor ).transpose (1 , 2 )
365
+ return Q , K , V
366
+
367
+
368
+ class Attention (nn .Module ):
369
+ def __init__ (self , latent_dim , to_dim ):
370
+ super ().__init__ ()
371
+ self .softmax = nn .Softmax (dim = - 1 )
372
+ self .out = nn .Linear (latent_dim , to_dim )
373
+
374
+ def forward (self , Q , K , V ):
375
+ batch_size , n_heads , to_num , d_in = Q .shape
376
+ attn = self .softmax (Q @ K .transpose (2 , 3 ) / d_in )
377
+ out = attn @ V
378
+ out = self .out (out .transpose (1 , 2 ).reshape (batch_size , to_num , n_heads * d_in ))
379
+ return out , attn
380
+
381
+
382
+ class SkipLayerNorm (nn .Module ):
383
+ def __init__ (self , to_len , to_dim ):
384
+ super ().__init__ ()
385
+ self .layer_norm = nn .LayerNorm ((to_len , to_dim ))
386
+
387
+ def forward (self , x_0 , x_1 ):
388
+ return self .layer_norm (x_0 + x_1 )
389
+
390
+
391
+ class FFN (nn .Module ):
392
+ def __init__ (self , to_dim , hidden_dim , dropout_rate = 0.2 ):
393
+ super ().__init__ ()
394
+ self .FFN = nn .Sequential (
395
+ nn .Linear (to_dim , hidden_dim ),
396
+ nn .ReLU (),
397
+ nn .Linear (hidden_dim , to_dim ),
398
+ nn .Dropout (dropout_rate ),
399
+ )
400
+
401
+ def forward (self , X ):
402
+ return self .FFN (X )
403
+
404
+
405
+ class AttentionBlock (nn .Module ):
406
+ def __init__ (self , to_dim , to_len , from_dim , latent_dim , num_heads ):
407
+ super ().__init__ ()
408
+ self .tokens_to_qkv = TokensToQKV (to_dim , from_dim , latent_dim )
409
+ self .split_heads = SplitHeads (num_heads )
410
+ self .attention = Attention (latent_dim , to_dim )
411
+ self .skip = SkipLayerNorm (to_len , to_dim )
412
+
413
+ def forward (self , X_to , X_from ):
414
+ Q , K , V = self .tokens_to_qkv (X_to , X_from )
415
+ Q , K , V = self .split_heads (Q , K , V )
416
+ out , attention = self .attention (Q , K , V )
417
+ out = self .skip (X_to , out )
418
+ return out
419
+
420
+
421
+ class EncoderTransformerBlock (nn .Module ):
422
+ def __init__ (self , to_dim , to_len , latent_dim , num_heads ):
423
+ super ().__init__ ()
424
+ self .attention_block = AttentionBlock (
425
+ to_dim , to_len , to_dim , latent_dim , num_heads
426
+ )
427
+ self .FFN = FFN (to_dim , 4 * to_dim )
428
+ self .skip = SkipLayerNorm (to_len , to_dim )
429
+
430
+ def forward (self , X_to ):
431
+ X_to = self .attention_block (X_to , X_to )
432
+ X_out = self .FFN (X_to )
433
+ return self .skip (X_out , X_to )
434
+
435
+
436
+ class DecoderTransformerBlock (nn .Module ):
437
+ def __init__ (self , to_dim , to_len , from_dim , latent_dim , num_heads ):
438
+ super ().__init__ ()
439
+ self .attention_block = AttentionBlock (
440
+ to_dim , to_len , from_dim , latent_dim , num_heads
441
+ )
442
+ self .encoder_block = EncoderTransformerBlock (
443
+ to_dim , to_len , latent_dim , num_heads
444
+ )
445
+
446
+ def forward (self , X_to , X_from ):
447
+ X_to = self .attention_block (X_to , X_from )
448
+ X_to = self .encoder_block (X_to )
449
+ return X_to
450
+
451
+
452
+ class TransformerEncoder (nn .Module ):
453
+ def __init__ (self , num_blocks , to_dim , to_len , latent_dim , num_heads ):
454
+ super ().__init__ ()
455
+ self .encoder = nn .ModuleList (
456
+ [
457
+ EncoderTransformerBlock (to_dim , to_len , latent_dim , num_heads )
458
+ for i in range (num_blocks )
459
+ ]
460
+ )
461
+
462
+ def forward (self , X_to ):
463
+ for i in range (len (self .encoder )):
464
+ X_to = self .encoder [i ](X_to )
465
+ return X_to
466
+
467
+
468
+ class TransformerDecoder (nn .Module ):
469
+ def __init__ (self , num_blocks , to_dim , to_len , from_dim , latent_dim , num_heads ):
470
+ super ().__init__ ()
471
+ self .decoder = nn .ModuleList (
472
+ [
473
+ DecoderTransformerBlock (to_dim , to_len , from_dim , latent_dim , num_heads )
474
+ for i in range (num_blocks )
475
+ ]
476
+ )
477
+
478
+ def forward (self , X_to , X_from ):
479
+ for i in range (len (self .decoder )):
480
+ X_to = self .decoder [i ](X_to , X_from )
481
+ return X_to
482
+
483
+
484
+ class Transformer (nn .Module ):
485
+ def __init__ (
486
+ self , num_blocks , to_dim , to_len , from_dim , from_len , latent_dim , num_heads
487
+ ):
488
+ super ().__init__ ()
489
+ self .encoder = TransformerEncoder (
490
+ num_blocks , to_dim , to_len , latent_dim , num_heads
491
+ )
492
+ self .decoder = TransformerDecoder (
493
+ num_blocks , from_dim , from_len , to_dim , latent_dim , num_heads
494
+ )
495
+
496
+ def forward (self , X_to , X_from ):
497
+ X_to = self .encoder (X_to )
498
+ X_out = self .decoder (X_from , X_to )
499
+ return X_out
500
+
346
501
347
502
###############################################################################
348
503
# We first create the ``AttentionBlockTensorDict``, the attention block using
@@ -608,8 +763,6 @@ def __init__(
608
763
# Benchmarking
609
764
# ------------------------------
610
765
611
- from tutorials .src .transformer import Transformer
612
-
613
766
###############################################################################
614
767
615
768
to_dim = 5
0 commit comments