Skip to content

Commit d619af8

Browse files
committed
[Misc] Code clean up
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 493768e commit d619af8

File tree

6 files changed

+269
-437
lines changed

6 files changed

+269
-437
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
18+
from tests.ut.base import TestBase
19+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
20+
21+
22+
class TestAttentionMaskBuilder(TestBase):
23+
24+
def test_init_attention_mask_builder(self):
25+
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024, dtype=torch.float16)
26+
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
27+
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype, torch.float16)

vllm_ascend/attention/attention.py

Lines changed: 2 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
3636

3737
from vllm_ascend.ascend_config import get_ascend_config
38+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
3839
from vllm_ascend.ops.cache import concat_and_cache_mla
3940
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
4041
enable_custom_op, is_310p, nd_to_nz_2d)
@@ -44,108 +45,6 @@
4445
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
4546

4647

47-
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
48-
# Construct lower triangle matrix.
49-
mask_flag = torch.tril(
50-
torch.ones((max_seq_len, max_seq_len),
51-
dtype=torch.bool)).view(max_seq_len, max_seq_len)
52-
# Create upper triangle matrix used to mark mask positions.
53-
mask_flag = ~mask_flag
54-
# Currently for fp16 dtype, the mask value should be set to -inf.
55-
# TODO: Eliminate this part in the future.
56-
if mask_value is None:
57-
if dtype == torch.float16:
58-
mask_value = torch.finfo(torch.float32).min
59-
else:
60-
mask_value = 1
61-
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
62-
mask_flag, mask_value).to(dtype)
63-
return attn_mask
64-
65-
66-
class AttentionMaskBuilder:
67-
68-
def __init__(self, attn_mask: torch.Tensor):
69-
self._seq_len_cached = attn_mask.shape[0]
70-
self.attn_mask_cache = attn_mask
71-
self.splitfuse_mask_value = -10000
72-
73-
@classmethod
74-
def initialize_from_len(cls,
75-
max_seq_len: int,
76-
dtype: torch.dtype = torch.float16,
77-
mask_value: Optional[int] = None):
78-
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
79-
80-
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
81-
device: torch.device):
82-
if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype:
83-
self._seq_len_cached = seqlen
84-
self.attn_mask_cache = generate_attn_mask(seqlen, dtype)
85-
if self.attn_mask_cache.device != device:
86-
self.attn_mask_cache = self.attn_mask_cache.to(device)
87-
88-
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
89-
device: torch.device):
90-
self.update_attn_cache(max_seq_len, dtype, device)
91-
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
92-
93-
def get_decode_attn_mask(
94-
self,
95-
input_lengths: torch.tensor,
96-
max_s: int,
97-
dtype: torch.dtype,
98-
device: torch.device,
99-
):
100-
self.update_attn_cache(max_s, dtype, device)
101-
return (self.attn_mask_cache.index_select(
102-
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
103-
104-
def get_splitfuse_attn_mask(
105-
self,
106-
seq_lens,
107-
query_lens,
108-
position,
109-
dtype,
110-
device,
111-
) -> torch.Tensor:
112-
max_seq_len = max(seq_lens, default=0)
113-
if max_seq_len <= self._seq_len_cached:
114-
self.update_attn_cache(max_seq_len, dtype, device)
115-
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
116-
# is not the same. Fix this in the future when kernel is ready.
117-
if self.attn_mask_cache.numel(
118-
) > 1 and self.attn_mask_cache[0][1] > 0:
119-
attn_mask = self.get_attn_mask( # type: ignore
120-
max_seq_len, dtype, device)
121-
attn_mask *= -10000
122-
else:
123-
attn_mask = self.attn_mask_cache
124-
return torch.index_select(attn_mask, dim=0,
125-
index=position)[:, :max_seq_len]
126-
total_q_len = sum(query_lens)
127-
attn_mask = torch.zeros((total_q_len, max_seq_len),
128-
dtype=dtype,
129-
device="cpu")
130-
131-
current_row = 0
132-
for i in range(len(query_lens)):
133-
seq_len = seq_lens[i]
134-
q_len = query_lens[i]
135-
context_len = seq_len - q_len
136-
137-
assert context_len >= 0
138-
attn_mask[current_row:current_row + q_len,
139-
context_len:] = self.splitfuse_mask_value
140-
right_tensor = attn_mask[current_row:current_row + q_len,
141-
context_len:seq_len]
142-
right_tensor.masked_fill_(
143-
right_tensor.tril() == self.splitfuse_mask_value, 0)
144-
current_row += q_len
145-
146-
return attn_mask.to(device, non_blocking=True)
147-
148-
14948
class AscendAttentionBackend(AttentionBackend):
15049

15150
@staticmethod
@@ -524,7 +423,7 @@ def __init__(self, input_builder: "ModelInputForNPUBuilder"):
524423
self.compress_mask = None
525424
self.chunk_mask = None
526425
if AscendMetadataBuilder._attn_mask_builder is None:
527-
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
426+
AscendMetadataBuilder._attn_mask_builder = AttentionMaskBuilder(
528427
128, self.input_builder.runner.model_config.dtype)
529428

530429
def _add_seq_group(
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import torch
16+
17+
18+
def _generate_attn_mask(max_seq_len, dtype):
19+
# Construct lower triangle matrix.
20+
mask_flag = torch.tril(
21+
torch.ones((max_seq_len, max_seq_len),
22+
dtype=torch.bool)).view(max_seq_len, max_seq_len)
23+
# Create upper triangle matrix used to mark mask positions.
24+
mask_flag = ~mask_flag
25+
# Currently for fp16 dtype, the mask value should be set to -inf.
26+
# TODO: Eliminate this part in the future.
27+
if dtype == torch.float16:
28+
mask_value = torch.finfo(torch.float32).min
29+
else:
30+
mask_value = 1
31+
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
32+
mask_flag, mask_value).to(dtype)
33+
return attn_mask
34+
35+
36+
class AttentionMaskBuilder:
37+
38+
def __init__(
39+
self,
40+
max_seq_len: int,
41+
dtype: torch.dtype,
42+
):
43+
attn_mask = _generate_attn_mask(max_seq_len, dtype)
44+
45+
self._seq_len_cached = attn_mask.shape[0]
46+
self.attn_mask_cache = attn_mask
47+
self.splitfuse_mask_value = -10000
48+
49+
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
50+
device: torch.device):
51+
self._update_attn_cache(max_seq_len, dtype, device)
52+
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
53+
54+
def get_splitfuse_attn_mask(
55+
self,
56+
seq_lens,
57+
query_lens,
58+
position,
59+
dtype,
60+
device,
61+
) -> torch.Tensor:
62+
max_seq_len = max(seq_lens, default=0)
63+
if max_seq_len <= self._seq_len_cached:
64+
self._update_attn_cache(max_seq_len, dtype, device)
65+
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
66+
# is not the same. Fix this in the future when kernel is ready.
67+
if self.attn_mask_cache.numel(
68+
) > 1 and self.attn_mask_cache[0][1] > 0:
69+
attn_mask = self.get_attn_mask( # type: ignore
70+
max_seq_len, dtype, device)
71+
attn_mask *= -10000
72+
else:
73+
attn_mask = self.attn_mask_cache
74+
return torch.index_select(attn_mask, dim=0,
75+
index=position)[:, :max_seq_len]
76+
total_q_len = sum(query_lens)
77+
attn_mask = torch.zeros((total_q_len, max_seq_len),
78+
dtype=dtype,
79+
device="cpu")
80+
current_row = 0
81+
for i in range(len(query_lens)):
82+
seq_len = seq_lens[i]
83+
q_len = query_lens[i]
84+
context_len = seq_len - q_len
85+
86+
assert context_len >= 0
87+
attn_mask[current_row:current_row + q_len,
88+
context_len:] = self.splitfuse_mask_value
89+
right_tensor = attn_mask[current_row:current_row + q_len,
90+
context_len:seq_len]
91+
right_tensor.masked_fill_(
92+
right_tensor.tril() == self.splitfuse_mask_value, 0)
93+
current_row += q_len
94+
95+
return attn_mask.to(device, non_blocking=True)
96+
97+
def _update_attn_cache(self, seqlen: int, dtype: torch.dtype,
98+
device: torch.device):
99+
if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype:
100+
self._seq_len_cached = seqlen
101+
self.attn_mask_cache = _generate_attn_mask(seqlen, dtype)
102+
if self.attn_mask_cache.device != device:
103+
self.attn_mask_cache = self.attn_mask_cache.to(device)

vllm_ascend/worker/eagle_proposer_v1.py

Lines changed: 3 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1515
from vllm.v1.sample.metadata import SamplingMetadata
1616

17-
from vllm_ascend.attention.attention import AttentionMaskBuilder
17+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
1818
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1919

2020
logger = init_logger(__name__)
@@ -74,8 +74,8 @@ def __init__(self,
7474
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
7575
self.attn_mask_len = min(self.model_config.max_model_len,
7676
int(mask_len))
77-
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
78-
self.attn_mask_len, self.dtype)
77+
self.attn_mask_builder = AttentionMaskBuilder(self.attn_mask_len,
78+
self.dtype)
7979

8080
def _make_attention_mask(
8181
self,
@@ -384,46 +384,3 @@ def prepare_eagle_input_sequential(out_tensor: torch.Tensor,
384384
(target_indices < end_pos) & \
385385
(offset_tensor < num_tokens)
386386
out_tensor[target_indices[mask]] = values_to_store[mask]
387-
388-
389-
# NOTE(woosuk): Currently, the below code is not used and we always use argmax
390-
# to sample the draft tokens. We will use this after we find a way to manage
391-
# the draft prob tensor.
392-
# Refer to https://github.com/vllm-project/vllm/pull/16899 for the details.
393-
# FIXME(woosuk): The logic here is duplicated with the main sampling code.
394-
# We should refactor this to reuse the same sampling implementation.
395-
def compute_probs_and_sample_next_token(
396-
logits: torch.Tensor,
397-
sampling_metadata: SamplingMetadata,
398-
) -> tuple[torch.Tensor, torch.Tensor]:
399-
if sampling_metadata.all_greedy:
400-
# For greedy requests, draft_probs is not used in rejection sampling.
401-
# Therefore, we can just return the logits.
402-
probs = logits
403-
next_token_ids = logits.argmax(dim=-1)
404-
return next_token_ids, probs
405-
406-
is_greedy = sampling_metadata.temperature == -1
407-
temperature = torch.where(is_greedy, 1.0, sampling_metadata.temperature)
408-
logits.div_(temperature.view(-1, 1))
409-
probs = logits.softmax(dim=-1, dtype=torch.float32)
410-
411-
# NOTE(woosuk): Currently, we ignore most of the sampling parameters in
412-
# generating the draft tokens. We only use the temperature. While this
413-
# could degrade the acceptance rate, it does not affect the distribution
414-
# of the generated tokens after rejection sampling.
415-
416-
# TODO(woosuk): Consider seeds.
417-
q = torch.empty_like(probs)
418-
q.exponential_()
419-
# NOTE(woosuk): We shouldn't use `probs.div_(q)` because the draft_probs
420-
# will be used later for rejection sampling.
421-
next_token_ids = probs.div(q).argmax(dim=-1).view(-1)
422-
if not sampling_metadata.all_random:
423-
greedy_token_ids = probs.argmax(dim=-1)
424-
next_token_ids = torch.where(
425-
is_greedy,
426-
greedy_token_ids,
427-
next_token_ids,
428-
)
429-
return next_token_ids, probs

0 commit comments

Comments
 (0)