Skip to content

Commit c412e3b

Browse files
committed
fix pylint-check
1 parent c29a748 commit c412e3b

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

mindnlp/transformers/integrations/npu_flash_attention.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
#
13
# Licensed under the Apache License, Version 2.0 (the "License");
24
# you may not use this file except in compliance with the License.
35
# You may obtain a copy of the License at
@@ -10,13 +12,19 @@
1012
# See the License for the specific language governing permissions and
1113
# limitations under the License.
1214

15+
"""
16+
FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
17+
Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
18+
"""
19+
20+
1321
import os
1422

1523
import math
24+
from typing import Optional, Tuple
1625
import mindspore
1726
from mindspore.ops import flash_attention_score
1827
from mindspore import nn
19-
from typing import Optional, Tuple
2028
from mindnlp.core import ops
2129

2230

@@ -25,7 +33,7 @@
2533
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
2634
DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3
2735

28-
SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE))
36+
SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=str(DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE)))
2937
if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]:
3038
raise ValueError(
3139
"Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) "
@@ -55,15 +63,15 @@ def bprop(self, input, indices, out, dout):
5563
assert dout.ndim >= 2
5664
other_shape = dout.shape[1:]
5765
grad_output = dout
58-
66+
5967
grad_flat = grad_output.reshape(grad_output.shape[0], -1)
6068
grad_shape = (input.shape[0], grad_flat.shape[1])
6169
grad_input = ops.zeros(grad_shape, grad_flat.dtype)
62-
70+
6371
indices_expanded = ops.expand_dims(indices, -1)
6472
indices_expanded = ops.broadcast_to(indices_expanded, (-1, grad_flat.shape[1]))
6573
grad_input.scatter_(0, indices_expanded, grad_flat)
66-
74+
6775
return grad_input.reshape(input.shape[0], *other_shape), None
6876

6977

mindnlp/transformers/modeling_flash_attention_utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# coding=utf-8
2-
# Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
1+
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
54
# you may not use this file except in compliance with the License.
@@ -13,23 +12,24 @@
1312
# See the License for the specific language governing permissions and
1413
# limitations under the License.
1514

16-
import os
15+
"""This module provides utilities for flash attention in Transformers models."""
16+
1717

18+
import os
1819
import inspect
19-
import mindspore
2020
from typing import Optional, Tuple
21+
import mindspore
2122
from mindnlp.core import ops
2223
from ..utils import logging
23-
24-
logger = logging.get_logger(__name__)
25-
flash_attn_func = None
26-
2724
from .integrations.npu_flash_attention import index_first_axis, pad_input, unpad_input
2825
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
2926
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
3027

3128

32-
if flash_attn_func:
29+
logger = logging.get_logger(__name__)
30+
31+
32+
if flash_attn_func is not None:
3333
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
3434

3535

@@ -285,7 +285,7 @@ def _flash_attention_forward(
285285
else:
286286
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
287287
causal = is_causal and query_length != 1
288-
288+
289289
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
290290
use_sliding_windows = (
291291
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
@@ -299,7 +299,7 @@ def _flash_attention_forward(
299299

300300
if softcap is not None:
301301
flash_kwargs["softcap"] = softcap
302-
302+
303303
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
304304
query_states, key_states, value_states = fa_peft_integration_check(
305305
query_states, key_states, value_states, target_dtype
@@ -312,7 +312,7 @@ def _flash_attention_forward(
312312
)
313313
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
314314
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
315-
315+
316316
attn_output_unpad = flash_attn_varlen_func(
317317
query_states,
318318
key_states,
@@ -327,7 +327,7 @@ def _flash_attention_forward(
327327
**flash_kwargs,
328328
)
329329
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
330-
330+
331331
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
332332
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
333333
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach

mindnlp/transformers/models/whisper/modeling_whisper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def forward(
460460
causal_mask = attention_mask
461461
if attention_mask is not None: # no matter the length, we just slice it
462462
causal_mask = attention_mask[:, : key_states.shape[-2]]
463-
463+
464464
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
465465
# therefore the input hidden states gets silently casted in float32. Hence, we need
466466
# cast them back in the correct dtype just to be sure everything works as expected.

0 commit comments

Comments
 (0)