Skip to content

Commit c845131

Browse files
committed
[Misc] Code clean up
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent dd22ac3 commit c845131

File tree

5 files changed

+220
-245
lines changed

5 files changed

+220
-245
lines changed

vllm_ascend/attention/attention.py

Lines changed: 1 addition & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
3636

3737
from vllm_ascend.ascend_config import get_ascend_config
38+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
3839
from vllm_ascend.ops.cache import concat_and_cache_mla
3940
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
4041
enable_custom_op, is_310p, nd_to_nz_2d)
@@ -44,108 +45,6 @@
4445
_ALLOWED_NUM_QUERIES_PER_KV = [32, 64, 128]
4546

4647

47-
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
48-
# Construct lower triangle matrix.
49-
mask_flag = torch.tril(
50-
torch.ones((max_seq_len, max_seq_len),
51-
dtype=torch.bool)).view(max_seq_len, max_seq_len)
52-
# Create upper triangle matrix used to mark mask positions.
53-
mask_flag = ~mask_flag
54-
# Currently for fp16 dtype, the mask value should be set to -inf.
55-
# TODO: Eliminate this part in the future.
56-
if mask_value is None:
57-
if dtype == torch.float16:
58-
mask_value = torch.finfo(torch.float32).min
59-
else:
60-
mask_value = 1
61-
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
62-
mask_flag, mask_value).to(dtype)
63-
return attn_mask
64-
65-
66-
class AttentionMaskBuilder:
67-
68-
def __init__(self, attn_mask: torch.Tensor):
69-
self._seq_len_cached = attn_mask.shape[0]
70-
self.attn_mask_cache = attn_mask
71-
self.splitfuse_mask_value = -10000
72-
73-
@classmethod
74-
def initialize_from_len(cls,
75-
max_seq_len: int,
76-
dtype: torch.dtype = torch.float16,
77-
mask_value: Optional[int] = None):
78-
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
79-
80-
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
81-
device: torch.device):
82-
if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype:
83-
self._seq_len_cached = seqlen
84-
self.attn_mask_cache = generate_attn_mask(seqlen, dtype)
85-
if self.attn_mask_cache.device != device:
86-
self.attn_mask_cache = self.attn_mask_cache.to(device)
87-
88-
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
89-
device: torch.device):
90-
self.update_attn_cache(max_seq_len, dtype, device)
91-
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
92-
93-
def get_decode_attn_mask(
94-
self,
95-
input_lengths: torch.tensor,
96-
max_s: int,
97-
dtype: torch.dtype,
98-
device: torch.device,
99-
):
100-
self.update_attn_cache(max_s, dtype, device)
101-
return (self.attn_mask_cache.index_select(
102-
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
103-
104-
def get_splitfuse_attn_mask(
105-
self,
106-
seq_lens,
107-
query_lens,
108-
position,
109-
dtype,
110-
device,
111-
) -> torch.Tensor:
112-
max_seq_len = max(seq_lens, default=0)
113-
if max_seq_len <= self._seq_len_cached:
114-
self.update_attn_cache(max_seq_len, dtype, device)
115-
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
116-
# is not the same. Fix this in the future when kernel is ready.
117-
if self.attn_mask_cache.numel(
118-
) > 1 and self.attn_mask_cache[0][1] > 0:
119-
attn_mask = self.get_attn_mask( # type: ignore
120-
max_seq_len, dtype, device)
121-
attn_mask *= -10000
122-
else:
123-
attn_mask = self.attn_mask_cache
124-
return torch.index_select(attn_mask, dim=0,
125-
index=position)[:, :max_seq_len]
126-
total_q_len = sum(query_lens)
127-
attn_mask = torch.zeros((total_q_len, max_seq_len),
128-
dtype=dtype,
129-
device="cpu")
130-
131-
current_row = 0
132-
for i in range(len(query_lens)):
133-
seq_len = seq_lens[i]
134-
q_len = query_lens[i]
135-
context_len = seq_len - q_len
136-
137-
assert context_len >= 0
138-
attn_mask[current_row:current_row + q_len,
139-
context_len:] = self.splitfuse_mask_value
140-
right_tensor = attn_mask[current_row:current_row + q_len,
141-
context_len:seq_len]
142-
right_tensor.masked_fill_(
143-
right_tensor.tril() == self.splitfuse_mask_value, 0)
144-
current_row += q_len
145-
146-
return attn_mask.to(device, non_blocking=True)
147-
148-
14948
class AscendAttentionBackend(AttentionBackend):
15049

15150
@staticmethod
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from typing import Optional
2+
3+
import torch
4+
5+
6+
def generate_attn_mask(max_seq_len: int, dtype=torch.float16, mask_value=None):
7+
# Construct lower triangle matrix.
8+
mask_flag = torch.tril(
9+
torch.ones((max_seq_len, max_seq_len),
10+
dtype=torch.bool)).view(max_seq_len, max_seq_len)
11+
# Create upper triangle matrix used to mark mask positions.
12+
mask_flag = ~mask_flag
13+
# Currently for fp16 dtype, the mask value should be set to -inf.
14+
# TODO: Eliminate this part in the future.
15+
if mask_value is None:
16+
if dtype == torch.float16:
17+
mask_value = torch.finfo(torch.float32).min
18+
else:
19+
mask_value = 1
20+
attn_mask = torch.masked_fill(torch.zeros(size=(max_seq_len, max_seq_len)),
21+
mask_flag, mask_value).to(dtype)
22+
return attn_mask
23+
24+
25+
class AttentionMaskBuilder:
26+
27+
def __init__(self, attn_mask: torch.Tensor):
28+
self._seq_len_cached = attn_mask.shape[0]
29+
self.attn_mask_cache = attn_mask
30+
self.splitfuse_mask_value = -10000
31+
32+
@classmethod
33+
def initialize_from_len(cls,
34+
max_seq_len: int,
35+
dtype: torch.dtype = torch.float16,
36+
mask_value: Optional[int] = None):
37+
return cls(generate_attn_mask(max_seq_len, dtype, mask_value))
38+
39+
def update_attn_cache(self, seqlen: int, dtype: torch.dtype,
40+
device: torch.device):
41+
if seqlen > self._seq_len_cached or self.attn_mask_cache.dtype != dtype:
42+
self._seq_len_cached = seqlen
43+
self.attn_mask_cache = generate_attn_mask(seqlen, dtype)
44+
if self.attn_mask_cache.device != device:
45+
self.attn_mask_cache = self.attn_mask_cache.to(device)
46+
47+
def get_attn_mask(self, max_seq_len: int, dtype: torch.dtype,
48+
device: torch.device):
49+
self.update_attn_cache(max_seq_len, dtype, device)
50+
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous()
51+
52+
def get_decode_attn_mask(
53+
self,
54+
input_lengths: torch.tensor,
55+
max_s: int,
56+
dtype: torch.dtype,
57+
device: torch.device,
58+
):
59+
self.update_attn_cache(max_s, dtype, device)
60+
return (self.attn_mask_cache.index_select(
61+
0, input_lengths)[:, :max_s].view(-1, 1, max_s).contiguous())
62+
63+
def get_splitfuse_attn_mask(
64+
self,
65+
seq_lens,
66+
query_lens,
67+
position,
68+
dtype,
69+
device,
70+
) -> torch.Tensor:
71+
max_seq_len = max(seq_lens, default=0)
72+
if max_seq_len <= self._seq_len_cached:
73+
self.update_attn_cache(max_seq_len, dtype, device)
74+
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
75+
# is not the same. Fix this in the future when kernel is ready.
76+
if self.attn_mask_cache.numel(
77+
) > 1 and self.attn_mask_cache[0][1] > 0:
78+
attn_mask = self.get_attn_mask( # type: ignore
79+
max_seq_len, dtype, device)
80+
attn_mask *= -10000
81+
else:
82+
attn_mask = self.attn_mask_cache
83+
return torch.index_select(attn_mask, dim=0,
84+
index=position)[:, :max_seq_len]
85+
total_q_len = sum(query_lens)
86+
attn_mask = torch.zeros((total_q_len, max_seq_len),
87+
dtype=dtype,
88+
device="cpu")
89+
90+
current_row = 0
91+
for i in range(len(query_lens)):
92+
seq_len = seq_lens[i]
93+
q_len = query_lens[i]
94+
context_len = seq_len - q_len
95+
96+
assert context_len >= 0
97+
attn_mask[current_row:current_row + q_len,
98+
context_len:] = self.splitfuse_mask_value
99+
right_tensor = attn_mask[current_row:current_row + q_len,
100+
context_len:seq_len]
101+
right_tensor.masked_fill_(
102+
right_tensor.tril() == self.splitfuse_mask_value, 0)
103+
current_row += q_len
104+
105+
return attn_mask.to(device, non_blocking=True)

vllm_ascend/worker/eagle_proposer_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1515
from vllm.v1.sample.metadata import SamplingMetadata
1616

17-
from vllm_ascend.attention.attention import AttentionMaskBuilder
17+
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
1818
from vllm_ascend.attention.attention_v1 import AscendAttentionState
1919

2020
logger = init_logger(__name__)

0 commit comments

Comments
 (0)