@@ -25,6 +25,7 @@ def __init__(
2525 num_heads ,
2626 dim_head_proj = None ,
2727 dropout_rate = 0.0 ,
28+ with_residual = True ,
2829 with_qk_lnorm = True ,
2930 with_flash = True ,
3031 norm_type = "LayerNorm" ,
@@ -38,6 +39,7 @@ def __init__(
3839 self .num_heads = num_heads
3940 self .dropout_rate = dropout_rate
4041 self .with_flash = with_flash
42+ self .with_residual = with_residual
4143 self .softcap = softcap
4244
4345 assert dim_embed % num_heads == 0
@@ -50,8 +52,6 @@ def __init__(
5052
5153 if dim_aux is not None :
5254 self .lnorm = AdaLayerNorm (dim_embed , dim_aux , norm_eps = norm_eps )
53- else :
54- self .lnorm = norm (dim_embed , eps = norm_eps )
5555 self .proj_heads_q = torch .nn .Linear (dim_embed , num_heads * self .dim_head_proj , bias = False )
5656 self .proj_heads_k = torch .nn .Linear (dim_embed , num_heads * self .dim_head_proj , bias = False )
5757 self .proj_heads_v = torch .nn .Linear (dim_embed , num_heads * self .dim_head_proj , bias = False )
@@ -71,7 +71,7 @@ def __init__(
7171 #########################################
7272 def forward (self , x , x_lens , ada_ln_aux = None ):
7373 x_in = x
74- x = self . lnorm ( x ) if ada_ln_aux is None else self .lnorm (x , ada_ln_aux )
74+ x = x if ada_ln_aux is None else self .lnorm (x , ada_ln_aux )
7575
7676 ## project onto heads and q,k,v and
7777 # ensure these are 4D tensors as required for flash attention
@@ -94,8 +94,12 @@ def forward(self, x, x_lens, ada_ln_aux=None):
9494 dropout_p = self .dropout_rate ,
9595 )
9696
97- # return x_in + self.dropout( self.proj_out( outs.flatten( -2, -1)) )
98- return x_in + self .proj_out (outs .flatten (- 2 , - 1 ))
97+ x = self .proj_out (outs .flatten (- 2 , - 1 ))
98+
99+ if self .with_residual :
100+ x = x_in + x
101+
102+ return x
99103
100104
101105####################################################################################################
@@ -107,6 +111,7 @@ def __init__(
107111 num_heads ,
108112 dim_head_proj = None ,
109113 dropout_rate = 0.0 ,
114+ with_residual = True ,
110115 with_qk_lnorm = True ,
111116 with_flash = True ,
112117 norm_type = "LayerNorm" ,
@@ -167,7 +172,11 @@ def forward(self, x, x_lens=None):
167172
168173 outs = self .compiled_flex_attention (qs , ks , vs ).transpose (1 , 2 ).squeeze ()
169174
170- return x_in + self .dropout (self .proj_out (outs .flatten (- 2 , - 1 )))
175+ x = self .proj_out (outs .flatten (- 2 , - 1 ))
176+ if self .with_residual :
177+ x = x_in + x
178+
179+ return x
171180
172181
173182####################################################################################################
@@ -284,9 +293,6 @@ def __init__(
284293
285294 if dim_aux is not None :
286295 self .lnorm_in_q = AdaLayerNorm (dim_embed_q , dim_aux , norm_eps = norm_eps )
287- else :
288- self .lnorm_in_q = norm (dim_embed_q , eps = norm_eps )
289- self .lnorm_in_kv = norm (dim_embed_kv , eps = norm_eps )
290296
291297 self .proj_heads_q = torch .nn .Linear (dim_embed_q , num_heads * self .dim_head_proj , bias = False )
292298 self .proj_heads_k = torch .nn .Linear (
@@ -309,11 +315,10 @@ def __init__(
309315 assert with_flash , "Only flash attention supported at the moment"
310316
311317 #########################################
312- def forward (self , x_q , x_kv , x_q_lens = None , x_kv_lens = None , ada_ln_aux = None ):
318+ def forward (self , x_q , x_kv , x_lens = None , x_kv_lens = None , ada_ln_aux = None ):
313319 if self .with_residual :
314320 x_q_in = x_q
315- x_q = self .lnorm_in_q (x_q ) if ada_ln_aux is None else self .lnorm_in_q (x_q , ada_ln_aux )
316- x_kv = self .lnorm_in_kv (x_kv )
321+ x_q = x_q if ada_ln_aux is None else self .lnorm_in_q (x_q , ada_ln_aux )
317322
318323 ## project onto heads and q,k,v and
319324 # ensure these are 4D tensors as required for flash attention
@@ -324,15 +329,15 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):
324329 vs = self .proj_heads_v (x_kv ).reshape (s )
325330
326331 if x_kv_lens is not None :
327- cum_x_q_lens = torch .cumsum (x_q_lens , 0 , dtype = torch .int32 )
332+ cum_x_q_lens = torch .cumsum (x_lens , 0 , dtype = torch .int32 )
328333 cum_x_kv_lens = torch .cumsum (x_kv_lens , 0 , dtype = torch .int32 )
329334 outs = flash_attn_varlen_func (
330335 qs ,
331336 ks ,
332337 vs ,
333338 cum_x_q_lens ,
334339 cum_x_kv_lens ,
335- x_q_lens .max (),
340+ x_lens .max (),
336341 x_kv_lens .max (),
337342 softcap = self .softcap ,
338343 dropout_p = self .dropout_rate ,
@@ -454,14 +459,13 @@ def forward(self, x_q, x_kv, x_q_lens=None, x_kv_lens=None, ada_ln_aux=None):
454459 vs ,
455460 cum_x_q_lens ,
456461 cum_x_kv_lens ,
457- x_q_lens .max (),
458- x_kv_lens .max (),
462+ x_q_lens .max (). item () ,
463+ x_kv_lens .max (). item () ,
459464 softcap = self .softcap ,
460465 dropout_p = self .dropout_rate ,
461466 )
462467 ]
463468
464- # outs = self.dropout( self.proj_out( torch.stack(outs).transpose(1,0).flatten( -2, -1)) )
465469 outs = self .proj_out (torch .stack (outs ).transpose (1 , 0 ).flatten (- 2 , - 1 ))
466470 if self .with_residual :
467471 outs = x_q_in + outs .reshape (x_q_in .shape )
@@ -479,7 +483,9 @@ def __init__(
479483 dim_head_proj = None ,
480484 dropout_rate = 0.0 ,
481485 with_qk_lnorm = True ,
486+ with_residual = True ,
482487 with_flash = True ,
488+ softcap = 0.0 ,
483489 norm_type = "LayerNorm" ,
484490 dim_aux = None ,
485491 norm_eps = 1e-5 ,
@@ -490,7 +496,10 @@ def __init__(
490496 self .num_heads = num_heads
491497 self .with_flash = with_flash
492498 self .dropout_rate = dropout_rate
499+ self .with_residual = with_residual
500+ self .softcap = softcap
493501
502+ assert with_flash , "You have to use flash attention"
494503 assert dim_embed % num_heads == 0
495504 self .dim_head_proj = dim_embed // num_heads if dim_head_proj is None else dim_head_proj
496505
@@ -502,57 +511,136 @@ def __init__(
502511 if dim_aux is not None :
503512 self .lnorm = AdaLayerNorm (dim_embed , dim_aux , norm_eps = norm_eps )
504513 else :
505- self .lnorm = norm (dim_embed , eps = norm_eps )
514+ self .lnorm = norm (dim_embed )
506515 self .proj_heads_q = torch .nn .Linear (dim_embed , num_heads * self .dim_head_proj , bias = False )
507516 self .proj_heads_k = torch .nn .Linear (dim_embed , num_heads * self .dim_head_proj , bias = False )
508517 self .proj_heads_v = torch .nn .Linear (dim_embed , num_heads * self .dim_head_proj , bias = False )
509518 self .proj_out = torch .nn .Linear (dim_embed , dim_embed , bias = False )
510- self .dropout = (
511- torch .nn .Dropout (p = dropout_rate ) if dropout_rate > 0.0 else torch .nn .Identity ()
512- )
513519
514520 lnorm = norm if with_qk_lnorm else torch .nn .Identity
515- self .lnorm_q = lnorm (self .dim_head_proj , eps = norm_eps )
516- self .lnorm_k = lnorm (self .dim_head_proj , eps = norm_eps )
521+ self .lnorm_q = lnorm (self .dim_head_proj )
522+ self .lnorm_k = lnorm (self .dim_head_proj )
517523
518524 self .dtype = attention_dtype
519- if with_flash :
520- self .att = torch .nn .functional .scaled_dot_product_attention
521- else :
522- self .att = self .attention
523- self .softmax = torch .nn .Softmax (dim = - 1 )
524525
525526 #########################################
526527 def forward (self , x , ada_ln_aux = None ):
527528 x_in = x
528- # x = self.lnorm( x)
529- x = self .lnorm (x ) if ada_ln_aux is None else self .lnorm (x , ada_ln_aux )
529+ # x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
530530
531531 ## project onto heads and q,k,v and
532532 # ensure these are 4D tensors as required for flash attention
533- s = [* ([x .shape [0 ], 1 ] if len (x .shape ) == 2 else x .shape [:- 1 ]), self .num_heads , - 1 ]
534- qs = self .lnorm_q (self .proj_heads_q (x ).reshape (s )).to (self .dtype )
535- ks = self .lnorm_k (self .proj_heads_k (x ).reshape (s )).to (self .dtype )
536- vs = self .proj_heads_v (x ).reshape (s ).to (self .dtype )
533+ q_shape = [* ([x .shape [0 ], 1 ] if len (x .shape ) == 2 else x .shape [:- 1 ]), self .num_heads , - 1 ]
534+ kv_shape = [
535+ * ([x .shape [0 ], 1 ] if len (x .shape ) == 2 else x .shape [:- 1 ]),
536+ self .num_heads ,
537+ - 1 ,
538+ ]
539+ qs = self .lnorm_q (self .proj_heads_q (x ).reshape (q_shape )).to (self .dtype )
540+ ks = self .lnorm_k (self .proj_heads_k (x ).reshape (kv_shape )).to (self .dtype )
541+ vs = self .proj_heads_v (x ).reshape (kv_shape ).to (self .dtype )
537542
538543 # ordering of tensors (seq, heads, embed) (which differs from torch's flash attention implt)
539- outs = flash_attn_func (qs , ks , vs , dropout_p = self .dropout_rate )
544+ outs = flash_attn_func (qs , ks , vs , softcap = self .softcap , dropout_p = self .dropout_rate )
545+
546+ if self .with_residual :
547+ x = x_in + self .proj_out (outs .flatten (- 2 , - 1 ))
548+ else :
549+ x = self .proj_out (outs .flatten (- 2 , - 1 ))
550+
551+ return x
540552
541- # return x_in + self.dropout( self.proj_out( outs.flatten( -2, -1)) )
542- return x_in + self .proj_out (outs .flatten (- 2 , - 1 ))
543553
554+ ####################################################################################################
555+ class MultiCrossAttentionHead (torch .nn .Module ):
544556 #########################################
545- def attention (self , q , k , v ):
546- scaling = 1.0 / torch .sqrt (torch .tensor (q .shape [- 1 ]))
547- return torch .matmul (self .softmax (scaling * self .score (q , k )), v )
557+ def __init__ (
558+ self ,
559+ dim_embed_q ,
560+ dim_embed_kv ,
561+ num_heads ,
562+ dim_head_proj = None ,
563+ dropout_rate = 0.0 ,
564+ with_qk_lnorm = True ,
565+ with_residual = True ,
566+ with_flash = True ,
567+ softcap = 0.0 ,
568+ norm_type = "LayerNorm" ,
569+ dim_aux = None ,
570+ norm_eps = 1e-5 ,
571+ attention_dtype = torch .bfloat16 ,
572+ ):
573+ super (MultiCrossAttentionHead , self ).__init__ ()
574+
575+ self .num_heads = num_heads
576+ self .with_flash = with_flash
577+ self .dropout_rate = dropout_rate
578+ self .with_residual = with_residual
579+ self .softcap = softcap
580+
581+ assert with_flash , "You have to use flash attention"
582+ assert dim_embed_kv % num_heads == 0
583+ self .dim_head_proj_kv = (
584+ dim_embed_kv // num_heads if dim_head_proj is None else dim_head_proj
585+ )
586+ self .dim_head_proj_q = dim_embed_q // num_heads if dim_head_proj is None else dim_head_proj
587+
588+ if norm_type == "LayerNorm" :
589+ norm = partial (torch .nn .LayerNorm , elementwise_affine = False , eps = norm_eps )
590+ else :
591+ norm = RMSNorm
592+
593+ if dim_aux is not None :
594+ self .lnorm = AdaLayerNorm (dim_embed_kv , dim_aux , norm_eps = norm_eps )
595+ else :
596+ self .lnorm = norm (dim_embed_kv )
597+ self .proj_heads_q = torch .nn .Linear (
598+ dim_embed_q , num_heads * self .dim_head_proj_q , bias = False
599+ )
600+ self .proj_heads_k = torch .nn .Linear (
601+ dim_embed_kv , num_heads * self .dim_head_proj_kv , bias = False
602+ )
603+ self .proj_heads_v = torch .nn .Linear (
604+ dim_embed_kv , num_heads * self .dim_head_proj_kv , bias = False
605+ )
606+ self .proj_out = torch .nn .Linear (dim_embed_kv , dim_embed_kv , bias = False )
607+
608+ lnorm = norm if with_qk_lnorm else torch .nn .Identity
609+ self .lnorm_q = lnorm (self .dim_head_proj_q )
610+ self .lnorm_k = lnorm (self .dim_head_proj_kv )
611+
612+ self .dtype = attention_dtype
548613
549614 #########################################
550- def score (self , q , k ):
551- return torch .matmul (q , torch .transpose (k , - 2 , - 1 ))
615+ def forward (self , q , x , ada_ln_aux = None ):
616+ x_in = x
617+ # x = self.lnorm(x) if ada_ln_aux is None else self.lnorm(x, ada_ln_aux)
618+
619+ ## project onto heads and q,k,v and
620+ # ensure these are 4D tensors as required for flash attention
621+ q_shape = [* ([x .shape [0 ], 1 ] if len (x .shape ) == 2 else x .shape [:- 1 ]), self .num_heads , - 1 ]
622+ kv_shape = [
623+ * ([x .shape [0 ], 1 ] if len (x .shape ) == 2 else x .shape [:- 1 ]),
624+ self .num_heads ,
625+ - 1 ,
626+ ]
627+ qs = self .lnorm_q (self .proj_heads_q (x ).reshape (q_shape )).to (self .dtype )
628+ ks = self .lnorm_k (self .proj_heads_k (x ).reshape (kv_shape )).to (self .dtype )
629+ vs = self .proj_heads_v (x ).reshape (kv_shape ).to (self .dtype )
630+
631+ # ordering of tensors (seq, heads, embed) (which differs from torch's flash attention implt)
632+ outs = flash_attn_func (qs , ks , vs , softcap = self .softcap , dropout_p = self .dropout_rate )
633+
634+ if self .with_residual :
635+ x = x_in + self .proj_out (outs .flatten (- 2 , - 1 ))
636+ else :
637+ x = self .proj_out (outs .flatten (- 2 , - 1 ))
638+
639+ return x
552640
553641
554642####################################################################################################
555- class MultiCrossAttentionHead (torch .nn .Module ):
643+ class MultiCrossAttentionHeadSPDA (torch .nn .Module ):
556644 #########################################
557645 def __init__ (
558646 self ,
0 commit comments