1
- # coding=utf-8
2
- # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
1
+ # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
3
2
#
4
3
# Licensed under the Apache License, Version 2.0 (the "License");
5
4
# you may not use this file except in compliance with the License.
13
12
# See the License for the specific language governing permissions and
14
13
# limitations under the License.
15
14
16
- import os
15
+ """This module provides utilities for flash attention in Transformers models."""
16
+
17
17
18
+ import os
18
19
import inspect
19
- import mindspore
20
20
from typing import Optional , Tuple
21
+ import mindspore
21
22
from mindnlp .core import ops
22
23
from ..utils import logging
23
-
24
- logger = logging .get_logger (__name__ )
25
- flash_attn_func = None
26
-
27
24
from .integrations .npu_flash_attention import index_first_axis , pad_input , unpad_input
28
25
from .integrations .npu_flash_attention import npu_flash_attn_func as flash_attn_func
29
26
from .integrations .npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
30
27
31
28
32
- if flash_attn_func :
29
+ logger = logging .get_logger (__name__ )
30
+
31
+
32
+ if flash_attn_func is not None :
33
33
_flash_supports_window_size = "window_size" in list (inspect .signature (flash_attn_func ).parameters )
34
34
35
35
@@ -285,7 +285,7 @@ def _flash_attention_forward(
285
285
else :
286
286
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
287
287
causal = is_causal and query_length != 1
288
-
288
+
289
289
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
290
290
use_sliding_windows = (
291
291
_flash_supports_window_size and sliding_window is not None and key_states .shape [1 ] > sliding_window
@@ -299,7 +299,7 @@ def _flash_attention_forward(
299
299
300
300
if softcap is not None :
301
301
flash_kwargs ["softcap" ] = softcap
302
-
302
+
303
303
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
304
304
query_states , key_states , value_states = fa_peft_integration_check (
305
305
query_states , key_states , value_states , target_dtype
@@ -312,7 +312,7 @@ def _flash_attention_forward(
312
312
)
313
313
cu_seqlens_q , cu_seqlens_k = cu_seq_lens
314
314
max_seqlen_in_batch_q , max_seqlen_in_batch_k = max_seq_lens
315
-
315
+
316
316
attn_output_unpad = flash_attn_varlen_func (
317
317
query_states ,
318
318
key_states ,
@@ -327,7 +327,7 @@ def _flash_attention_forward(
327
327
** flash_kwargs ,
328
328
)
329
329
attn_output = pad_input (attn_output_unpad , indices_q , batch_size , query_length )
330
-
330
+
331
331
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
332
332
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
333
333
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
0 commit comments