Skip to content

Commit b526478

Browse files
gshtrasqli88
andauthored
Aiter mla cherrypick (#543)
* Compatible patch for latest AITER(05/07/2025) Signed-off-by: Qiang Li <qiang.li2@amd.com> * yapf Signed-off-by: Qiang Li <qiang.li2@amd.com> --------- Signed-off-by: Qiang Li <qiang.li2@amd.com> Co-authored-by: Qiang Li <qiang.li2@amd.com>
1 parent 166d0ef commit b526478

File tree

4 files changed

+54
-23
lines changed

4 files changed

+54
-23
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,9 +1245,9 @@ def _compute_prefill_context(
12451245

12461246
attn_output, attn_softmax_lse = \
12471247
self._flash_attn_varlen_diff_headdims(
1248-
q=q,
1249-
k=k,
1250-
v=v,
1248+
q,
1249+
k,
1250+
v,
12511251
cu_seqlens_q=prefill_metadata.query_start_loc,
12521252
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
12531253
max_seqlen_q=prefill_metadata.max_query_len,
@@ -1299,9 +1299,9 @@ def _forward_prefill(
12991299
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
13001300

13011301
output = self._flash_attn_varlen_diff_headdims(
1302-
q=q,
1303-
k=k,
1304-
v=v,
1302+
q,
1303+
k,
1304+
v,
13051305
cu_seqlens_q=prefill_metadata.query_start_loc,
13061306
cu_seqlens_k=prefill_metadata.query_start_loc,
13071307
max_seqlen_q=prefill_metadata.max_prefill_seq_len,

vllm/attention/backends/rocm_aiter_mla.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]:
5353

5454
@dataclass
5555
class AiterMLAMetadata(MLACommonMetadata):
56-
# The following 4 tensors are for current version of AITER MLA
56+
# The following 5 tensors are for current version of AITER MLA
5757
block_table_bound: Optional[torch.Tensor] = None
5858
# The indptr of the paged kv cache, shape: [batch_size + 1]
5959
paged_kv_indptr: Optional[torch.Tensor] = None
@@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
6363
# the paged kv cache, shape: [batch_size]
6464
paged_kv_last_page_lens: Optional[torch.Tensor] = None
6565

66+
# This is just to make new AITER MLA API work
67+
# -- MTP support is not added yet.
68+
qo_indptr: Optional[torch.Tensor] = None
69+
6670
@property
6771
def prefill_metadata(self):
6872
prefill_metadata = super().prefill_metadata
@@ -74,6 +78,7 @@ def prefill_metadata(self):
7478
prefill_metadata\
7579
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
7680
prefill_metadata.block_table_bound = self.block_table_bound
81+
prefill_metadata.qo_indptr = self.qo_indptr
7782

7883
# update the cache
7984
self._cached_prefill_metadata = self.__class__(
@@ -93,6 +98,7 @@ def decode_metadata(self):
9398
decode_metadata\
9499
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
95100
decode_metadata.block_table_bound = self.block_table_bound
101+
decode_metadata.qo_indptr = self.qo_indptr
96102

97103
# update the cache
98104
self._cached_decode_metadata = self.__class__(
@@ -136,6 +142,7 @@ def prepare(self):
136142
self.paged_kv_indptr: list[int] = [0]
137143
self.paged_kv_last_page_lens: list[int] = []
138144
self.total_blocks = 0
145+
self.qo_indptr: list[int] = [0]
139146

140147
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
141148
prefix_cache_hit: bool):
@@ -210,6 +217,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int):
210217
self.paged_kv_indices.extend(block_table[:block_table_bound])
211218
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
212219
block_table_bound)
220+
self.qo_indptr.append(self.qo_indptr[-1] + 1)
213221

214222
last_page_len = seq_len % self.block_size
215223
if last_page_len == 0:
@@ -228,6 +236,8 @@ def build(self, seq_lens: list[int], query_lens: list[int],
228236
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
229237
cuda_graph_pad_size)
230238
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
239+
last_qo_indptr = self.qo_indptr[-1]
240+
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)
231241

232242
# For current version of AITER MLA
233243
if len(self.paged_kv_indptr) > 0:
@@ -247,16 +257,22 @@ def build(self, seq_lens: list[int], query_lens: list[int],
247257
1,
248258
device=device,
249259
dtype=torch.int)
260+
261+
qo_indptr = torch.tensor(self.qo_indptr,
262+
device=device,
263+
dtype=torch.int)
250264
else:
251265
paged_kv_indices_tensor = None
252266
paged_kv_indptr_tensor = None
253267
paged_kv_last_page_lens_tensor = None
254268
block_table_bound_tensor = None
269+
qo_indptr = None
255270

256271
metadata.paged_kv_indptr = paged_kv_indptr_tensor
257272
metadata.paged_kv_indices = paged_kv_indices_tensor
258273
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
259274
metadata.block_table_bound = block_table_bound_tensor
275+
metadata.qo_indptr = qo_indptr
260276

261277
return metadata
262278

@@ -265,21 +281,25 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
265281

266282
@contextmanager
267283
def graph_capture(self, max_batch_size: int):
268-
kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata(
269-
max_batch_size=max_batch_size,
270-
block_size=self.runner.block_size,
271-
max_block_per_batch=self.runner.get_max_block_per_batch(),
272-
device=self.runner.device)
284+
kv_indices, kv_indptr, last_page_lens, qo_indptr = \
285+
get_aiter_mla_metadata(
286+
max_batch_size=max_batch_size,
287+
block_size=self.runner.block_size,
288+
max_block_per_batch=\
289+
self.runner.get_max_block_per_batch(),
290+
device=self.runner.device)
273291
self._paged_kv_indices_tensor = kv_indices
274292
self._paged_kv_indptr_tensor = kv_indptr
275293
self._paged_kv_last_page_lens_tensor = last_page_lens
294+
self._qo_indptr_tensor = qo_indptr
276295

277296
with super().graph_capture(max_batch_size):
278297
yield
279298

280299
del self._paged_kv_indices_tensor
281300
del self._paged_kv_indptr_tensor
282301
del self._paged_kv_last_page_lens_tensor
302+
del self._qo_indptr_tensor
283303

284304
def graph_capture_get_metadata_for_batch(
285305
self,
@@ -293,10 +313,12 @@ def graph_capture_get_metadata_for_batch(
293313
paged_kv_indices = self._paged_kv_indices_tensor
294314
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
295315
batch_size]
316+
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]
296317

297318
metadata.paged_kv_indptr = paged_kv_indptr
298319
metadata.paged_kv_indices = paged_kv_indices
299320
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
321+
metadata.qo_indptr = qo_indptr
300322

301323
return metadata
302324

@@ -313,6 +335,7 @@ def get_graph_input_buffers(self,
313335
input_buffers[
314336
"paged_kv_last_page_lens"] = attn_metadata.\
315337
decode_metadata.paged_kv_last_page_lens
338+
input_buffers['qo_indptr'] = attn_metadata.qo_indptr
316339

317340
return input_buffers
318341

@@ -332,6 +355,8 @@ def prepare_graph_input_buffers(self,
332355
input_buffers["paged_kv_last_page_lens"].copy_(
333356
attn_metadata.decode_metadata.paged_kv_last_page_lens,
334357
non_blocking=True)
358+
input_buffers["qo_indptr"].copy_(
359+
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)
335360

336361

337362
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
@@ -372,11 +397,9 @@ def _flash_attn_varlen_diff_headdims(
372397
softmax_scale: float, return_softmax_lse: bool,
373398
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
374399
output = self.flash_attn_varlen_func(
375-
q=q,
376-
k=k,
377-
v=v,
378-
softmax_scale=softmax_scale,
379-
return_lse=return_softmax_lse,
400+
q,
401+
k,
402+
v,
380403
**kwargs,
381404
)
382405

@@ -396,7 +419,7 @@ def _forward_decode(
396419
B = q_nope.shape[0]
397420

398421
q = torch.cat([q_nope, q_pe], dim=-1)
399-
o = torch.zeros(B,
422+
o = torch.empty(B,
400423
self.num_heads,
401424
self.kv_lora_rank,
402425
dtype=q.dtype,
@@ -405,6 +428,8 @@ def _forward_decode(
405428
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
406429

407430
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
431+
attn_metadata.qo_indptr,
432+
attn_metadata.max_query_len,
408433
attn_metadata.paged_kv_indptr,
409434
attn_metadata.paged_kv_indices,
410435
attn_metadata.paged_kv_last_page_lens)

vllm/attention/ops/rocm_aiter_mla.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
1717
paged_kv_last_page_lens = torch.full((max_batch_size, ),
1818
block_size,
1919
dtype=torch.int32)
20-
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
20+
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
21+
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
2122

2223

2324
def aiter_mla_decode_fwd(
2425
q: torch.Tensor,
2526
kv_buffer: torch.Tensor,
2627
o: torch.Tensor,
2728
sm_scale: float,
29+
qo_indptr: torch.Tensor,
30+
max_seqlen_qo: int,
2831
kv_indptr: Optional[torch.Tensor] = None,
2932
kv_indices: Optional[torch.Tensor] = None,
3033
kv_last_page_lens: Optional[torch.Tensor] = None,
@@ -35,8 +38,10 @@ def aiter_mla_decode_fwd(
3538
mla_decode_fwd(q,
3639
kv_buffer.view(-1, 1, 1, q.shape[-1]),
3740
o,
41+
qo_indptr,
3842
kv_indptr,
3943
kv_indices,
4044
kv_last_page_lens,
45+
max_seqlen_qo,
4146
sm_scale=sm_scale,
4247
logit_cap=logit_cap)

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,11 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
129129

130130
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
131131
sorted_weight_buf, sorted_expert_ids,
132-
num_valid_ids, topk, w1_scale.view(local_E, -1),
133-
w2_scale.view(local_E, -1),
134-
a1_scale.t().contiguous(), *block_shape,
135-
smooth_scale)
132+
num_valid_ids, topk,
133+
a1_scale.t().contiguous(),
134+
w1_scale.view(local_E, -1),
135+
w2_scale.view(local_E,
136+
-1), *block_shape, smooth_scale)
136137

137138
return out_asm
138139

0 commit comments

Comments
 (0)