6
6
import torch .nn .functional as F
7
7
from einops import rearrange , repeat
8
8
9
+ from dalle_pytorch .cache import Cached
10
+
9
11
from rotary_embedding_torch import apply_rotary_emb
10
12
11
13
# helpers
@@ -102,14 +104,14 @@ def __init__(self, dim, seq_len, image_size = 32, kernel_size = 5, dilation = 1,
102
104
103
105
self .stable = stable
104
106
105
- self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
107
+ self .to_qkv = Cached ( nn .Linear (dim , inner_dim * 3 , bias = False ) )
106
108
107
- self .to_out = nn .Sequential (
109
+ self .to_out = Cached ( nn .Sequential (
108
110
nn .Linear (inner_dim , dim ),
109
111
nn .Dropout (dropout )
110
- )
112
+ ))
111
113
112
- def forward (self , x , mask = None , rotary_pos_emb = None ):
114
+ def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
113
115
b , n , _ , h , img_size , kernel_size , dilation , seq_len , device = * x .shape , self .heads , self .image_size , self .kernel_size , self .dilation , self .seq_len , x .device
114
116
softmax = torch .softmax if not self .stable else stable_softmax
115
117
@@ -126,7 +128,7 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
126
128
127
129
# derive query / keys / values
128
130
129
- qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
131
+ qkv = self .to_qkv (x , cache = cache , cache_key = f' { cache_key } _qkv' ).chunk (3 , dim = - 1 )
130
132
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
131
133
132
134
if exists (rotary_pos_emb ):
@@ -203,11 +205,13 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
203
205
out = torch .cat ((out_text , out_image ), dim = 1 )
204
206
205
207
out = rearrange (out , '(b h) n d -> b n (h d)' , h = h )
206
- out = self .to_out (out )
208
+ out = self .to_out (out , cache = cache , cache_key = f' { cache_key } _out' )
207
209
return out [:, :n ]
208
210
209
211
# sparse axial causal attention
210
212
213
+ from time import time
214
+
211
215
class SparseAxialCausalAttention (nn .Module ):
212
216
def __init__ (self , dim , seq_len , image_size = 32 , axis = 0 , heads = 8 , dim_head = 64 , dropout = 0. , stable = False , ** kwargs ):
213
217
super ().__init__ ()
@@ -222,14 +226,14 @@ def __init__(self, dim, seq_len, image_size = 32, axis = 0, heads = 8, dim_head
222
226
223
227
self .stable = stable
224
228
225
- self .to_qkv = nn .Linear (dim , inner_dim * 3 , bias = False )
229
+ self .to_qkv = Cached ( nn .Linear (dim , inner_dim * 3 , bias = False ) )
226
230
227
- self .to_out = nn .Sequential (
231
+ self .to_out = Cached ( nn .Sequential (
228
232
nn .Linear (inner_dim , dim ),
229
233
nn .Dropout (dropout )
230
- )
234
+ ))
231
235
232
- def forward (self , x , mask = None , rotary_pos_emb = None ):
236
+ def forward (self , x , mask = None , rotary_pos_emb = None , cache = None , cache_key = None ):
233
237
b , n , _ , h , img_size , axis , seq_len , device = * x .shape , self .heads , self .image_size , self .axis , self .seq_len , x .device
234
238
softmax = torch .softmax if not self .stable else stable_softmax
235
239
@@ -246,7 +250,10 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
246
250
247
251
# derive queries / keys / values
248
252
249
- qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
253
+ t = time ()
254
+ qkv = self .to_qkv (x , cache = cache , cache_key = f'{ cache_key } _qkv' ).chunk (3 , dim = - 1 )
255
+ print (f'Time 1: { time () - t :.5f} sec' )
256
+ t = time ()
250
257
q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
251
258
252
259
if exists (rotary_pos_emb ):
@@ -317,7 +324,10 @@ def forward(self, x, mask = None, rotary_pos_emb = None):
317
324
out = torch .cat ((out_text , out_image ), dim = 1 )
318
325
319
326
out = rearrange (out , '(b h) n d -> b n (h d)' , h = h )
320
- out = self .to_out (out )
327
+ print (f'Time 2: { time () - t :.5f} sec' )
328
+ t = time ()
329
+ out = self .to_out (out , cache = cache , cache_key = f'{ cache_key } _out' )
330
+ print (f'Time 3: { time () - t :.5f} sec\n ' )
321
331
return out [:, :n ]
322
332
323
333
# microsoft sparse attention CUDA kernel
0 commit comments