Skip to content

Commit 0cf349a

Browse files
committed
[Misc] Code clean up
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 493768e commit 0cf349a

File tree

6 files changed

+349
-437
lines changed

6 files changed

+349
-437
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
18+
from tests.ut.base import TestBase
19+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
20+
21+
22+
class TestAttentionMaskBuilder(TestBase):
23+
24+
def test_init_attention_mask_builder(self):
25+
# generate attention_mask_builder with float16
26+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
27+
dtype=torch.float16)
28+
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
29+
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
30+
torch.float16)
31+
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
32+
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
33+
(1024, 1024))
34+
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
35+
torch.tensor(float("-inf"), dtype=torch.float16))
36+
37+
# generate attention_mask_builder with int8
38+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=512,
39+
dtype=torch.int8)
40+
self.assertEqual(attention_mask_builder._seq_len_cached, 512)
41+
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
42+
torch.int8)
43+
self.assertEqual(attention_mask_builder.splitfuse_mask_value, -10000)
44+
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
45+
(512, 512))
46+
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
47+
torch.tensor(1, dtype=torch.int8))
48+
49+
def test_get_attn_mask(self):
50+
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
51+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
52+
dtype=torch.float16)
53+
attn_mask = attention_mask_builder.get_attn_mask(
54+
max_seq_len=512, dtype=torch.float16, device=torch.device("cpu"))
55+
self.assertEqual(attn_mask.shape, (512, 512))
56+
self.assertEqual(attn_mask[0][-1],
57+
torch.tensor(float("-inf"), dtype=torch.float16))
58+
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
59+
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
60+
(1024, 1024))
61+
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
62+
torch.tensor(float("-inf"), dtype=torch.float16))
63+
64+
# if the len is greater than max_seq_len, the attn_mask_cache will be updated
65+
attn_mask = attention_mask_builder.get_attn_mask(
66+
max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu"))
67+
self.assertEqual(attn_mask.shape, (2048, 2048))
68+
self.assertEqual(attn_mask[0][-1],
69+
torch.tensor(float("-inf"), dtype=torch.float16))
70+
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
71+
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
72+
(2048, 2048))
73+
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
74+
torch.tensor(float("-inf"), dtype=torch.float16))
75+
76+
def test_get_splitfuse_attn_mask(self):
77+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
78+
dtype=torch.float16)
79+
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
80+
seq_lens=[512],
81+
query_lens=[512],
82+
position=torch.tensor([0]),
83+
dtype=torch.float16,
84+
device=torch.device("cpu"),
85+
)
86+
self.assertEqual(attn_mask.shape, (1, 512))
87+
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
88+
89+
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
90+
seq_lens=[2048],
91+
query_lens=[1024],
92+
position=torch.tensor([0]),
93+
dtype=torch.float16,
94+
device=torch.device("cpu"),
95+
)
96+
self.assertEqual(attn_mask.shape, (1024, 2048))
97+
98+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
99+
dtype=torch.int8)
100+
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
101+
seq_lens=[512],
102+
query_lens=[512],
103+
position=torch.tensor([0]),
104+
dtype=torch.int8,
105+
device=torch.device("cpu"),
106+
)
107+
self.assertEqual(attn_mask.shape, (1, 512))

vllm_ascend/attention/attention.py

Lines changed: 2 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
3636

3737
from vllm_ascend.ascend_config import get_ascend_config
38+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
3839
from vllm_ascend.ops.cache import concat_and_cache_mla
3940
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
4041
enable_custom_op, is_310p, nd_to_nz_2d)
@@ -44,108 +45,6 @@
4445
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
4546

4647

47-
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
48-
# Construct lower triangle matrix.
49-
mask_flag = torch.tril(
50-
torch.ones((max_seq_len, max_seq_len),
51-
dtype=torch.bool)).view(max_seq_len, max_seq_len)
52-
# Create upper triangle matrix used to mark mask positions.
53-
mask_flag = ~mask_flag
54-
# Currently for fp16 dtype, the mask value should be set to -inf.
55-
# TODO: Eliminate this part in the future.
56-
if mask_value is None:
57-
if dtype == torch.float16:
58-
mask_value = torch.finfo(torch.float32).min
59-
else:
60-
mask_value = 1
61-
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
62-
mask_flag, mask_value).to(dtype)
63-
return attn_mask
64-
65-
66-
class AttentionMaskBuilder:
67-
68-
def __init__(self, attn_mask: torch.Tensor):
69-
self._seq_len_cached = attn_mask.shape[0]
70-
self.attn_mask_cache = attn_mask
71-
self.splitfuse_mask_value = -10000
72-
73-
@classmethod
74-
def initialize_from_len(cls,
75-
max_seq_len: int,
76-
dtype: torch.dtype = torch.float16,
77-
mask_value: Optional[int] = None):
78-
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
79-
80-
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
81-
device: torch.device):
82-
if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype:
83-
self._seq_len_cached = seqlen
84-
self.attn_mask_cache = generate_attn_mask(seqlen, dtype)
85-
if self.attn_mask_cache.device != device:
86-
self.attn_mask_cache = self.attn_mask_cache.to(device)
87-
88-
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
89-
device: torch.device):
90-
self.update_attn_cache(max_seq_len, dtype, device)
91-
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
92-
93-
def get_decode_attn_mask(
94-
self,
95-
input_lengths: torch.tensor,
96-
max_s: int,
97-
dtype: torch.dtype,
98-
device: torch.device,
99-
):
100-
self.update_attn_cache(max_s, dtype, device)
101-
return (self.attn_mask_cache.index_select(
102-
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
103-
104-
def get_splitfuse_attn_mask(
105-
self,
106-
seq_lens,
107-
query_lens,
108-
position,
109-
dtype,
110-
device,
111-
) -> torch.Tensor:
112-
max_seq_len = max(seq_lens, default=0)
113-
if max_seq_len <= self._seq_len_cached:
114-
self.update_attn_cache(max_seq_len, dtype, device)
115-
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
116-
# is not the same. Fix this in the future when kernel is ready.
117-
if self.attn_mask_cache.numel(
118-
) > 1 and self.attn_mask_cache[0][1] > 0:
119-
attn_mask = self.get_attn_mask( # type: ignore
120-
max_seq_len, dtype, device)
121-
attn_mask *= -10000
122-
else:
123-
attn_mask = self.attn_mask_cache
124-
return torch.index_select(attn_mask, dim=0,
125-
index=position)[:, :max_seq_len]
126-
total_q_len = sum(query_lens)
127-
attn_mask = torch.zeros((total_q_len, max_seq_len),
128-
dtype=dtype,
129-
device="cpu")
130-
131-
current_row = 0
132-
for i in range(len(query_lens)):
133-
seq_len = seq_lens[i]
134-
q_len = query_lens[i]
135-
context_len = seq_len - q_len
136-
137-
assert context_len >= 0
138-
attn_mask[current_row:current_row + q_len,
139-
context_len:] = self.splitfuse_mask_value
140-
right_tensor = attn_mask[current_row:current_row + q_len,
141-
context_len:seq_len]
142-
right_tensor.masked_fill_(
143-
right_tensor.tril() == self.splitfuse_mask_value, 0)
144-
current_row += q_len
145-
146-
return attn_mask.to(device, non_blocking=True)
147-
148-
14948
class AscendAttentionBackend(AttentionBackend):
15049

15150
@staticmethod
@@ -524,7 +423,7 @@ def __init__(self, input_builder: "ModelInputForNPUBuilder"):
524423
self.compress_mask = None
525424
self.chunk_mask = None
526425
if AscendMetadataBuilder._attn_mask_builder is None:
527-
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
426+
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder(
528427
128, self.input_builder.runner.model_config.dtype)
529428

530429
def _add_seq_group(
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import torch
16+
17+
18+
def _generate_attn_mask(max_seq_len, dtype):
19+
# Construct lower triangle matrix.
20+
mask_flag = torch.tril(
21+
torch.ones((max_seq_len, max_seq_len),
22+
dtype=torch.bool)).view(max_seq_len, max_seq_len)
23+
# Create upper triangle matrix used to mark mask positions.
24+
mask_flag = ~mask_flag
25+
# Currently for fp16 dtype, the mask value should be set to -inf.
26+
# TODO: Eliminate this part in the future.
27+
if dtype == torch.float16:
28+
mask_value = torch.finfo(torch.float32).min
29+
else:
30+
mask_value = 1
31+
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
32+
mask_flag, mask_value).to(dtype)
33+
return attn_mask
34+
35+
36+
class AttentionMaskBuilder:
37+
38+
def __init__(
39+
self,
40+
max_seq_len: int,
41+
dtype: torch.dtype,
42+
):
43+
attn_mask = _generate_attn_mask(max_seq_len, dtype)
44+
45+
self._seq_len_cached = attn_mask.shape[0]
46+
self.attn_mask_cache = attn_mask
47+
self.splitfuse_mask_value = -10000
48+
49+
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
50+
device: torch.device):
51+
self._update_attn_cache(max_seq_len, dtype, device)
52+
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
53+
54+
def get_splitfuse_attn_mask(
55+
self,
56+
seq_lens,
57+
query_lens,
58+
position,
59+
dtype,
60+
device,
61+
) -> torch.Tensor:
62+
max_seq_len = max(seq_lens, default=0)
63+
if max_seq_len <= self._seq_len_cached:
64+
self._update_attn_cache(max_seq_len, dtype, device)
65+
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
66+
# is not the same. Fix this in the future when kernel is ready.
67+
if self.attn_mask_cache.numel(
68+
) > 1 and self.attn_mask_cache[0][1] > 0:
69+
attn_mask = self.get_attn_mask( # type: ignore
70+
max_seq_len, dtype, device)
71+
attn_mask *= -10000
72+
else:
73+
attn_mask = self.attn_mask_cache
74+
return torch.index_select(attn_mask, dim=0,
75+
index=position)[:, :max_seq_len]
76+
total_q_len = sum(query_lens)
77+
attn_mask = torch.zeros((total_q_len, max_seq_len),
78+
dtype=dtype,
79+
device="cpu")
80+
current_row = 0
81+
for i in range(len(query_lens)):
82+
seq_len = seq_lens[i]
83+
q_len = query_lens[i]
84+
context_len = seq_len - q_len
85+
86+
assert context_len >= 0
87+
attn_mask[current_row:current_row + q_len,
88+
context_len:] = self.splitfuse_mask_value
89+
right_tensor = attn_mask[current_row:current_row + q_len,
90+
context_len:seq_len]
91+
right_tensor.masked_fill_(
92+
right_tensor.tril() == self.splitfuse_mask_value, 0)
93+
current_row += q_len
94+
95+
return attn_mask.to(device, non_blocking=True)
96+
97+
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
98+
device: torch.device):
99+
if seqlen > self._seq_len_cached:
100+
self._seq_len_cached = seqlen
101+
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
102+
if self.attn_mask_cache.device != device:
103+
self.attn_mask_cache = self.attn_mask_cache.to(device)

0 commit comments

Comments
 (0)