Skip to content

Commit 57b7cd5

Browse files
authored
feat(op): support varlen npu flash attention (#209)
1 parent 7cd091c commit 57b7cd5

File tree

3 files changed

+363
-160
lines changed

3 files changed

+363
-160
lines changed

internlm/core/parallel/comm/isp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,12 @@ def auto_wrap_distributed_attention(cls: nn.Module) -> Callable[[bool, Any, floa
900900
def _attetion_constructor(
901901
local_attn_cls: type, causal=False, softmax_scale=None, attention_dropout=0.0
902902
) -> nn.Module:
903-
if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
903+
try:
904+
tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp")
905+
except AttributeError:
906+
tp_mode = "mtp"
907+
908+
if tp_mode != "isp":
904909
return local_attn_cls(causal, softmax_scale, attention_dropout)
905910
else:
906911
return DistributedAttention(

internlm/model/ops/attention.py

Lines changed: 184 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _nyi_attn(func_name, *args, **kwargs): # pylint: disable=W0613
7979

8080
def _flash_float32_compatibility_wrapper(input_idxs: Tuple, flash_func: Callable, *args, **kwargs):
8181
if gpc.config.model.dtype is torch.float32:
82-
inputs = (args[idx] for idx in input_idxs)
82+
inputs = [args[idx] for idx in input_idxs]
8383
input_dtype = inputs[0].dtype
8484
other_args = [args[idx] for idx in range(len(inputs), len(args))]
8585

@@ -194,10 +194,35 @@ def _flash_fixedlen_qkvsplited_attn(q, k, v, dropout_p=0.0, softmax_scale=None,
194194

195195

196196
# npu flash attention operators
197-
# TODO: should we add _flash_float32_compatibility_wrapper support for npu.
197+
def _npu_varlen_qkvsplited_attn(
198+
q: torch.Tensor,
199+
k: torch.Tensor,
200+
v: torch.Tensor,
201+
cu_seqlens_q,
202+
cu_seqlens_k,
203+
max_seqlen_q, # pylint: disable=W0613
204+
max_seqlen_k, # pylint: disable=W0613
205+
dropout_p=0.0,
206+
softmax_scale=None,
207+
causal=False,
208+
):
209+
return _flash_float32_compatibility_wrapper(
210+
(0, 1, 2),
211+
_npu_varlen_qkvsplited_func,
212+
q,
213+
k,
214+
v,
215+
cu_seqlens_q,
216+
cu_seqlens_k,
217+
max_seqlen_q,
218+
max_seqlen_k,
219+
dropout_p,
220+
softmax_scale,
221+
causal,
222+
)
198223

199224

200-
def _npu_varlen_qkvsplited_attn(
225+
def _npu_varlen_qkvsplited_func(
201226
q: torch.Tensor,
202227
k: torch.Tensor,
203228
v: torch.Tensor,
@@ -208,17 +233,32 @@ def _npu_varlen_qkvsplited_attn(
208233
dropout_p=0.0,
209234
softmax_scale=None,
210235
causal=False,
236+
use_fixlen=False,
211237
):
212-
# TODO: support npu native varlen flash attention
238+
"""Support Huawei Ascend's torch_npu flash attention.
239+
Tested version:
240+
torch: 2.1.0+cpu
241+
torch_npu: 2.1.0.post3+git7c4136d
242+
cann: 8.0.RC1.alpha003
243+
"""
213244
packed_length = q.size(dim=1)
245+
softmax_scale = softmax_scale or 1.0 / math.sqrt(q.shape[-1])
214246

215-
q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
216-
k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k)
217-
v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k)
247+
if use_fixlen:
218248

219-
output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal)
249+
q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
250+
k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k)
251+
v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k)
220252

221-
return pack_output_after_attn(output, cu_seqlens_q, packed_length)
253+
output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal)
254+
255+
output = pack_output_after_attn(output, cu_seqlens_q, packed_length)
256+
else:
257+
output = _npu_fused_varlen_qkvsplited_attn(
258+
q, k, v, dropout_p, softmax_scale, causal, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k
259+
)
260+
261+
return output
222262

223263

224264
def _npu_fixedlen_qkvsplited_attn(
@@ -236,6 +276,7 @@ def _npu_fixedlen_qkvsplited_attn(
236276
q, k, v = q.squeeze(dim=2), k.squeeze(dim=2), v.squeeze(dim=2)
237277

238278
_, seqlen, n_head, _ = q.shape
279+
sparse_mode = 0
239280
attention_mask = torch.triu(torch.ones(seqlen, seqlen, device=get_current_device()), 1).bool()
240281

241282
return _origin_npu_fixedlen_qkvsplited_func(
@@ -247,25 +288,71 @@ def _npu_fixedlen_qkvsplited_attn(
247288
pse=None,
248289
atten_mask=attention_mask,
249290
scale=softmax_scale,
250-
sparse_mode=0, # If necessary, expose the interface
291+
sparse_mode=sparse_mode, # If necessary, expose the interface
251292
pre_tockens=seqlen, # Used for sparse calculations, representing the left boundary of the slides window
252293
next_tockens=0, # If necessary, expose the interface
253294
keep_prob=1 - dropout_p,
254295
inner_precise=0, # If necessary, expose the interface
255-
)
296+
)[0]
256297

257298

258-
def _npu_varlen_qkvpacked_attn(
259-
qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613
299+
def _npu_fused_varlen_qkvsplited_attn(
300+
q: torch.Tensor,
301+
k: torch.Tensor,
302+
v: torch.Tensor,
303+
dropout_p: float,
304+
softmax_scale=None,
305+
causal=False,
306+
max_seqlen_q: int = None,
307+
max_seqlen_k: int = None,
308+
cu_seqlens_q=None,
309+
cu_seqlens_kv=None,
310+
deterministic=False,
260311
):
261-
# TODO: support npu native varlen flash attention
262-
packed_length = qkv.size(dim=1)
312+
assert causal is True
313+
assert q.dtype in (torch.bfloat16, torch.float16)
263314

264-
qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens)
315+
if len(q.shape) == 4: # [1, packedseqlen, n_head, headdim]
316+
q, k, v = q.squeeze(dim=0), k.squeeze(dim=0), v.squeeze(dim=0)
265317

266-
output = _npu_fixedlen_qkvpacked_attn(qkv, dropout_p, softmax_scale, causal)
318+
S, N = max(max_seqlen_q, max_seqlen_k), q.shape[1]
319+
device = get_current_device()
320+
sparse_mode = 0
267321

268-
return pack_output_after_attn(output, cu_seqlens, packed_length)
322+
if max_seqlen_k > 2048 and max_seqlen_q > 2048:
323+
sparse_mode = 2
324+
max_seqlen_k = 2048
325+
max_seqlen_q = 2048
326+
327+
attention_mask = torch.triu(torch.ones(max_seqlen_q, max_seqlen_k, device=device), 1).bool()
328+
cu_seqlens_q = cu_seqlens_q[1:].tolist()
329+
cu_seqlens_kv = cu_seqlens_kv[1:].tolist()
330+
331+
return _origin_npu_fixedlen_qkvsplited_func(
332+
query=q,
333+
key=k,
334+
value=v,
335+
head_num=N,
336+
input_layout="TND",
337+
pse=None,
338+
atten_mask=attention_mask,
339+
scale=softmax_scale,
340+
sparse_mode=sparse_mode,
341+
pre_tockens=S, # Used for sparse calculations, representing the left boundary of the slides window
342+
next_tockens=0,
343+
keep_prob=1 - dropout_p,
344+
inner_precise=0 if not deterministic else 2,
345+
actual_seq_kvlen=cu_seqlens_kv,
346+
actual_seq_qlen=cu_seqlens_q,
347+
)[0].unsqueeze(dim=0)
348+
349+
350+
def _npu_varlen_qkvpacked_attn(
351+
qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613
352+
):
353+
# TODO: support npu native varlen flash attention
354+
q, k, v = qkv.unbind(dim=2)
355+
return _npu_varlen_qkvsplited_attn(q, k, v, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal)
269356

270357

271358
def _npu_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False):
@@ -285,14 +372,20 @@ def _npu_varlen_kvpacked_attn(
285372
causal=False,
286373
):
287374
# TODO: support npu native varlen flash attention
288-
packed_length = q.size(dim=1)
289-
290-
q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
291-
kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k)
292-
293-
output = _npu_fixedlen_kvpacked_attn(q, kv, dropout_p, softmax_scale, causal)
294-
295-
return pack_output_after_attn(output, cu_seqlens_q, packed_length)
375+
k, v = kv.unbind(dim=2)
376+
k, v = k.squeeze(dim=2), v.squeeze(dim=2)
377+
return _npu_varlen_qkvsplited_attn(
378+
q,
379+
k,
380+
v,
381+
cu_seqlens_q,
382+
cu_seqlens_k,
383+
max_seqlen_q,
384+
max_seqlen_k,
385+
dropout_p,
386+
softmax_scale,
387+
causal,
388+
)
296389

297390

298391
def _npu_fixedlen_kvpacked_attn(q: torch.Tensor, kv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False):
@@ -335,12 +428,6 @@ def _deeplink_fixedlen_qkvsplited_attn(*args, **kwargs):
335428

336429

337430
# torch attention operators
338-
339-
340-
def _torch_varlen_qkvpacked_attn(*args, **kwargs):
341-
_nyi_attn("_torch_varlen_qkvpacked_attn", *args, **kwargs)
342-
343-
344431
# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py
345432
def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None):
346433
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
@@ -369,10 +456,6 @@ def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=Non
369456
return output
370457

371458

372-
def _torch_varlen_kvpacked_attn(*args, **kwargs):
373-
_nyi_attn("_torch_varlen_kvpacked_attn", *args, **kwargs)
374-
375-
376459
# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py
377460
def _torch_fixedlen_kvpacked_attn(
378461
q: torch.Tensor, kv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None
@@ -407,17 +490,78 @@ def _torch_fixedlen_kvpacked_attn(
407490
return output
408491

409492

410-
def _torch_varlen_qkvsplited_attn(*args, **kwargs):
411-
_nyi_attn("_torch_varlen_qkvsplited_attn", *args, **kwargs)
412-
413-
414493
def _torch_fixedlen_qkvsplited_attn(
415494
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None
416495
):
417496
kv = torch.stack([k, v], dim=2)
418497
return _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask)
419498

420499

500+
def _torch_varlen_qkvsplited_attn(
501+
q: torch.Tensor,
502+
k: torch.Tensor,
503+
v: torch.Tensor,
504+
cu_seqlens_q,
505+
cu_seqlens_k,
506+
max_seqlen_q, # pylint: disable=W0613
507+
max_seqlen_k, # pylint: disable=W0613
508+
dropout,
509+
softmax_scale=None,
510+
causal=False,
511+
key_padding_mask=None,
512+
):
513+
kv = torch.stack([k, v], dim=2)
514+
packed_length = q.size(dim=1)
515+
516+
q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
517+
kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k)
518+
519+
output = _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask)
520+
521+
return pack_output_after_attn(output, cu_seqlens_q, packed_length)
522+
523+
524+
def _torch_varlen_qkvpacked_attn(
525+
qkv: torch.Tensor,
526+
cu_seqlens,
527+
max_seqlen, # pylint: disable=W0613
528+
dropout,
529+
softmax_scale=None,
530+
causal=False,
531+
key_padding_mask=None,
532+
):
533+
534+
packed_length = qkv.size(dim=1)
535+
qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens)
536+
537+
output = _torch_fixedlen_qkvpacked_attn(qkv, dropout, softmax_scale, causal, key_padding_mask)
538+
539+
return pack_output_after_attn(output, cu_seqlens, packed_length)
540+
541+
542+
def _torch_varlen_kvpacked_attn(
543+
q: torch.Tensor,
544+
kv: torch.Tensor,
545+
cu_seqlens_q,
546+
cu_seqlens_k,
547+
max_seqlen_q, # pylint: disable=W0613
548+
max_seqlen_k, # pylint: disable=W0613
549+
dropout,
550+
softmax_scale=None,
551+
causal=False,
552+
key_padding_mask=None,
553+
):
554+
555+
packed_length = q.size(dim=1)
556+
557+
q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
558+
kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k)
559+
560+
output = _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask)
561+
562+
return pack_output_after_attn(output, cu_seqlens_q, packed_length)
563+
564+
421565
@auto_wrap_distributed_attention
422566
class SelfAttention(nn.Module):
423567
"""Implements scaled dot-product attention with optional softmax scaling.

0 commit comments

Comments
 (0)