@@ -79,7 +79,7 @@ def _nyi_attn(func_name, *args, **kwargs): # pylint: disable=W0613
79
79
80
80
def _flash_float32_compatibility_wrapper (input_idxs : Tuple , flash_func : Callable , * args , ** kwargs ):
81
81
if gpc .config .model .dtype is torch .float32 :
82
- inputs = ( args [idx ] for idx in input_idxs )
82
+ inputs = [ args [idx ] for idx in input_idxs ]
83
83
input_dtype = inputs [0 ].dtype
84
84
other_args = [args [idx ] for idx in range (len (inputs ), len (args ))]
85
85
@@ -194,10 +194,35 @@ def _flash_fixedlen_qkvsplited_attn(q, k, v, dropout_p=0.0, softmax_scale=None,
194
194
195
195
196
196
# npu flash attention operators
197
- # TODO: should we add _flash_float32_compatibility_wrapper support for npu.
197
+ def _npu_varlen_qkvsplited_attn (
198
+ q : torch .Tensor ,
199
+ k : torch .Tensor ,
200
+ v : torch .Tensor ,
201
+ cu_seqlens_q ,
202
+ cu_seqlens_k ,
203
+ max_seqlen_q , # pylint: disable=W0613
204
+ max_seqlen_k , # pylint: disable=W0613
205
+ dropout_p = 0.0 ,
206
+ softmax_scale = None ,
207
+ causal = False ,
208
+ ):
209
+ return _flash_float32_compatibility_wrapper (
210
+ (0 , 1 , 2 ),
211
+ _npu_varlen_qkvsplited_func ,
212
+ q ,
213
+ k ,
214
+ v ,
215
+ cu_seqlens_q ,
216
+ cu_seqlens_k ,
217
+ max_seqlen_q ,
218
+ max_seqlen_k ,
219
+ dropout_p ,
220
+ softmax_scale ,
221
+ causal ,
222
+ )
198
223
199
224
200
- def _npu_varlen_qkvsplited_attn (
225
+ def _npu_varlen_qkvsplited_func (
201
226
q : torch .Tensor ,
202
227
k : torch .Tensor ,
203
228
v : torch .Tensor ,
@@ -208,17 +233,32 @@ def _npu_varlen_qkvsplited_attn(
208
233
dropout_p = 0.0 ,
209
234
softmax_scale = None ,
210
235
causal = False ,
236
+ use_fixlen = False ,
211
237
):
212
- # TODO: support npu native varlen flash attention
238
+ """Support Huawei Ascend's torch_npu flash attention.
239
+ Tested version:
240
+ torch: 2.1.0+cpu
241
+ torch_npu: 2.1.0.post3+git7c4136d
242
+ cann: 8.0.RC1.alpha003
243
+ """
213
244
packed_length = q .size (dim = 1 )
245
+ softmax_scale = softmax_scale or 1.0 / math .sqrt (q .shape [- 1 ])
214
246
215
- q = unpack_qkv_before_attn (q , cu_seqlens = cu_seqlens_q )
216
- k = unpack_qkv_before_attn (k , cu_seqlens = cu_seqlens_k )
217
- v = unpack_qkv_before_attn (v , cu_seqlens = cu_seqlens_k )
247
+ if use_fixlen :
218
248
219
- output = _npu_fixedlen_qkvsplited_attn (q , k , v , dropout_p , softmax_scale , causal )
249
+ q = unpack_qkv_before_attn (q , cu_seqlens = cu_seqlens_q )
250
+ k = unpack_qkv_before_attn (k , cu_seqlens = cu_seqlens_k )
251
+ v = unpack_qkv_before_attn (v , cu_seqlens = cu_seqlens_k )
220
252
221
- return pack_output_after_attn (output , cu_seqlens_q , packed_length )
253
+ output = _npu_fixedlen_qkvsplited_attn (q , k , v , dropout_p , softmax_scale , causal )
254
+
255
+ output = pack_output_after_attn (output , cu_seqlens_q , packed_length )
256
+ else :
257
+ output = _npu_fused_varlen_qkvsplited_attn (
258
+ q , k , v , dropout_p , softmax_scale , causal , max_seqlen_q , max_seqlen_k , cu_seqlens_q , cu_seqlens_k
259
+ )
260
+
261
+ return output
222
262
223
263
224
264
def _npu_fixedlen_qkvsplited_attn (
@@ -236,6 +276,7 @@ def _npu_fixedlen_qkvsplited_attn(
236
276
q , k , v = q .squeeze (dim = 2 ), k .squeeze (dim = 2 ), v .squeeze (dim = 2 )
237
277
238
278
_ , seqlen , n_head , _ = q .shape
279
+ sparse_mode = 0
239
280
attention_mask = torch .triu (torch .ones (seqlen , seqlen , device = get_current_device ()), 1 ).bool ()
240
281
241
282
return _origin_npu_fixedlen_qkvsplited_func (
@@ -247,25 +288,71 @@ def _npu_fixedlen_qkvsplited_attn(
247
288
pse = None ,
248
289
atten_mask = attention_mask ,
249
290
scale = softmax_scale ,
250
- sparse_mode = 0 , # If necessary, expose the interface
291
+ sparse_mode = sparse_mode , # If necessary, expose the interface
251
292
pre_tockens = seqlen , # Used for sparse calculations, representing the left boundary of the slides window
252
293
next_tockens = 0 , # If necessary, expose the interface
253
294
keep_prob = 1 - dropout_p ,
254
295
inner_precise = 0 , # If necessary, expose the interface
255
- )
296
+ )[ 0 ]
256
297
257
298
258
- def _npu_varlen_qkvpacked_attn (
259
- qkv : torch .Tensor , cu_seqlens , max_seqlen , dropout_p , softmax_scale = None , causal = False # pylint: disable=W0613
299
+ def _npu_fused_varlen_qkvsplited_attn (
300
+ q : torch .Tensor ,
301
+ k : torch .Tensor ,
302
+ v : torch .Tensor ,
303
+ dropout_p : float ,
304
+ softmax_scale = None ,
305
+ causal = False ,
306
+ max_seqlen_q : int = None ,
307
+ max_seqlen_k : int = None ,
308
+ cu_seqlens_q = None ,
309
+ cu_seqlens_kv = None ,
310
+ deterministic = False ,
260
311
):
261
- # TODO: support npu native varlen flash attention
262
- packed_length = qkv . size ( dim = 1 )
312
+ assert causal is True
313
+ assert q . dtype in ( torch . bfloat16 , torch . float16 )
263
314
264
- qkv = unpack_qkv_before_attn (qkv , cu_seqlens = cu_seqlens )
315
+ if len (q .shape ) == 4 : # [1, packedseqlen, n_head, headdim]
316
+ q , k , v = q .squeeze (dim = 0 ), k .squeeze (dim = 0 ), v .squeeze (dim = 0 )
265
317
266
- output = _npu_fixedlen_qkvpacked_attn (qkv , dropout_p , softmax_scale , causal )
318
+ S , N = max (max_seqlen_q , max_seqlen_k ), q .shape [1 ]
319
+ device = get_current_device ()
320
+ sparse_mode = 0
267
321
268
- return pack_output_after_attn (output , cu_seqlens , packed_length )
322
+ if max_seqlen_k > 2048 and max_seqlen_q > 2048 :
323
+ sparse_mode = 2
324
+ max_seqlen_k = 2048
325
+ max_seqlen_q = 2048
326
+
327
+ attention_mask = torch .triu (torch .ones (max_seqlen_q , max_seqlen_k , device = device ), 1 ).bool ()
328
+ cu_seqlens_q = cu_seqlens_q [1 :].tolist ()
329
+ cu_seqlens_kv = cu_seqlens_kv [1 :].tolist ()
330
+
331
+ return _origin_npu_fixedlen_qkvsplited_func (
332
+ query = q ,
333
+ key = k ,
334
+ value = v ,
335
+ head_num = N ,
336
+ input_layout = "TND" ,
337
+ pse = None ,
338
+ atten_mask = attention_mask ,
339
+ scale = softmax_scale ,
340
+ sparse_mode = sparse_mode ,
341
+ pre_tockens = S , # Used for sparse calculations, representing the left boundary of the slides window
342
+ next_tockens = 0 ,
343
+ keep_prob = 1 - dropout_p ,
344
+ inner_precise = 0 if not deterministic else 2 ,
345
+ actual_seq_kvlen = cu_seqlens_kv ,
346
+ actual_seq_qlen = cu_seqlens_q ,
347
+ )[0 ].unsqueeze (dim = 0 )
348
+
349
+
350
+ def _npu_varlen_qkvpacked_attn (
351
+ qkv : torch .Tensor , cu_seqlens , max_seqlen , dropout_p , softmax_scale = None , causal = False # pylint: disable=W0613
352
+ ):
353
+ # TODO: support npu native varlen flash attention
354
+ q , k , v = qkv .unbind (dim = 2 )
355
+ return _npu_varlen_qkvsplited_attn (q , k , v , cu_seqlens , max_seqlen , dropout_p , softmax_scale , causal )
269
356
270
357
271
358
def _npu_fixedlen_qkvpacked_attn (qkv : torch .Tensor , dropout_p : float , softmax_scale = None , causal = False ):
@@ -285,14 +372,20 @@ def _npu_varlen_kvpacked_attn(
285
372
causal = False ,
286
373
):
287
374
# TODO: support npu native varlen flash attention
288
- packed_length = q .size (dim = 1 )
289
-
290
- q = unpack_qkv_before_attn (q , cu_seqlens = cu_seqlens_q )
291
- kv = unpack_qkv_before_attn (kv , cu_seqlens = cu_seqlens_k )
292
-
293
- output = _npu_fixedlen_kvpacked_attn (q , kv , dropout_p , softmax_scale , causal )
294
-
295
- return pack_output_after_attn (output , cu_seqlens_q , packed_length )
375
+ k , v = kv .unbind (dim = 2 )
376
+ k , v = k .squeeze (dim = 2 ), v .squeeze (dim = 2 )
377
+ return _npu_varlen_qkvsplited_attn (
378
+ q ,
379
+ k ,
380
+ v ,
381
+ cu_seqlens_q ,
382
+ cu_seqlens_k ,
383
+ max_seqlen_q ,
384
+ max_seqlen_k ,
385
+ dropout_p ,
386
+ softmax_scale ,
387
+ causal ,
388
+ )
296
389
297
390
298
391
def _npu_fixedlen_kvpacked_attn (q : torch .Tensor , kv : torch .Tensor , dropout_p : float , softmax_scale = None , causal = False ):
@@ -335,12 +428,6 @@ def _deeplink_fixedlen_qkvsplited_attn(*args, **kwargs):
335
428
336
429
337
430
# torch attention operators
338
-
339
-
340
- def _torch_varlen_qkvpacked_attn (* args , ** kwargs ):
341
- _nyi_attn ("_torch_varlen_qkvpacked_attn" , * args , ** kwargs )
342
-
343
-
344
431
# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py
345
432
def _torch_fixedlen_qkvpacked_attn (qkv : torch .Tensor , dropout , softmax_scale = None , causal = False , key_padding_mask = None ):
346
433
batch_size , seqlen = qkv .shape [0 ], qkv .shape [1 ]
@@ -369,10 +456,6 @@ def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=Non
369
456
return output
370
457
371
458
372
- def _torch_varlen_kvpacked_attn (* args , ** kwargs ):
373
- _nyi_attn ("_torch_varlen_kvpacked_attn" , * args , ** kwargs )
374
-
375
-
376
459
# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py
377
460
def _torch_fixedlen_kvpacked_attn (
378
461
q : torch .Tensor , kv : torch .Tensor , dropout , softmax_scale = None , causal = False , key_padding_mask = None
@@ -407,17 +490,78 @@ def _torch_fixedlen_kvpacked_attn(
407
490
return output
408
491
409
492
410
- def _torch_varlen_qkvsplited_attn (* args , ** kwargs ):
411
- _nyi_attn ("_torch_varlen_qkvsplited_attn" , * args , ** kwargs )
412
-
413
-
414
493
def _torch_fixedlen_qkvsplited_attn (
415
494
q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , dropout , softmax_scale = None , causal = False , key_padding_mask = None
416
495
):
417
496
kv = torch .stack ([k , v ], dim = 2 )
418
497
return _torch_fixedlen_kvpacked_attn (q , kv , dropout , softmax_scale , causal , key_padding_mask )
419
498
420
499
500
+ def _torch_varlen_qkvsplited_attn (
501
+ q : torch .Tensor ,
502
+ k : torch .Tensor ,
503
+ v : torch .Tensor ,
504
+ cu_seqlens_q ,
505
+ cu_seqlens_k ,
506
+ max_seqlen_q , # pylint: disable=W0613
507
+ max_seqlen_k , # pylint: disable=W0613
508
+ dropout ,
509
+ softmax_scale = None ,
510
+ causal = False ,
511
+ key_padding_mask = None ,
512
+ ):
513
+ kv = torch .stack ([k , v ], dim = 2 )
514
+ packed_length = q .size (dim = 1 )
515
+
516
+ q = unpack_qkv_before_attn (q , cu_seqlens = cu_seqlens_q )
517
+ kv = unpack_qkv_before_attn (kv , cu_seqlens = cu_seqlens_k )
518
+
519
+ output = _torch_fixedlen_kvpacked_attn (q , kv , dropout , softmax_scale , causal , key_padding_mask )
520
+
521
+ return pack_output_after_attn (output , cu_seqlens_q , packed_length )
522
+
523
+
524
+ def _torch_varlen_qkvpacked_attn (
525
+ qkv : torch .Tensor ,
526
+ cu_seqlens ,
527
+ max_seqlen , # pylint: disable=W0613
528
+ dropout ,
529
+ softmax_scale = None ,
530
+ causal = False ,
531
+ key_padding_mask = None ,
532
+ ):
533
+
534
+ packed_length = qkv .size (dim = 1 )
535
+ qkv = unpack_qkv_before_attn (qkv , cu_seqlens = cu_seqlens )
536
+
537
+ output = _torch_fixedlen_qkvpacked_attn (qkv , dropout , softmax_scale , causal , key_padding_mask )
538
+
539
+ return pack_output_after_attn (output , cu_seqlens , packed_length )
540
+
541
+
542
+ def _torch_varlen_kvpacked_attn (
543
+ q : torch .Tensor ,
544
+ kv : torch .Tensor ,
545
+ cu_seqlens_q ,
546
+ cu_seqlens_k ,
547
+ max_seqlen_q , # pylint: disable=W0613
548
+ max_seqlen_k , # pylint: disable=W0613
549
+ dropout ,
550
+ softmax_scale = None ,
551
+ causal = False ,
552
+ key_padding_mask = None ,
553
+ ):
554
+
555
+ packed_length = q .size (dim = 1 )
556
+
557
+ q = unpack_qkv_before_attn (q , cu_seqlens = cu_seqlens_q )
558
+ kv = unpack_qkv_before_attn (kv , cu_seqlens = cu_seqlens_k )
559
+
560
+ output = _torch_fixedlen_kvpacked_attn (q , kv , dropout , softmax_scale , causal , key_padding_mask )
561
+
562
+ return pack_output_after_attn (output , cu_seqlens_q , packed_length )
563
+
564
+
421
565
@auto_wrap_distributed_attention
422
566
class SelfAttention (nn .Module ):
423
567
"""Implements scaled dot-product attention with optional softmax scaling.
0 commit comments