@@ -49,26 +49,35 @@ def __init__(self, dim, seq_len, causal = True, heads = 8, dim_head = 64, dropou
49
49
self .stable = stable
50
50
self .causal = causal
51
51
52
- self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
53
- self .to_out = nn .Sequential (
52
+ self .to_qkv = Cached ( nn .Linear (dim , inner_dim * 3 , bias = False ) )
53
+ self .to_out = Cached ( nn .Sequential (
54
54
nn .Linear (inner_dim , dim ),
55
55
nn .Dropout (dropout )
56
- )
56
+ ))
57
57
58
- def forward (self , x , mask = None , rotary_pos_emb = None ):
58
+ def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
59
59
b , n , _ , h , device = * x .shape , self .heads , x .device
60
60
softmax = torch .softmax if not self .stable else stable_softmax
61
61
62
- qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
62
+ qkv = self .to_qkv (x , cache = cache , cache_key = f' { cache_key } _qkv' ).chunk (3 , dim = - 1 )
63
63
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> b h n d' , h = h ), qkv )
64
64
65
65
if exists (rotary_pos_emb ):
66
66
q , k , v = apply_pos_emb (rotary_pos_emb , (q , k , v ))
67
67
68
68
q = q * self .scale
69
69
70
- dots = torch .einsum ('b h i d, b h j d -> b h i j' , q , k )
71
- mask_value = max_neg_value (dots )
70
+ mask_value = max_neg_value (q )
71
+ dots_key = f'{ cache_key } _dots'
72
+ if exists (cache ) and dots_key in cache :
73
+ topleft = cache [dots_key ]
74
+ top = F .pad (topleft , (0 , 1 ), value = mask_value )
75
+ bottom = q [..., n - 1 :, :] @ k .swapaxes (- 1 , - 2 )
76
+ dots = torch .cat ([top , bottom ], dim = - 2 )
77
+ else :
78
+ dots = q @ k .swapaxes (- 1 , - 2 )
79
+ if exists (cache ):
80
+ cache [dots_key ] = dots
72
81
73
82
if exists (mask ):
74
83
mask = rearrange (mask , 'b j -> b () () j' )
@@ -82,9 +91,21 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
82
91
83
92
attn = softmax (dots , dim = - 1 )
84
93
85
- out = torch .einsum ('b h i j, b h j d -> b h i d' , attn , v )
94
+ out_key = f'{ cache_key } _out'
95
+ if exists (cache ) and out_key in cache :
96
+ top = cache [out_key ]
97
+ assert top .shape [- 2 ] == n - 1
98
+
99
+ bottom = attn [..., n - 1 :n , :] @ v
100
+
101
+ out = torch .cat ([top , bottom ], dim = - 2 )
102
+ else :
103
+ out = attn @ v
104
+ if exists (cache ):
105
+ cache [out_key ] = out
106
+
86
107
out = rearrange (out , 'b h n d -> b n (h d)' )
87
- out = self .to_out (out )
108
+ out = self .to_out (out , cache = cache , cache_key = f' { cache_key } _out_proj' )
88
109
return out
89
110
90
111
# sparse attention with convolutional pattern, as mentioned in the blog post. customizable kernel size and dilation
@@ -265,15 +286,17 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
265
286
266
287
# text attention
267
288
268
- dots_text = einsum ('b i d, b j d -> b i j' , q_text , k_text )
289
+ print ('shapes 1:' , q_text .shape , k_text .swapaxes (- 1 , - 2 ).shape )
290
+ dots_text = q_text @ k_text .swapaxes (- 1 , - 2 )
269
291
mask_value = max_neg_value (dots_text )
270
292
271
293
i , j = dots_text .shape [- 2 :]
272
294
text_causal_mask = torch .ones (i , j , device = device ).triu_ (j - i + 1 ).bool ()
273
295
dots_text .masked_fill_ (text_causal_mask , mask_value )
274
296
275
297
attn_text = softmax (dots_text , dim = - 1 )
276
- out_text = einsum ('b i j, b j d -> b i d' , attn_text , v_text )
298
+ print ('shapes 2:' , attn_text .shape , v_text .shape )
299
+ out_text = attn_text @ v_text
277
300
278
301
# image attention
279
302
@@ -286,8 +309,10 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
286
309
287
310
# similarity
288
311
289
- dots_image_to_image = einsum ('b x i d, b x j d -> b x i j' , q_img , k_img )
290
- dots_image_to_text = einsum ('b x i d, b j d -> b x i j' , q_img , k_text )
312
+ print ('shapes 3:' , q_img .shape , k_img .swapaxes (- 1 , - 2 ).shape )
313
+ dots_image_to_image = q_img @ k_img .swapaxes (- 1 , - 2 )
314
+ print ('shapes 4:' , q_img .shape , k_text [:, None ].swapaxes (- 1 , - 2 ).shape )
315
+ dots_image_to_text = q_img @ k_text [:, None ].swapaxes (- 1 , - 2 )
291
316
292
317
dots = torch .cat ((dots_image_to_text , dots_image_to_image ), dim = - 1 )
293
318
@@ -310,8 +335,10 @@ def forward(self, x, mask = None, rotary_pos_emb = None, cache = None, cache_key
310
335
311
336
attn_image_to_text , attn_image_to_image = attn [..., :text_len ], attn [..., text_len :]
312
337
313
- out_image_to_image = einsum ('b x i j, b x j d -> b x i d' , attn_image_to_image , v_img )
314
- out_image_to_text = einsum ('b x i j, b j d -> b x i d' , attn_image_to_text , v_text )
338
+ print ('shapes 5:' , attn_image_to_image .shape , v_img .shape )
339
+ out_image_to_image = attn_image_to_image @ v_img
340
+ print ('shapes 6:' , attn_image_to_text .shape , v_text [:, None ].shape )
341
+ out_image_to_text = attn_image_to_text @ v_text [:, None ]
315
342
316
343
out_image = out_image_to_image + out_image_to_text
317
344
0 commit comments