Skip to content

Commit a1aafc8

Browse files
authored
[ROCm][FEAT] Enable Full Graph Mode in AITER MLA V1 Attn Backend (Decode Phase only) (#20254)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
1 parent 139508a commit a1aafc8

File tree

1 file changed

+59
-31
lines changed

1 file changed

+59
-31
lines changed

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from dataclasses import dataclass
5-
from typing import Any, Optional
5+
from typing import Any, ClassVar, Optional
66

77
import torch
88

@@ -63,63 +63,91 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
6363

6464

6565
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
66+
full_cudagraph_supported: ClassVar[bool] = True # decode only
6667

6768
def __init__(self, runner, kv_cache_spec: AttentionSpec,
6869
block_table: BlockTable):
6970
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
7071
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
7172
"only supports block size 1."
7273

73-
def _get_paged_kv_tensors(
74-
self, block_table: torch.Tensor,
75-
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
74+
# Preparing persistent buffers
75+
if self.runner.full_cuda_graph:
76+
device = self.runner.device
77+
max_num_reqs = self.runner.max_num_reqs
78+
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
79+
dtype=torch.int32,
80+
device=device)
81+
self.paged_kv_indices = torch.zeros(
82+
block_table.get_device_tensor().numel(
83+
), # max num pages possible
84+
dtype=torch.int32,
85+
device=device)
86+
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
87+
dtype=torch.int32,
88+
device=device)
89+
90+
self.qo_indptr = torch.arange(0,
91+
max_num_reqs + 1,
92+
dtype=torch.int32,
93+
device=device)
94+
95+
def _build_decode(self, block_table_tensor: torch.Tensor,
96+
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
7697
page_size = self.kv_cache_spec.block_size
7798
block_table_bounds = (seq_lens + page_size - 1) // page_size
7899
device = self.runner.device
79100

80-
mask = (torch.arange(block_table.size(1),
81-
dtype=block_table.dtype,
101+
mask = (torch.arange(block_table_tensor.size(1),
102+
dtype=block_table_tensor.dtype,
82103
device=device).unsqueeze(0)
83104
< block_table_bounds.unsqueeze(1))
84-
paged_kv_indices = block_table[mask]
105+
paged_kv_indices = block_table_tensor[mask]
106+
107+
paged_kv_last_page_len = seq_lens % page_size
108+
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
109+
page_size, paged_kv_last_page_len)
85110

86111
paged_kv_indptr = torch.cat([
87112
torch.zeros(1, dtype=block_table_bounds.dtype, device=device),
88113
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
89114
])
90115

91-
paged_kv_last_page_len = seq_lens % page_size
92-
paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0,
93-
page_size, paged_kv_last_page_len)
94-
qo_indptr = torch.arange(0,
95-
self._num_decodes + 1,
96-
step=1,
97-
dtype=torch.int32,
98-
device=device)
99-
100-
return (
101-
paged_kv_indices,
102-
paged_kv_indptr,
103-
paged_kv_last_page_len,
104-
qo_indptr,
105-
)
116+
if self.runner.full_cuda_graph:
117+
num_reqs = self._num_decodes
106118

107-
def _build_decode(self, block_table_tensor: torch.Tensor,
108-
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
119+
num_actual_pages = paged_kv_indices.size(0)
120+
121+
self.paged_kv_indices[:num_actual_pages].copy_(paged_kv_indices,
122+
non_blocking=True)
123+
self.paged_kv_indices[num_actual_pages:].fill_(-1)
124+
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
125+
126+
self.paged_kv_indptr[:1 + num_reqs].copy_(paged_kv_indptr,
127+
non_blocking=True)
128+
self.paged_kv_indptr[1 + num_reqs:].fill_(paged_kv_indptr[-1])
129+
paged_kv_indptr = self.paged_kv_indptr[:1 + num_reqs]
130+
131+
self.paged_kv_last_page_len[:num_reqs].copy_(
132+
paged_kv_last_page_len, non_blocking=True)
133+
self.paged_kv_last_page_len[num_reqs:].fill_(1)
134+
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
135+
136+
qo_indptr = self.qo_indptr[:1 + num_reqs]
109137

110-
(
111-
paged_kv_indices,
112-
paged_kv_indptr,
113-
paged_last_page_len,
114-
qo_indptr,
115-
) = self._get_paged_kv_tensors(block_table_tensor, seq_lens)
138+
else:
139+
qo_indptr = torch.arange(0,
140+
self._num_decodes + 1,
141+
step=1,
142+
dtype=torch.int32,
143+
device=device)
116144

117145
attn_metadata = AiterMLADecodeMetadata(
118146
block_table=block_table_tensor,
119147
seq_lens=seq_lens,
120148
paged_kv_indptr=paged_kv_indptr,
121149
paged_kv_indices=paged_kv_indices,
122-
paged_kv_last_page_len=paged_last_page_len,
150+
paged_kv_last_page_len=paged_kv_last_page_len,
123151
qo_indptr=qo_indptr)
124152

125153
return attn_metadata

0 commit comments

Comments
 (0)