diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 376845d9b40e..ad7376bde644 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1245,9 +1245,9 @@ def _compute_prefill_context( attn_output, attn_softmax_lse = \ self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, + q, + k, + v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i], max_seqlen_q=prefill_metadata.max_query_len, @@ -1299,9 +1299,9 @@ def _forward_prefill( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) output = self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, + q, + k, + v, cu_seqlens_q=prefill_metadata.query_start_loc, cu_seqlens_k=prefill_metadata.query_start_loc, max_seqlen_q=prefill_metadata.max_prefill_seq_len, diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 2984bc1dad64..3d728a0f8ccc 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -53,7 +53,7 @@ def get_state_cls() -> Type["AiterMLAState"]: @dataclass class AiterMLAMetadata(MLACommonMetadata): - # The following 4 tensors are for current version of AITER MLA + # The following 5 tensors are for current version of AITER MLA block_table_bound: Optional[torch.Tensor] = None # The indptr of the paged kv cache, shape: [batch_size + 1] paged_kv_indptr: Optional[torch.Tensor] = None @@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata): # the paged kv cache, shape: [batch_size] paged_kv_last_page_lens: Optional[torch.Tensor] = None + # This is for new AITER MLA API to work + # -- MTP support needs more changes. + qo_indptr: Optional[torch.Tensor] = None + @property def prefill_metadata(self): prefill_metadata = super().prefill_metadata @@ -74,6 +78,7 @@ def prefill_metadata(self): prefill_metadata\ .paged_kv_last_page_lens = self.paged_kv_last_page_lens prefill_metadata.block_table_bound = self.block_table_bound + prefill_metadata.qo_indptr = self.qo_indptr # update the cache self._cached_prefill_metadata = self.__class__( @@ -93,6 +98,7 @@ def decode_metadata(self): decode_metadata\ .paged_kv_last_page_lens = self.paged_kv_last_page_lens decode_metadata.block_table_bound = self.block_table_bound + decode_metadata.qo_indptr = self.qo_indptr # update the cache self._cached_decode_metadata = self.__class__( @@ -136,6 +142,7 @@ def prepare(self): self.paged_kv_indptr: list[int] = [0] self.paged_kv_last_page_lens: list[int] = [] self.total_blocks = 0 + self.qo_indptr: list[int] = [0] def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, prefix_cache_hit: bool): @@ -210,6 +217,7 @@ def _update_paged_kv_tensors(self, block_table: list[int], seq_len: int): self.paged_kv_indices.extend(block_table[:block_table_bound]) self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + block_table_bound) + self.qo_indptr.append(self.qo_indptr[-1] + 1) last_page_len = seq_len % self.block_size if last_page_len == 0: @@ -228,6 +236,8 @@ def build(self, seq_lens: list[int], query_lens: list[int], self.paged_kv_indptr.extend([last_paged_kv_indptr] * cuda_graph_pad_size) self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size) + last_qo_indptr = self.qo_indptr[-1] + self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size) # For current version of AITER MLA if len(self.paged_kv_indptr) > 0: @@ -247,16 +257,32 @@ def build(self, seq_lens: list[int], query_lens: list[int], 1, device=device, dtype=torch.int) + + # This is hardcoded -- MTP is disabled + #num_draft_tokens = 1 + #qo_indptr = torch.arange( + #0, + #(1 + batch_size) * num_draft_tokens, + #step = num_draft_tokens, + #dtype = torch.int, + #device = device + #) + qo_indptr = torch.tensor(self.qo_indptr, + device=device, + dtype=torch.int) + else: paged_kv_indices_tensor = None paged_kv_indptr_tensor = None paged_kv_last_page_lens_tensor = None block_table_bound_tensor = None + qo_indptr = None metadata.paged_kv_indptr = paged_kv_indptr_tensor metadata.paged_kv_indices = paged_kv_indices_tensor metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor metadata.block_table_bound = block_table_bound_tensor + metadata.qo_indptr = qo_indptr return metadata @@ -265,7 +291,7 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]): @contextmanager def graph_capture(self, max_batch_size: int): - kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata( + kv_indices, kv_indptr, last_page_lens, qo_indptr = get_aiter_mla_metadata( max_batch_size=max_batch_size, block_size=self.runner.block_size, max_block_per_batch=self.runner.get_max_block_per_batch(), @@ -273,6 +299,7 @@ def graph_capture(self, max_batch_size: int): self._paged_kv_indices_tensor = kv_indices self._paged_kv_indptr_tensor = kv_indptr self._paged_kv_last_page_lens_tensor = last_page_lens + self._qo_indptr_tensor = qo_indptr with super().graph_capture(max_batch_size): yield @@ -280,6 +307,7 @@ def graph_capture(self, max_batch_size: int): del self._paged_kv_indices_tensor del self._paged_kv_indptr_tensor del self._paged_kv_last_page_lens_tensor + del self._qo_indptr_tensor def graph_capture_get_metadata_for_batch( self, @@ -293,10 +321,12 @@ def graph_capture_get_metadata_for_batch( paged_kv_indices = self._paged_kv_indices_tensor paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[: batch_size] + qo_indptr = self._qo_indptr_tensor[:batch_size + 1] metadata.paged_kv_indptr = paged_kv_indptr metadata.paged_kv_indices = paged_kv_indices metadata.paged_kv_last_page_lens = paged_kv_last_page_lens + metadata.qo_indptr = qo_indptr return metadata @@ -313,6 +343,7 @@ def get_graph_input_buffers(self, input_buffers[ "paged_kv_last_page_lens"] = attn_metadata.\ decode_metadata.paged_kv_last_page_lens + input_buffers['qo_indptr'] = attn_metadata.qo_indptr return input_buffers @@ -332,6 +363,8 @@ def prepare_graph_input_buffers(self, input_buffers["paged_kv_last_page_lens"].copy_( attn_metadata.decode_metadata.paged_kv_last_page_lens, non_blocking=True) + input_buffers["qo_indptr"].copy_( + attn_metadata.decode_metadata.qo_indptr, non_blocking=True) class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): @@ -372,12 +405,10 @@ def _flash_attn_varlen_diff_headdims( softmax_scale: float, return_softmax_lse: bool, **kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]: output = self.flash_attn_varlen_func( - q=q, - k=k, - v=v, - softmax_scale=softmax_scale, - return_lse=return_softmax_lse, - **kwargs, + q, + k, + v, + **kwargs ) return output @@ -396,7 +427,7 @@ def _forward_decode( B = q_nope.shape[0] q = torch.cat([q_nope, q_pe], dim=-1) - o = torch.zeros(B, + o = torch.empty(B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, @@ -404,9 +435,15 @@ def _forward_decode( kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) - aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, - attn_metadata.paged_kv_indptr, - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_lens) + aiter_mla_decode_fwd( + q, + kv_buffer, + o, + self.scale, + attn_metadata.qo_indptr, + attn_metadata.max_query_len, + attn_metadata.paged_kv_indptr, + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_lens) return self._v_up_proj(o) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 1c90f8c19b09..1a4cb6333bff 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -17,7 +17,10 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int, paged_kv_last_page_lens = torch.full((max_batch_size, ), block_size, dtype=torch.int32) - return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens + qo_indptr = torch.zeros(max_batch_size + 1, + dtype=torch.int, + device=device) + return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr def aiter_mla_decode_fwd( @@ -25,6 +28,8 @@ def aiter_mla_decode_fwd( kv_buffer: torch.Tensor, o: torch.Tensor, sm_scale: float, + qo_indptr: torch.Tensor, + max_seqlen_qo: int, kv_indptr: Optional[torch.Tensor] = None, kv_indices: Optional[torch.Tensor] = None, kv_last_page_lens: Optional[torch.Tensor] = None, @@ -35,8 +40,10 @@ def aiter_mla_decode_fwd( mla_decode_fwd(q, kv_buffer.view(-1, 1, 1, q.shape[-1]), o, + qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, + max_seqlen_qo, sm_scale=sm_scale, logit_cap=logit_cap) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 2acb1bd69e43..d3d8245870b0 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -129,9 +129,10 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl( fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids, sorted_weight_buf, sorted_expert_ids, - num_valid_ids, topk, w1_scale.view(local_E, -1), + num_valid_ids, topk, a1_scale.t().contiguous(), + w1_scale.view(local_E, -1), w2_scale.view(local_E, -1), - a1_scale.t().contiguous(), *block_shape, + *block_shape, smooth_scale) return out_asm