@@ -72,7 +72,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
72
72
if exists (cache ) and dots_key in cache :
73
73
topleft = cache [dots_key ]
74
74
top = F .pad (topleft , (0 , 1 ), value = mask_value )
75
- bottom = q [..., n - 1 :, :] @ k .swapaxes (- 1 , - 2 )
75
+ bottom = q [..., n - 1 :n , :] @ k .swapaxes (- 1 , - 2 )
76
76
dots = torch .cat ([top , bottom ], dim = - 2 )
77
77
else :
78
78
dots = q @ k .swapaxes (- 1 , - 2 )
@@ -94,10 +94,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
94
94
out_key = f'{ cache_key } _out'
95
95
if exists (cache ) and out_key in cache :
96
96
top = cache [out_key ]
97
- assert top .shape [- 2 ] == n - 1
98
-
99
97
bottom = attn [..., n - 1 :n , :] @ v
100
-
101
98
out = torch .cat ([top , bottom ], dim = - 2 )
102
99
else :
103
100
out = attn @ v
@@ -231,8 +228,6 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
231
228
232
229
# sparse axial causal attention
233
230
234
- from time import time
235
-
236
231
class SparseAxialCausalAttention (nn .Module ):
237
232
def __init__ (self , dim , seq_len , image_size = 32 , axis = 0 , heads = 8 , dim_head = 64 , dropout = 0. , stable = False , ** kwargs ):
238
233
super ().__init__ ()
@@ -271,10 +266,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
271
266
272
267
# derive queries / keys / values
273
268
274
- t = time ()
275
269
qkv = self .to_qkv (x , cache = cache , cache_key = f'{ cache_key } _qkv' ).chunk (3 , dim = - 1 )
276
- print (f'Time 1: { time () - t :.5f} sec' )
277
- t = time ()
278
270
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
279
271
280
272
if exists (rotary_pos_emb ):
@@ -286,7 +278,6 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
286
278
287
279
# text attention
288
280
289
- print ('shapes 1:' , q_text .shape , k_text .swapaxes (- 1 , - 2 ).shape )
290
281
dots_text = q_text @ k_text .swapaxes (- 1 , - 2 )
291
282
mask_value = max_neg_value (dots_text )
292
283
@@ -295,7 +286,6 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
295
286
dots_text .masked_fill_ (text_causal_mask , mask_value )
296
287
297
288
attn_text = softmax (dots_text , dim = - 1 )
298
- print ('shapes 2:' , attn_text .shape , v_text .shape )
299
289
out_text = attn_text @ v_text
300
290
301
291
# image attention
@@ -309,9 +299,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
309
299
310
300
# similarity
311
301
312
- print ('shapes 3:' , q_img .shape , k_img .swapaxes (- 1 , - 2 ).shape )
313
302
dots_image_to_image = q_img @ k_img .swapaxes (- 1 , - 2 )
314
- print ('shapes 4:' , q_img .shape , k_text [:, None ].swapaxes (- 1 , - 2 ).shape )
315
303
dots_image_to_text = q_img @ k_text [:, None ].swapaxes (- 1 , - 2 )
316
304
317
305
dots = torch .cat ((dots_image_to_text , dots_image_to_image ), dim = - 1 )
@@ -335,9 +323,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
335
323
336
324
attn_image_to_text , attn_image_to_image = attn [..., :text_len ], attn [..., text_len :]
337
325
338
- print ('shapes 5:' , attn_image_to_image .shape , v_img .shape )
339
326
out_image_to_image = attn_image_to_image @ v_img
340
- print ('shapes 6:' , attn_image_to_text .shape , v_text [:, None ].shape )
341
327
out_image_to_text = attn_image_to_text @ v_text [:, None ]
342
328
343
329
out_image = out_image_to_image + out_image_to_text
@@ -351,10 +337,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
351
337
out = torch .cat ((out_text , out_image ), dim = 1 )
352
338
353
339
out = rearrange (out , '(b h) n d -> b n (h d)' , h = h )
354
- print (f'Time 2: { time () - t :.5f} sec' )
355
- t = time ()
356
340
out = self .to_out (out , cache = cache , cache_key = f'{ cache_key } _out' )
357
- print (f'Time 3: { time () - t :.5f} sec\n ' )
358
341
return out [:, :n ]
359
342
360
343
# microsoft sparse attention CUDA kernel
0 commit comments