Skip to content

Commit c29a748

Browse files
committed
[Feat] Add WhisperFlashAttention2
1 parent 895e5c0 commit c29a748

File tree

6 files changed

+748
-2
lines changed

6 files changed

+748
-2
lines changed

mindnlp/core/ops/other.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,13 @@ def einsum(equation, *operands):
556556
return result
557557

558558

559+
# expand_dims
560+
has_expand_dims = hasattr(mindspore.mint, 'expand_dims')
561+
def expand_dims(input, axis):
562+
if use_pyboost() and has_expand_dims:
563+
return mindspore.mint.expand_dims(input, axis)
564+
return ops.expand_dims(input, axis)
565+
559566

560567
# flatten
561568
has_flatten = hasattr(mindspore.mint, 'flatten')

mindnlp/transformers/configuration_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def __init__(self, **kwargs):
342342

343343
# Attention implementation to use, if relevant.
344344
self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
345+
self._attn_implementation_autoset = False
345346

346347
# Drop the transformers version info
347348
self.transformers_version = kwargs.pop("transformers_version", None)
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
13+
import os
14+
15+
import math
16+
import mindspore
17+
from mindspore.ops import flash_attention_score
18+
from mindspore import nn
19+
from typing import Optional, Tuple
20+
from mindnlp.core import ops
21+
22+
23+
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
24+
# Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
25+
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
26+
DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3
27+
28+
SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE))
29+
if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]:
30+
raise ValueError(
31+
"Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) "
32+
"or 3 (down-right aligned causal mask)."
33+
)
34+
35+
36+
def is_npu_fa2_top_left_aligned_causal_mask():
37+
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE
38+
39+
40+
class IndexFirstAxis(nn.Cell):
41+
def __init__(self):
42+
super(IndexFirstAxis, self).__init__()
43+
44+
def construct(self, input: mindspore.Tensor, indices: mindspore.Tensor):
45+
assert input.ndim >= 2
46+
first_axis_dim, other_shape = input.shape[0], input.shape[1:]
47+
input_flat = input.reshape(first_axis_dim, -1)
48+
indices_expanded = ops.expand_dims(indices, -1)
49+
indices_expanded = ops.broadcast_to(indices_expanded, (-1, input_flat.shape[1]))
50+
output_flat = ops.gather(input_flat, 0, indices_expanded)
51+
output = output_flat.reshape(-1, *other_shape)
52+
return output
53+
54+
def bprop(self, input, indices, out, dout):
55+
assert dout.ndim >= 2
56+
other_shape = dout.shape[1:]
57+
grad_output = dout
58+
59+
grad_flat = grad_output.reshape(grad_output.shape[0], -1)
60+
grad_shape = (input.shape[0], grad_flat.shape[1])
61+
grad_input = ops.zeros(grad_shape, grad_flat.dtype)
62+
63+
indices_expanded = ops.expand_dims(indices, -1)
64+
indices_expanded = ops.broadcast_to(indices_expanded, (-1, grad_flat.shape[1]))
65+
grad_input.scatter_(0, indices_expanded, grad_flat)
66+
67+
return grad_input.reshape(input.shape[0], *other_shape), None
68+
69+
70+
index_first_axis = IndexFirstAxis()
71+
72+
73+
class IndexPutFirstAxis(nn.Cell):
74+
def __init__(self):
75+
super(IndexPutFirstAxis, self).__init__()
76+
77+
def construct(self, values: mindspore.Tensor, indices: mindspore.Tensor, first_axis_dim: int):
78+
assert indices.ndim == 1
79+
assert values.ndim >= 2
80+
output = ops.zeros(
81+
(first_axis_dim, *values.shape[1:]),
82+
values.dtype
83+
)
84+
output[indices] = values
85+
return output
86+
87+
def bprop(self, values, indices, first_axis_dim, out, dout):
88+
grad_values = dout[indices]
89+
return grad_values, None, None
90+
91+
92+
index_put_first_axis = IndexPutFirstAxis()
93+
94+
95+
def pad_input(
96+
hidden_states: mindspore.Tensor,
97+
indices: mindspore.Tensor,
98+
batch: int,
99+
seqlen: int
100+
):
101+
"""
102+
Arguments:
103+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
104+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
105+
batch: int, batch size for the padded sequence.
106+
seqlen: int, maximum sequence length for the padded sequence.
107+
Return:
108+
hidden_states: (batch, seqlen, ...)
109+
"""
110+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
111+
return output.reshape(batch, seqlen, *hidden_states.shape[1:])
112+
113+
114+
def unpad_input(
115+
hidden_states: mindspore.Tensor,
116+
attention_mask: mindspore.Tensor,
117+
unused_mask: Optional[mindspore.Tensor] = None,
118+
):
119+
"""
120+
Arguments:
121+
hidden_states: (batch, seqlen, ...)
122+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
123+
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused.
124+
Return:
125+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask.
126+
indices: (total_nnz), the indices of masked tokens from the flattened input sequence.
127+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
128+
max_seqlen_in_batch: int
129+
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask.
130+
"""
131+
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
132+
seqlens_in_batch = all_masks.sum(dim=-1, dtype=mindspore.int32)
133+
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=mindspore.int32)
134+
indices = ops.nonzero(all_masks.flatten(), as_tuple=False).flatten()
135+
max_seqlen_in_batch = seqlens_in_batch.max().item()
136+
cu_seqlens = ops.pad(ops.cumsum(seqlens_in_batch, dim=0, dtype=mindspore.int32), (1, 0))
137+
138+
hidden_states_flat = hidden_states.reshape(-1, *hidden_states.shape[2:])
139+
hidden_states = index_first_axis(hidden_states_flat, indices)
140+
return (
141+
hidden_states,
142+
indices,
143+
cu_seqlens,
144+
max_seqlen_in_batch,
145+
used_seqlens_in_batch,
146+
)
147+
148+
149+
def create_attn_mask(causal: bool, sparse_mode: int) -> Tuple[int, mindspore.Tensor]:
150+
"""
151+
Create a causal mask for the attention scores.
152+
153+
Args:
154+
causal (`bool`):
155+
If `True`, the mask will be causal.
156+
sparse_mode (`bool`):
157+
If `True`, the mask will be top-left
158+
aligned, otherwise it will be bottom-right aligned.
159+
Returns:
160+
`Tuple[bool, mindspore.Tensor]`:
161+
A tuple containing sparse_mode and the mask tensor.
162+
"""
163+
if not causal:
164+
sparse_mode = 0
165+
attn_mask = None
166+
else:
167+
if sparse_mode == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE:
168+
attn_mask = ops.tril(ops.ones((2048, 2048)), diagonal=-1).bool()
169+
else:
170+
attn_mask = ops.triu(ops.ones((2048, 2048)), diagonal=1).bool()
171+
return sparse_mode, attn_mask
172+
173+
174+
def npu_flash_attn_func(
175+
q: mindspore.Tensor,
176+
k: mindspore.Tensor,
177+
v: mindspore.Tensor,
178+
dropout_p: float = 0.0,
179+
softmax_scale: Optional[float] = None,
180+
causal: bool = False,
181+
**kwargs,
182+
):
183+
head_num = q.shape[2]
184+
sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE)
185+
if softmax_scale is None:
186+
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
187+
output = flash_attention_score(
188+
q,
189+
k,
190+
v,
191+
head_num,
192+
keep_prob=1.0 - dropout_p,
193+
scalar_value=softmax_scale,
194+
attn_mask=attn_mask,
195+
input_layout="BSND",
196+
sparse_mode=sparse_mode,
197+
prefix=None,
198+
)
199+
200+
return output
201+
202+
203+
def npu_flash_attn_varlen_func(
204+
q: mindspore.Tensor,
205+
k: mindspore.Tensor,
206+
v: mindspore.Tensor,
207+
cu_seqlens_q: Optional[mindspore.Tensor] = None,
208+
cu_seqlens_k: Optional[mindspore.Tensor] = None,
209+
dropout_p: float = 0.0,
210+
softmax_scale: Optional[float] = None,
211+
causal: bool = False,
212+
**kwargs,
213+
):
214+
head_num = q.shape[1]
215+
sparse_mode, attn_mask = create_attn_mask(causal, SPARSE_MODE)
216+
if softmax_scale is None:
217+
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
218+
219+
output = flash_attention_score(
220+
q,
221+
k,
222+
v,
223+
head_num,
224+
keep_prob=1.0 - dropout_p,
225+
scalar_value=softmax_scale,
226+
attn_mask=attn_mask,
227+
input_layout="TND",
228+
actual_seq_qlen=cu_seqlens_q[1:].asnumpy().tolist(),
229+
actual_seq_kvlen=cu_seqlens_k[1:].asnumpy().tolist(),
230+
sparse_mode=sparse_mode,
231+
prefix=None,
232+
)
233+
234+
return output

0 commit comments

Comments
 (0)