8
8
9
9
from dalle_pytorch .cache import Cached
10
10
11
- from rotary_embedding_torch import apply_rotary_emb
12
-
13
11
# helpers
14
12
15
13
def exists (val ):
@@ -31,9 +29,18 @@ def stable_softmax(t, dim = -1, alpha = 32 ** 2):
31
29
t = t - torch .amax (t , dim = dim , keepdim = True )
32
30
return (t * alpha ).softmax (dim = dim )
33
31
32
+ def rotate_half (x ):
33
+ d = x .shape [- 1 ] // 2
34
+ return torch .cat ([- x [..., d :], x [..., :d ]], dim = - 1 )
35
+
36
+ def apply_rotary_emb (freqs , t ):
37
+ rot_dim = freqs .shape [- 1 ]
38
+ assert rot_dim <= t .shape [- 1 ], f'feature dimension { t .shape [- 1 ]} is not of sufficient size to rotate in all the positions { rot_dim } '
39
+ t , t_right = t [..., :rot_dim ], t [..., rot_dim :]
40
+ t = (t * freqs .cos ()) + (rotate_half (t ) * freqs .sin ())
41
+ return torch .cat ((t , t_right ), dim = - 1 )
42
+
34
43
def apply_pos_emb (pos_emb , qkv ):
35
- n = qkv [0 ].shape [- 2 ]
36
- pos_emb = pos_emb [..., :n , :]
37
44
return tuple (map (lambda t : apply_rotary_emb (pos_emb , t ), qkv ))
38
45
39
46
# classes
@@ -49,7 +56,7 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
49
56
self .stable = stable
50
57
self .causal = causal
51
58
52
- self .to_qkv = Cached ( nn .Linear (dim , inner_dim * 3 , bias = False ) )
59
+ self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
53
60
self .to_out = Cached (nn .Sequential (
54
61
nn .Linear (inner_dim , dim ),
55
62
nn .Dropout (dropout )
@@ -59,42 +66,49 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
59
66
b , n , _ , h , device = * x .shape , self .heads , x .device
60
67
softmax = torch .softmax if not self .stable else stable_softmax
61
68
62
- qkv = self .to_qkv (x , cache = cache , cache_key = f'{ cache_key } _qkv' ).chunk (3 , dim = - 1 )
63
- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
69
+ qkv_key = f'{ cache_key } _qkv'
70
+ if exists (cache ) and qkv_key in cache :
71
+ qkv = self .to_qkv (x [..., n - 1 :n , :]).chunk (3 , dim = - 1 )
72
+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
64
73
65
- if exists (rotary_pos_emb ):
66
- q , k , v = apply_pos_emb (rotary_pos_emb , (q , k , v ))
74
+ if exists (rotary_pos_emb ):
75
+ q , k , v = apply_pos_emb (rotary_pos_emb [..., n - 1 : n , :] , (q , k , v ))
67
76
68
- q = q * self .scale
77
+ q *= self .scale
69
78
70
- mask_value = max_neg_value (q )
71
- dots_key = f'{ cache_key } _dots'
72
- if exists (cache ) and dots_key in cache :
73
- topleft = cache [dots_key ]
74
- top = F .pad (topleft , (0 , 1 ), value = mask_value )
75
- bottom = q [..., n - 1 :n , :] @ k .swapaxes (- 1 , - 2 )
76
- dots = torch .cat ([top , bottom ], dim = - 2 )
79
+ k_top , v_top = cache [qkv_key ]
80
+ k = torch .cat ([k_top , k ], dim = - 2 )
81
+ v = torch .cat ([v_top , v ], dim = - 2 )
77
82
else :
78
- dots = q @ k .swapaxes (- 1 , - 2 )
83
+ qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
84
+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
85
+
86
+ if exists (rotary_pos_emb ):
87
+ q , k , v = apply_pos_emb (rotary_pos_emb [..., :n , :], (q , k , v ))
88
+
89
+ q *= self .scale
79
90
if exists (cache ):
80
- cache [dots_key ] = dots
91
+ cache [qkv_key ] = ( k , v )
81
92
82
- if exists (mask ):
83
- mask = rearrange (mask , 'b j -> b () () j' )
84
- dots .masked_fill_ (~ mask , mask_value )
85
- del mask
93
+ # mask_value = max_neg_value(q)
94
+ dots = q @ k .swapaxes (- 1 , - 2 )
86
95
87
- if self .causal :
88
- i , j = dots .shape [- 2 :]
89
- mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
90
- dots .masked_fill_ (mask , mask_value )
96
+ # if exists(mask): # TODO:
97
+ # mask = rearrange(mask, 'b j -> b () () j')
98
+ # dots.masked_fill_(~mask, mask_value)
99
+ # del mask
100
+
101
+ # if self.causal: # TODO:
102
+ # i, j = dots.shape[-2:]
103
+ # mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
104
+ # dots.masked_fill_(mask, mask_value)
91
105
92
106
attn = softmax (dots , dim = - 1 )
93
107
94
108
out_key = f'{ cache_key } _out'
95
109
if exists (cache ) and out_key in cache :
96
110
top = cache [out_key ]
97
- bottom = attn [..., n - 1 : n , :] @ v
111
+ bottom = attn @ v
98
112
out = torch .cat ([top , bottom ], dim = - 2 )
99
113
else :
100
114
out = attn @ v
0 commit comments