6
6
import torch .nn .functional as F
7
7
from einops import rearrange , repeat
8
8
9
- from dalle_pytorch .cache import Cached
10
-
11
9
# helpers
12
10
13
11
def exists (val ):
@@ -41,6 +39,8 @@ def apply_rotary_emb(freqs, t):
41
39
return torch .cat ((t , t_right ), dim = - 1 )
42
40
43
41
def apply_pos_emb (pos_emb , qkv ):
42
+ n = qkv [0 ].shape [- 2 ]
43
+ pos_emb = pos_emb [..., :n , :]
44
44
return tuple (map (lambda t : apply_rotary_emb (pos_emb , t ), qkv ))
45
45
46
46
# classes
@@ -65,30 +65,24 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
65
65
def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
66
66
b , n , _ , h , device = * x .shape , self .heads , x .device
67
67
softmax = torch .softmax if not self .stable else stable_softmax
68
+ using_cache = exists (cache ) and cache_key in cache
68
69
69
- qkv_key = f'{ cache_key } _qkv'
70
- if exists (cache ) and qkv_key in cache :
71
- qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
72
- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
70
+ qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
71
+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
73
72
74
- if exists (rotary_pos_emb ):
75
- q , k , v = apply_pos_emb (rotary_pos_emb [..., n - 1 :n , :], (q , k , v )) # FIXME: Fix rotary index here
73
+ if exists (rotary_pos_emb ):
74
+ if using_cache :
75
+ rotary_pos_emb = rotary_pos_emb [..., n - 1 :, :] # FIXME: Fix rotary index here
76
+ q , k , v = apply_pos_emb (rotary_pos_emb , (q , k , v ))
76
77
77
- q *= self .scale
78
+ q = q * self .scale
78
79
79
- k_top , v_top = cache [qkv_key ]
80
+ if using_cache :
81
+ k_top , v_top = cache [cache_key ]
80
82
k = torch .cat ([k_top , k ], dim = - 2 )
81
83
v = torch .cat ([v_top , v ], dim = - 2 )
82
- else :
83
- qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
84
- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
85
-
86
- if exists (rotary_pos_emb ):
87
- q , k , v = apply_pos_emb (rotary_pos_emb [..., :n , :], (q , k , v ))
88
-
89
- q *= self .scale
90
84
if exists (cache ):
91
- cache [qkv_key ] = ( k , v )
85
+ cache [cache_key ] = k , v
92
86
93
87
dots = q @ k .swapaxes (- 1 , - 2 )
94
88
mask_value = max_neg_value (dots )
@@ -98,17 +92,16 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
98
92
dots .masked_fill_ (~ mask , mask_value )
99
93
del mask
100
94
101
- # if self.causal: # TODO:
102
- # i, j = dots.shape[-2:]
103
- # mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
104
- # dots.masked_fill_(mask, mask_value)
95
+ if self .causal and not using_cache : # causality is naturally enforced if we run the cached inference
96
+ i , j = dots .shape [- 2 :]
97
+ mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
98
+ dots .masked_fill_ (mask , mask_value )
105
99
106
100
attn = softmax (dots , dim = - 1 )
107
101
108
102
out = attn @ v
109
103
out = rearrange (out , 'b h n d -> b n (h d)' )
110
104
out = self .to_out (out )
111
-
112
105
return out
113
106
114
107
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
@@ -128,14 +121,14 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
128
121
129
122
self .stable = stable
130
123
131
- self .to_qkv = Cached ( nn .Linear (dim , inner_dim * 3 , bias = False ) )
124
+ self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
132
125
133
- self .to_out = Cached ( nn .Sequential (
126
+ self .to_out = nn .Sequential (
134
127
nn .Linear (inner_dim , dim ),
135
128
nn .Dropout (dropout )
136
- ))
129
+ )
137
130
138
- def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
131
+ def forward (self , x , mask = None , rotary_pos_emb = None ):
139
132
b , n , _ , h , img_size , kernel_size , dilation , seq_len , device = * x .shape , self .heads , self .image_size , self .kernel_size , self .dilation , self .seq_len , x .device
140
133
softmax = torch .softmax if not self .stable else stable_softmax
141
134
@@ -152,7 +145,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
152
145
153
146
# derive query / keys / values
154
147
155
- qkv = self .to_qkv (x , cache = cache , cache_key = f' { cache_key } _qkv' ).chunk (3 , dim = - 1 )
148
+ qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
156
149
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
157
150
158
151
if exists (rotary_pos_emb ):
@@ -229,7 +222,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
229
222
out = torch .cat ((out_text , out_image ), dim = 1 )
230
223
231
224
out = rearrange (out , '(b h) n d -> b n (h d)' , h = h )
232
- out = self .to_out (out , cache = cache , cache_key = f' { cache_key } _out' )
225
+ out = self .to_out (out )
233
226
return out [:, :n ]
234
227
235
228
# sparse axial causal attention
@@ -248,14 +241,14 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
248
241
249
242
self .stable = stable
250
243
251
- self .to_qkv = Cached ( nn .Linear (dim , inner_dim * 3 , bias = False ) )
244
+ self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
252
245
253
- self .to_out = Cached ( nn .Sequential (
246
+ self .to_out = nn .Sequential (
254
247
nn .Linear (inner_dim , dim ),
255
248
nn .Dropout (dropout )
256
- ))
249
+ )
257
250
258
- def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
251
+ def forward (self , x , mask = None , rotary_pos_emb = None ):
259
252
b , n , _ , h , img_size , axis , seq_len , device = * x .shape , self .heads , self .image_size , self .axis , self .seq_len , x .device
260
253
softmax = torch .softmax if not self .stable else stable_softmax
261
254
@@ -272,7 +265,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
272
265
273
266
# derive queries / keys / values
274
267
275
- qkv = self .to_qkv (x , cache = cache , cache_key = f' { cache_key } _qkv' ).chunk (3 , dim = - 1 )
268
+ qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
276
269
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
277
270
278
271
if exists (rotary_pos_emb ):
@@ -284,15 +277,15 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
284
277
285
278
# text attention
286
279
287
- dots_text = q_text @ k_text . swapaxes ( - 1 , - 2 )
280
+ dots_text = einsum ( 'b i d, b j d -> b i j' , q_text , k_text )
288
281
mask_value = max_neg_value (dots_text )
289
282
290
283
i , j = dots_text .shape [- 2 :]
291
284
text_causal_mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
292
285
dots_text .masked_fill_ (text_causal_mask , mask_value )
293
286
294
287
attn_text = softmax (dots_text , dim = - 1 )
295
- out_text = attn_text @ v_text
288
+ out_text = einsum ( 'b i j, b j d -> b i d' , attn_text , v_text )
296
289
297
290
# image attention
298
291
@@ -305,8 +298,8 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
305
298
306
299
# similarity
307
300
308
- dots_image_to_image = q_img @ k_img . swapaxes ( - 1 , - 2 )
309
- dots_image_to_text = q_img @ k_text [:, None ]. swapaxes ( - 1 , - 2 )
301
+ dots_image_to_image = einsum ( 'b x i d, b x j d -> b x i j' , q_img , k_img )
302
+ dots_image_to_text = einsum ( 'b x i d, b j d -> b x i j' , q_img , k_text )
310
303
311
304
dots = torch .cat ((dots_image_to_text , dots_image_to_image ), dim = - 1 )
312
305
@@ -329,8 +322,8 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
329
322
330
323
attn_image_to_text , attn_image_to_image = attn [..., :text_len ], attn [..., text_len :]
331
324
332
- out_image_to_image = attn_image_to_image @ v_img
333
- out_image_to_text = attn_image_to_text @ v_text [:, None ]
325
+ out_image_to_image = einsum ( 'b x i j, b x j d -> b x i d' , attn_image_to_image , v_img )
326
+ out_image_to_text = einsum ( 'b x i j, b j d -> b x i d' , attn_image_to_text , v_text )
334
327
335
328
out_image = out_image_to_image + out_image_to_text
336
329
@@ -343,7 +336,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
343
336
out = torch .cat ((out_text , out_image ), dim = 1 )
344
337
345
338
out = rearrange (out , '(b h) n d -> b n (h d)' , h = h )
346
- out = self .to_out (out , cache = cache , cache_key = f' { cache_key } _out' )
339
+ out = self .to_out (out )
347
340
return out [:, :n ]
348
341
349
342
# microsoft sparse attention CUDA kernel
0 commit comments