88# nor does it submit to any jurisdiction.
99
1010
11+ from __future__ import annotations
12+
1113import logging
14+ import math
1215from typing import Optional
1316
1417import einops
18+ import torch
19+ from packaging import version
1520from torch import Tensor
1621from torch import nn
1722from torch .distributed .distributed_c10d import ProcessGroup
1823
19- try :
20- from flash_attn import flash_attn_func as attn_func
21- except ImportError :
22- from torch .nn .functional import scaled_dot_product_attention as attn_func
23-
24- _FLASH_ATTENTION_AVAILABLE = False
25- else :
26- _FLASH_ATTENTION_AVAILABLE = True
27-
2824from anemoi .models .distributed .transformer import shard_heads
2925from anemoi .models .distributed .transformer import shard_sequence
3026from anemoi .utils .config import DotDict
3329
3430
3531class MultiHeadSelfAttention (nn .Module ):
36- """Multi Head Self Attention Pytorch Layer."""
32+ """Multi Head Self Attention Pytorch Layer
33+
34+ allows for three different attention implementations:
35+ - scaled dot product attention, see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
36+ - flash attention, see https://github.com/Dao-AILab/flash-attention
37+ """
3738
3839 def __init__ (
3940 self ,
@@ -44,32 +45,89 @@ def __init__(
4445 is_causal : bool = False ,
4546 window_size : Optional [int ] = None ,
4647 dropout_p : float = 0.0 ,
48+ attention_implementation : str = "flash_attention" ,
49+ softcap : Optional [float ] = None ,
50+ use_alibi_slopes : bool = False ,
4751 ):
52+ """Initialize MultiHeadSelfAttention.
53+
54+ For the flash attention implementation, two additional parameters are available: softcap, use_alibi_slopes
55+
56+ softcap: Softcapping prevents the logits from growing excessively large
57+
58+ use_alibi_slopes: Adds bias of `(-alibi_slope * |i + seqlen_k - seqlen_q - j|)` to the attention score of
59+ query i and key j, where alibi_slope is calculated using get_alibi_slopes
60+
61+ Parameters
62+ ----------
63+ num_heads : int
64+ number of heads
65+ embed_dim : int
66+ embedding dimension
67+ bias : bool, optional
68+ bias, by default False
69+ is_causal : bool, optional
70+ apply causal attention mask, by default False
71+ window_size : Optional[int], optional
72+ window_size, by default None
73+ dropout_p : float, optional
74+ dropout probability, by default 0.0
75+ attention_implementation: str, optional
76+ A predefined string which selects which underlying attention
77+ implementation, by default "flash_attention"
78+ softcap : float, optional
79+ Anything > 0 activates softcapping attention, by default None
80+ use_alibi_slopes : bool, optional
81+ Adds bias
82+ """
4883 super ().__init__ ()
4984
5085 assert (
5186 embed_dim % num_heads == 0
5287 ), f"Embedding dimension ({ embed_dim } ) must be divisible by number of heads ({ num_heads } )"
5388
89+ self .attention_implementation = attention_implementation
90+ self .use_alibi_slopes = use_alibi_slopes
91+
5492 self .num_heads = num_heads
5593 self .embed_dim = embed_dim
5694 self .head_dim = embed_dim // num_heads # q k v
57- self .window_size = ( window_size , window_size ) # flash attention
95+ self .window_size = window_size
5896 self .dropout_p = dropout_p
5997 self .is_causal = is_causal
98+ self .softcap = softcap
99+
100+ self .set_attention_function ()
101+
102+ if self .use_alibi_slopes :
103+ self .alibi_slopes = get_alibi_slopes (num_heads )
104+ assert self .alibi_slopes .shape [0 ] == num_heads , "Error: Number of alibi_slopes must match number of heads"
105+ else :
106+ self .alibi_slopes = None
60107
61108 linear = layer_kernels ["Linear" ]
62109 self .lin_qkv = linear (embed_dim , 3 * embed_dim , bias = bias )
63- self .attention = attn_func
64-
65- if not _FLASH_ATTENTION_AVAILABLE :
66- LOGGER .warning ("Flash attention not available, falling back to pytorch scaled_dot_product_attention" )
67110
68111 self .projection = linear (embed_dim , embed_dim , bias = True )
69112
113+ def set_attention_function (self ):
114+ attn_funcs = {
115+ "flash_attention" : FlashAttentionWrapper ,
116+ "scaled_dot_product_attention" : SDPAAttentionWrapper ,
117+ }
118+ assert (
119+ self .attention_implementation in attn_funcs
120+ ), f"{ self .attention_implementation } not supported. \
121+ Please change model.processor.attention_implementation to one of: { attn_funcs .keys ()} "
122+ LOGGER .info (f"Using { self .attention_implementation } " )
123+
124+ # initalise the attn func here
125+ self .attention = attn_funcs [self .attention_implementation ]()
126+
70127 def forward (
71128 self , x : Tensor , shapes : list , batch_size : int , model_comm_group : Optional [ProcessGroup ] = None
72129 ) -> Tensor :
130+
73131 query , key , value = self .lin_qkv (x ).chunk (3 , - 1 )
74132
75133 if model_comm_group :
@@ -92,24 +150,151 @@ def forward(
92150 value = shard_heads (value , shapes = shapes , mgroup = model_comm_group )
93151 dropout_p = self .dropout_p if self .training else 0.0
94152
95- if _FLASH_ATTENTION_AVAILABLE :
96- query , key , value = (
97- einops .rearrange (t , "batch heads grid vars -> batch grid heads vars" ) for t in (query , key , value )
153+ out = self .attention (
154+ query ,
155+ key ,
156+ value ,
157+ batch_size ,
158+ causal = False ,
159+ window_size = self .window_size ,
160+ dropout_p = dropout_p ,
161+ softcap = self .softcap ,
162+ alibi_slopes = self .alibi_slopes ,
163+ )
164+
165+ out = shard_sequence (out , shapes = shapes , mgroup = model_comm_group )
166+ out = einops .rearrange (out , "batch heads grid vars -> (batch grid) (heads vars)" )
167+
168+ out = self .projection (out )
169+
170+ return out
171+
172+
173+ class SDPAAttentionWrapper (nn .Module ):
174+ """Wrapper for Pytorch scaled dot product attention"""
175+
176+ def __init__ (self ):
177+ super ().__init__ ()
178+
179+ from torch .nn .functional import scaled_dot_product_attention
180+
181+ self .attention = scaled_dot_product_attention
182+ self .mask = None
183+ self .window_size = None
184+
185+ def update_mask (self , seq_len , window_size : int , device : str ):
186+
187+ self .mask = (
188+ torch .abs (
189+ torch .arange (seq_len , device = device ).unsqueeze (0 ) - torch .arange (seq_len , device = device ).unsqueeze (1 )
98190 )
99- out = self .attention (query , key , value , causal = False , window_size = self .window_size , dropout_p = dropout_p )
100- out = einops .rearrange (out , "batch grid heads vars -> batch heads grid vars" )
101- else :
191+ <= window_size
192+ )
193+
194+ def forward (
195+ self ,
196+ query ,
197+ key ,
198+ value ,
199+ batch_size : int ,
200+ causal = False ,
201+ window_size = None ,
202+ dropout_p = 0.0 ,
203+ softcap = None ,
204+ alibi_slopes = None ,
205+ ):
206+ if softcap is not None :
207+ NotImplementedError (
208+ "Softcap not supported by Pytorchs SDPA. please switch to flash attention or disable softcap."
209+ )
210+ if alibi_slopes is not None :
211+ NotImplementedError (
212+ "Alibi slopes not supported by Pytorchs SDPA. please switch to flash attention or disable alibi slopes."
213+ )
214+
215+ sequence_len = query .shape [- 2 ]
216+
217+ if window_size is not None and (self .mask is None or tuple (self .mask .shape ) != (sequence_len , sequence_len )):
218+ self .update_mask (sequence_len , window_size = window_size , device = query .device )
219+
220+ with torch .nn .attention .sdpa_kernel (backends = [torch .nn .attention .SDPBackend .MATH ]):
102221 out = self .attention (
103222 query ,
104223 key ,
105224 value ,
106- is_causal = False ,
225+ attn_mask = self .mask ,
226+ is_causal = causal ,
107227 dropout_p = dropout_p ,
108- ) # expects (batch heads grid variable) format
228+ )
109229
110- out = shard_sequence (out , shapes = shapes , mgroup = model_comm_group )
111- out = einops .rearrange (out , "batch heads grid vars -> (batch grid) (heads vars)" )
230+ return out
112231
113- out = self .projection (out )
114232
233+ class FlashAttentionWrapper (nn .Module ):
234+ """Wrapper for Flash attention."""
235+
236+ def __init__ (self ):
237+ super ().__init__ ()
238+ try :
239+ import flash_attn
240+ except ImportError :
241+ raise ImportError ("Error: Flash-attn not installed. Please install flash-attn to use Flash Attention" )
242+
243+ if version .parse (flash_attn .__version__ ) < version .parse ("2.6.0" ):
244+ raise RuntimeError ("Error: Flash-attn version is too low. Update to 2.6.0 or higher." )
245+ else :
246+ self .attention = flash_attn .flash_attn_func
247+
248+ def forward (
249+ self ,
250+ query ,
251+ key ,
252+ value ,
253+ batch_size : int ,
254+ causal : bool = False ,
255+ window_size : int = None ,
256+ dropout_p : float = 0.0 ,
257+ softcap : Optional [float ] = None ,
258+ alibi_slopes : torch .Tensor = None ,
259+ ):
260+ query , key , value = (
261+ einops .rearrange (t , "batch heads grid vars -> batch grid heads vars" ) for t in (query , key , value )
262+ )
263+
264+ alibi_slopes = alibi_slopes .repeat (batch_size , 1 ).to (query .device ) if alibi_slopes is not None else None
265+
266+ out = self .attention (
267+ query ,
268+ key ,
269+ value ,
270+ causal = False ,
271+ window_size = (window_size , window_size ),
272+ dropout_p = dropout_p ,
273+ softcap = softcap ,
274+ alibi_slopes = alibi_slopes ,
275+ )
276+ out = einops .rearrange (out , "batch grid heads vars -> batch heads grid vars" )
115277 return out
278+
279+
280+ def get_alibi_slopes (num_heads : int ) -> Tensor :
281+ """Calculates linearly decreasing slopes for alibi attention.
282+
283+ Parameters
284+ ----------
285+ num_heads : int
286+ number of attention heads
287+
288+ Returns
289+ -------
290+ Tensor
291+ aLiBi slopes
292+ """
293+ n = 2 ** math .floor (math .log2 (num_heads ))
294+ slope_0 = 2 ** (- 8 / n )
295+ alibi_slopes = torch .pow (slope_0 , torch .arange (1 , 1 + n ))
296+ if n < num_heads :
297+ slope_hat_0 = 2 ** (- 4 / n )
298+ alibi_slopes_hat = torch .pow (slope_hat_0 , torch .arange (1 , 1 + 2 * (num_heads - n ), 2 ))
299+ alibi_slopes = torch .cat ([alibi_slopes , alibi_slopes_hat ])
300+ return alibi_slopes
0 commit comments