From 0bba64d3c45aaefeec5bc9549db2e9e15f74e69c Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Sun, 15 Jun 2025 21:49:47 -0700 Subject: [PATCH 01/13] [V1] Perf optimization for layers with KV reuse Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_skip_prefill.py | 235 +++++++++++++++++++ vllm/envs.py | 4 + vllm/forward_context.py | 3 + vllm/model_executor/models/qwen2.py | 4 +- vllm/v1/attention/backends/cpu_attn.py | 19 +- vllm/v1/attention/backends/flash_attn.py | 29 ++- vllm/v1/attention/backends/flashinfer.py | 13 +- vllm/v1/attention/backends/flex_attention.py | 13 +- vllm/v1/attention/backends/mla/common.py | 12 +- vllm/v1/attention/backends/utils.py | 61 ++++- vllm/v1/worker/gpu_model_runner.py | 87 ++++--- 11 files changed, 432 insertions(+), 48 deletions(-) create mode 100644 tests/v1/e2e/test_kv_sharing_skip_prefill.py diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py new file mode 100644 index 00000000000..8948627950c --- /dev/null +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -0,0 +1,235 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import gc +from collections.abc import Iterable +from typing import Optional, Union + +import pytest +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm import LLM, SamplingParams +from vllm.config import CacheConfig, VllmConfig +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP, + Qwen2Model) +from vllm.model_executor.models.registry import ModelRegistry +from vllm.model_executor.models.utils import (AutoWeightsLoader, + extract_layer_index, + maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from ...utils import fork_new_process_for_each_test + +START_KV_SHARING_LAYER = 10 + + +class Qwen2DecoderLayerWithKVSharing(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + attn_prefix = f"{prefix}.self_attn" + layer_idx = extract_layer_index(prefix) + kv_sharing_target_layer_name = None + + if layer_idx >= START_KV_SHARING_LAYER: + # re-use KV cache from first 5 layers + target_layer_idx = layer_idx % 5 + kv_sharing_target_layer_name = f"{attn_prefix}.attn".replace( + str(layer_idx), str(target_layer_idx)) + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=attn_prefix, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + ) + + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen2ModelWithKVSharing(Qwen2Model): + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + decode_indices = get_forward_context().decode_indices + if decode_indices is None: + decode_indices = torch.arange(positions.size(0), + device=positions.device) + + # Forward with full inputs up to the first layer that shares KV cache + for layer in self.layers[self.start_layer:START_KV_SHARING_LAYER]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if decode_indices is not None: + decode_hidden_states = hidden_states[decode_indices] + decode_positions = positions[decode_indices] + decode_residual = (residual[decode_indices] + if residual is not None else None) + else: + decode_hidden_states = hidden_states + decode_positions = positions + decode_residual = residual + + # Optimization: forward with partial inputs only for last N layers + for layer in self.layers[START_KV_SHARING_LAYER:self.end_layer]: + decode_hidden_states, decode_residual = layer( + decode_positions, + decode_hidden_states, + decode_residual, + ) + + # Merge results back + if decode_hidden_states is not None: + hidden_states[decode_indices] = decode_hidden_states + if residual is not None: + residual[decode_indices] = decode_residual + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class TestQwen2ForCausalLM(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2ModelWithKVSharing( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + decoder_layer_type=Qwen2DecoderLayerWithKVSharing) + self.lm_head = self.model.embed_tokens + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +# TODO: make it work with torch.compile +@fork_new_process_for_each_test +@pytest.mark.parametrize("enforce_eager", [True]) +def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager): + prompt = "What is the capital of France?" + ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM) + sampling_params = SamplingParams(temperature=0.0, max_tokens=40) + single_prompt = [prompt] + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + enforce_eager=enforce_eager) + responses = llm.generate(single_prompt, sampling_params) + ref_output = responses[0].outputs[0].text + + del llm + gc.collect() + torch.cuda.empty_cache() + + m.setenv("VLLM_V1_KV_SHARING_SKIP_PREFILL", "1") + + llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + enforce_eager=enforce_eager) + responses = llm.generate(single_prompt, sampling_params) + output = responses[0].outputs[0].text + assert output == ref_output diff --git a/vllm/envs.py b/vllm/envs.py index 7bff6ade815..ab0184174e4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -139,6 +139,7 @@ VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 + VLLM_V1_KV_SHARING_SKIP_PREFILL: bool = False def get_default_cache_root(): @@ -964,6 +965,9 @@ def get_vllm_port() -> Optional[int]: # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. "VLLM_USE_TRTLLM_DECODE_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), + + "VLLM_V1_KV_SHARING_SKIP_PREFILL": + lambda: os.environ.get("VLLM_V1_KV_SHARING_SKIP_PREFILL", "0") == "1", } # --8<-- [end:env-vars-definition] diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feea..22a751f6cc8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -95,6 +95,7 @@ class ForwardContext: # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None skip_cuda_graphs: bool = False + decode_indices: Optional[torch.Tensor] = None _forward_context: Optional[ForwardContext] = None @@ -116,6 +117,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, + decode_indices: Optional[torch.Tensor] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -141,6 +143,7 @@ def set_forward_context( attn_metadata=attn_metadata, dp_metadata=dp_metadata, skip_cuda_graphs=skip_cuda_graphs, + decode_indices=decode_indices, ) try: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7ef9d248da4..e4b5d674ff6 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -110,6 +110,7 @@ def __init__( prefix: str = "", attn_type: str = AttentionType.DECODER, dual_chunk_attention_config: Optional[dict[str, Any]] = None, + **attn_kwargs, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -171,7 +172,8 @@ def __init__( **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}) + } if dual_chunk_attention_config else {}, + **attn_kwargs) def forward( self, diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index d6270fbf319..d89f115eeca 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -374,11 +374,22 @@ def reorder_batch(self, input_batch: InputBatch, return True - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + decode_only_common_attn_metadata: Optional[ + CommonAttentionMetadata] = None, + ): + if decode_only_common_attn_metadata is not None: + raise NotImplementedError( + "CPU backend does not support decode-only attention yet.") num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len + query_start_loc_np = (common_attn_metadata.query_start_loc_np + if common_attn_metadata.query_start_loc_np + is not None else self.runner.query_start_loc_np) runner = self.runner block_table = self.block_table @@ -390,8 +401,8 @@ def build(self, common_prefix_len: int, ) if num_prompt_req < num_reqs else 0 self.seq_start_loc_np[0] = 0 np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1]) - num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item() - num_decode_tokens = runner.query_start_loc_np[num_reqs].item( + num_prefill_tokens = query_start_loc_np[num_prompt_req].item() + num_decode_tokens = query_start_loc_np[num_reqs].item( ) - num_prefill_tokens slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long() block_table_tensor = block_table.get_device_tensor() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fbc13c06c65..97a7a71f679 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -142,6 +142,8 @@ class LocalAttentionMetadata: local_attn_metadata: Optional[LocalAttentionMetadata] = None + decode_only_attn_metadata: Optional["FlashAttentionMetadata"] = None + def _get_sliding_window_configs( vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: @@ -208,9 +210,19 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.aot_sliding_window: Optional[tuple[int, int]] = None def build( - self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + decode_only_common_attn_metadata: Optional[ + CommonAttentionMetadata] = None, ) -> FlashAttentionMetadata: + decode_only_attn_metadata = None + if decode_only_common_attn_metadata is not None: + decode_only_attn_metadata = self.build( + common_prefix_len=0, # disable cascade attention + common_attn_metadata=decode_only_common_attn_metadata, + ) + num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -268,10 +280,15 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, # for local attention local_attn_metadata = None if self.runner.attention_chunk_size is not None: + query_start_loc_np = (common_attn_metadata.query_start_loc_np + if common_attn_metadata.query_start_loc_np + is not None else + self.runner.query_start_loc_np[:num_reqs + + 1]) seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ virt_block_table_tensor = make_local_attention_virtual_batches( self.runner.attention_chunk_size, - self.runner.query_start_loc_np[:num_reqs + 1], + query_start_loc_np, self.runner.seq_lens_np[:num_reqs], block_table_tensor, self.block_size, @@ -375,6 +392,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_attn_metadata=local_attn_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, + decode_only_attn_metadata=decode_only_attn_metadata, ) return attn_metadata @@ -477,6 +495,11 @@ def forward( # Profiling run. return output + if (self.kv_sharing_target_layer_name is not None + and attn_metadata.decode_only_attn_metadata is not None): + # Override with decode-only attention metadata + attn_metadata = attn_metadata.decode_only_attn_metadata + # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4ae595c976b..b45ef8cab37 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -423,8 +423,17 @@ def _plan(self, attn_metadata: FlashInferMetadata): kv_data_type=attn_metadata.kv_data_type, ) - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + decode_only_common_attn_metadata: Optional[ + CommonAttentionMetadata] = None, + ): + if decode_only_common_attn_metadata is not None: + raise NotImplementedError( + "FlashInfer backend does not support decode-only attention yet." + ) num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index a8c5f464aa3..08cda5f1e58 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -271,8 +271,17 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.kv_cache_spec = kv_cache_spec self.block_table = block_table - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + decode_only_common_attn_metadata: Optional[ + CommonAttentionMetadata] = None, + ): + if decode_only_common_attn_metadata is not None: + raise NotImplementedError( + "FlexAttention backend does not support decode-only " + "attention yet.") num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 970de229e13..2b74d6edfac 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -604,8 +604,16 @@ def build_for_cudagraph_capture( self._num_prefill_tokens = 0 return self.build(0, m) - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + decode_only_common_attn_metadata: Optional[ + CommonAttentionMetadata] = None, + ) -> M: + if decode_only_common_attn_metadata is not None: + raise NotImplementedError( + "MLA backend does not support decode-only attention yet.") num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 88adc32406e..6b94e429271 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -47,6 +47,9 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" + query_start_loc_np: Optional[np.ndarray] = None + """(batch_size + 1,), cpu numpy version of query_start_loc""" + M = TypeVar("M") @@ -56,8 +59,13 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): full_cudagraph_supported: ClassVar[bool] = False @abstractmethod - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + decode_only_common_attn_metadata: Optional[ + CommonAttentionMetadata] = None, + ) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. @@ -351,3 +359,52 @@ def make_local_attention_virtual_batches( return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ block_table_local + +def compute_decode_only_common_attn_metadata( + num_reqs: int, + decode_indices: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens: torch.Tensor, +): + """ + Compute new query related attention data only considering + token positions corresponding to decode_indices. + + Used to skip tokens during prefill for some Attention layers + that re-use KV cache from earlier layers in the model. + """ + # Inputs: + # decode_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(decode_indices, + query_start_loc[1:], + right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc only considering tokens in decode_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty(num_reqs + 1, + device=query_start_loc.device, + dtype=query_start_loc.dtype) + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = num_decode_tokens.max().item() + total_num_decode_tokens = num_decode_tokens.sum().item() + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + # TODO(sarckk): optimize + query_start_loc_np=decode_query_start_loc.cpu().numpy(), + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + ) + return common_attn_metadata diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f3279fa5fa8..5da229e85dc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -47,7 +47,8 @@ is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + compute_decode_only_common_attn_metadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec, @@ -696,16 +697,57 @@ def _prepare_inputs( self.query_start_loc_cpu[num_reqs].item()) query_start_loc = self.query_start_loc[:num_reqs + 1] + query_start_loc_np = self.query_start_loc_np[:num_reqs + 1] seq_lens = self.seq_lens[:num_reqs] common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, + query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, ) + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + + decode_only_common_attn_metadata = None + if envs.VLLM_V1_KV_SHARING_SKIP_PREFILL: + decode_only_common_attn_metadata = ( + compute_decode_only_common_attn_metadata( + num_reqs=num_reqs, + # TODO(sarckk): logits_indices contains tokens for partial + # prefill requests, so we can optimize further by only + # considering tokens if its index is more than or equal to + # input_batch.num_prompt_tokens[req_index], which correspond + # to positions that are required for sampling output tokens + decode_indices=logits_indices, + query_start_loc=query_start_loc, + seq_lens=seq_lens, + )) + attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. @@ -727,6 +769,8 @@ def _prepare_inputs( attn_metadata_i = (builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, + decode_only_common_attn_metadata= + decode_only_common_attn_metadata, )) for layer_name in kv_cache_group_spec.layer_names: @@ -736,30 +780,6 @@ def _prepare_inputs( b.can_run_in_cudagraph(common_attn_metadata) for b in self.attn_metadata_builders) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 - if not use_spec_decode: - # NOTE(woosuk): Due to chunked prefills, the batch may contain - # partial requests. While we should not sample any token - # from these partial requests, we do so for simplicity. - # We will ignore the sampled tokens from the partial requests. - # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 - spec_decode_metadata = None - else: - # Get the number of draft tokens for each request. - # Iterate over the dictionary rather than all requests since not all - # requests have draft tokens. - num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): - req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) - - spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) - logits_indices = spec_decode_metadata.logits_indices - # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) @@ -1358,13 +1378,14 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs, - ): + decode_indices = (logits_indices + if envs.VLLM_V1_KV_SHARING_SKIP_PREFILL else None) + with set_forward_context(attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs, + decode_indices=decode_indices): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( @@ -1961,6 +1982,7 @@ def _dummy_run( attn_metadata = {} query_start_loc = self.query_start_loc[:num_reqs + 1] + query_start_loc_np = self.query_start_loc_np[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. self.seq_lens_np[:num_reqs] = self.max_model_len self.seq_lens_np[num_reqs:] = 0 @@ -1970,6 +1992,7 @@ def _dummy_run( common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, + query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, num_reqs=num_reqs, num_actual_tokens=num_tokens, From dba8e8685f0ca79adcbda22d8fad97ae05f1cc45 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Tue, 1 Jul 2025 19:07:56 -0700 Subject: [PATCH 02/13] Add piecewise cudagraph support + refactor Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_skip_prefill.py | 144 ++++++++++++++----- vllm/compilation/backends.py | 7 +- vllm/compilation/decorators.py | 9 +- vllm/config.py | 2 + vllm/forward_context.py | 2 + vllm/v1/attention/backends/cpu_attn.py | 19 +-- vllm/v1/attention/backends/flash_attn.py | 86 ++++++++--- vllm/v1/attention/backends/flashinfer.py | 13 +- vllm/v1/attention/backends/flex_attention.py | 13 +- vllm/v1/attention/backends/mla/common.py | 12 +- vllm/v1/attention/backends/utils.py | 57 +------- vllm/v1/worker/gpu_model_runner.py | 94 +++++++----- 12 files changed, 266 insertions(+), 192 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index 8948627950c..8364d8d8c90 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -3,7 +3,7 @@ import gc from collections.abc import Iterable -from typing import Optional, Union +from typing import List, Optional, Union import pytest import torch @@ -11,7 +11,11 @@ from transformers import Qwen2Config from vllm import LLM, SamplingParams -from vllm.config import CacheConfig, VllmConfig +from vllm.compilation.backends import set_model_tag +from vllm.compilation.decorators import (skip_torch_compile, + support_torch_compile) +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + VllmConfig) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -52,6 +56,7 @@ def __init__( target_layer_idx = layer_idx % 5 kv_sharing_target_layer_name = f"{attn_prefix}.attn".replace( str(layer_idx), str(target_layer_idx)) + self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -99,8 +104,72 @@ def forward( return hidden_states, residual +@support_torch_compile +class DecoderLayerGroup(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layers: List[nn.Module], + ): + super().__init__() + self.layers = layers + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ): + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + return hidden_states, residual + + +@skip_torch_compile class Qwen2ModelWithKVSharing(Qwen2Model): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[ + nn.Module] = Qwen2DecoderLayerWithKVSharing): + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type, + ) + + with set_model_tag("first_layer_group"): + self.first_layer_group = DecoderLayerGroup( + vllm_config=vllm_config, + prefix=f"{prefix}.first_layer_group", + layers=self.layers[self.start_layer:START_KV_SHARING_LAYER], + ) + + with set_model_tag("second_layer_group"): + self.second_layer_group = DecoderLayerGroup( + vllm_config=vllm_config, + prefix=f"{prefix}.second_layer_group", + layers=self.layers[START_KV_SHARING_LAYER:self.end_layer], + ) + + # Pre-allocate static buffers for CUDA graph + self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.dtype = vllm_config.model_config.dtype + self.device = next(self.parameters()).device + self.hidden_size = vllm_config.model_config.get_hidden_size() + self.residual = torch.zeros((self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) + def forward( self, input_ids: torch.Tensor, @@ -112,46 +181,40 @@ def forward( hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) + residual = None + first_hidden_states, first_residual = self.first_layer_group( + positions, + hidden_states, + residual, # no residual, assume no pipeline parallel + ) decode_indices = get_forward_context().decode_indices if decode_indices is None: decode_indices = torch.arange(positions.size(0), device=positions.device) - - # Forward with full inputs up to the first layer that shares KV cache - for layer in self.layers[self.start_layer:START_KV_SHARING_LAYER]: - hidden_states, residual = layer( - positions, - hidden_states, - residual, - ) - - if decode_indices is not None: - decode_hidden_states = hidden_states[decode_indices] - decode_positions = positions[decode_indices] - decode_residual = (residual[decode_indices] - if residual is not None else None) - else: - decode_hidden_states = hidden_states - decode_positions = positions - decode_residual = residual - - # Optimization: forward with partial inputs only for last N layers - for layer in self.layers[START_KV_SHARING_LAYER:self.end_layer]: - decode_hidden_states, decode_residual = layer( - decode_positions, - decode_hidden_states, - decode_residual, - ) + num_decodes = decode_indices.shape[0] + assert num_decodes >= 1 + assert first_residual is not None + + # CUDA graph expects static tensor addresses + # Copy output of first layer group to second layer group + self.residual[:num_decodes].copy_(first_residual[decode_indices]) + hidden_states[:num_decodes].copy_(first_hidden_states[decode_indices]) + positions[:num_decodes].copy_(positions[decode_indices]) + + second_hidden_states, second_residual = self.second_layer_group( + positions[:num_decodes], + hidden_states[:num_decodes], + self.residual[:num_decodes], + ) # Merge results back - if decode_hidden_states is not None: - hidden_states[decode_indices] = decode_hidden_states - if residual is not None: - residual[decode_indices] = decode_residual + first_hidden_states[decode_indices] = second_hidden_states + if first_residual is not None: + first_residual[decode_indices] = second_residual - hidden_states, _ = self.norm(hidden_states, residual) + hidden_states, _ = self.norm(first_hidden_states, first_residual) return hidden_states @@ -205,20 +268,24 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights) -# TODO: make it work with torch.compile @fork_new_process_for_each_test -@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("enforce_eager", [False, True]) def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager): prompt = "What is the capital of France?" ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM) - sampling_params = SamplingParams(temperature=0.0, max_tokens=40) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) single_prompt = [prompt] + compilation_config = CompilationConfig( + level=CompilationLevel.PIECEWISE + if not enforce_eager else CompilationLevel.NO_COMPILATION, + cudagraph_share_memory_pool=False) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - enforce_eager=enforce_eager) + enforce_eager=enforce_eager, + compilation_config=compilation_config) responses = llm.generate(single_prompt, sampling_params) ref_output = responses[0].outputs[0].text @@ -229,7 +296,8 @@ def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager): m.setenv("VLLM_V1_KV_SHARING_SKIP_PREFILL", "1") llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - enforce_eager=enforce_eager) + enforce_eager=enforce_eager, + compilation_config=compilation_config) responses = llm.generate(single_prompt, sampling_params) output = responses[0].outputs[0].text assert output == ref_output diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5148c289d86..9a029d73018 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -412,8 +412,11 @@ def __init__( # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - global global_graph_pool - if global_graph_pool is None: + if vllm_config.compilation_config.cudagraph_share_memory_pool: + global global_graph_pool + if global_graph_pool is None: + global_graph_pool = current_platform.graph_pool_handle() + else: global_graph_pool = current_platform.graph_pool_handle() # TODO: in the future, if we want to use multiple diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 05e4ca9f08b..61af5001df0 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -23,6 +23,13 @@ _T = TypeVar("_T", bound=type[nn.Module]) +def skip_torch_compile(cls: _T) -> _T: + cls._skip_compile_vllm = True + for base in cls.__bases__: + base._skip_compile_vllm = True + return cls + + @overload def support_torch_compile( *, @@ -156,7 +163,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() + ] or not supports_dynamo() or getattr(self, "_skip_compile_vllm", False) if self.do_not_compile: return compilation_counter.num_models_seen += 1 diff --git a/vllm/config.py b/vllm/config.py index b1f7f9e57a7..9b67da42114 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4114,6 +4114,8 @@ class CompilationConfig: """Sizes to capture cudagraph. - None (default): capture sizes are inferred from vllm config. - list[int]: capture sizes are specified as given.""" + cudagraph_share_memory_pool: bool = True + """Whether to share a single global memory pool for each CUDA graph captured""" cudagraph_copy_inputs: bool = False """Whether to copy input tensors for cudagraph. If the caller can guarantee that the same input buffers diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 22a751f6cc8..57a47b48ace 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -95,7 +95,9 @@ class ForwardContext: # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None skip_cuda_graphs: bool = False + decode_indices: Optional[torch.Tensor] = None + """indices used for decoding""" _forward_context: Optional[ForwardContext] = None diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index d89f115eeca..d6270fbf319 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -374,22 +374,11 @@ def reorder_batch(self, input_batch: InputBatch, return True - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - decode_only_common_attn_metadata: Optional[ - CommonAttentionMetadata] = None, - ): - if decode_only_common_attn_metadata is not None: - raise NotImplementedError( - "CPU backend does not support decode-only attention yet.") + def build(self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - query_start_loc_np = (common_attn_metadata.query_start_loc_np - if common_attn_metadata.query_start_loc_np - is not None else self.runner.query_start_loc_np) runner = self.runner block_table = self.block_table @@ -401,8 +390,8 @@ def build( ) if num_prompt_req < num_reqs else 0 self.seq_start_loc_np[0] = 0 np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1]) - num_prefill_tokens = query_start_loc_np[num_prompt_req].item() - num_decode_tokens = query_start_loc_np[num_reqs].item( + num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item() + num_decode_tokens = runner.query_start_loc_np[num_reqs].item( ) - num_prefill_tokens slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long() block_table_tensor = block_table.get_device_tensor() diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 97a7a71f679..e2f98a8619f 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -142,7 +142,7 @@ class LocalAttentionMetadata: local_attn_metadata: Optional[LocalAttentionMetadata] = None - decode_only_attn_metadata: Optional["FlashAttentionMetadata"] = None + prefill_skipped_attn_metadata: Optional["FlashAttentionMetadata"] = None def _get_sliding_window_configs( @@ -209,26 +209,71 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def build( + def build_skip_prefill( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - decode_only_common_attn_metadata: Optional[ - CommonAttentionMetadata] = None, ) -> FlashAttentionMetadata: - decode_only_attn_metadata = None - if decode_only_common_attn_metadata is not None: - decode_only_attn_metadata = self.build( - common_prefix_len=0, # disable cascade attention - common_attn_metadata=decode_only_common_attn_metadata, - ) + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + decode_indices = common_attn_metadata.decode_indices + # Example inputs + # num_reqs: 3 + # decode_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(decode_indices, + query_start_loc[1:], + right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc only considering tokens in decode_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty(num_reqs + 1, + device=query_start_loc.device, + dtype=query_start_loc.dtype) + + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = num_decode_tokens.max().item() + total_num_decode_tokens = num_decode_tokens.sum().item() + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + # TODO(sarckk): optimize + query_start_loc_np=decode_query_start_loc.cpu().numpy(), + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + # Set to None so we don't recurse again + decode_indices=None, + ) + metadata = self.build( + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata, + ) + return metadata + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + ) -> FlashAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_np = common_attn_metadata.query_start_loc_np seq_lens = common_attn_metadata.seq_lens block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] @@ -280,11 +325,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, # for local attention local_attn_metadata = None if self.runner.attention_chunk_size is not None: - query_start_loc_np = (common_attn_metadata.query_start_loc_np - if common_attn_metadata.query_start_loc_np - is not None else - self.runner.query_start_loc_np[:num_reqs + - 1]) seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ virt_block_table_tensor = make_local_attention_virtual_batches( self.runner.attention_chunk_size, @@ -375,6 +415,12 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits + prefill_skipped_attn_metadata = None + if common_attn_metadata.decode_indices is not None: + prefill_skipped_attn_metadata = self.build_skip_prefill( + common_prefix_len=0, # disable cascade attention + common_attn_metadata=common_attn_metadata) + attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, @@ -392,7 +438,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_attn_metadata=local_attn_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - decode_only_attn_metadata=decode_only_attn_metadata, + prefill_skipped_attn_metadata=prefill_skipped_attn_metadata, ) return attn_metadata @@ -459,6 +505,8 @@ def __init__( raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") + self.kv_sharing_skip_prefill = False + def forward( self, layer: torch.nn.Module, @@ -496,9 +544,9 @@ def forward( return output if (self.kv_sharing_target_layer_name is not None - and attn_metadata.decode_only_attn_metadata is not None): - # Override with decode-only attention metadata - attn_metadata = attn_metadata.decode_only_attn_metadata + and self.kv_sharing_skip_prefill + and attn_metadata.prefill_skipped_attn_metadata is not None): + attn_metadata = attn_metadata.prefill_skipped_attn_metadata # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index b45ef8cab37..4ae595c976b 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -423,17 +423,8 @@ def _plan(self, attn_metadata: FlashInferMetadata): kv_data_type=attn_metadata.kv_data_type, ) - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - decode_only_common_attn_metadata: Optional[ - CommonAttentionMetadata] = None, - ): - if decode_only_common_attn_metadata is not None: - raise NotImplementedError( - "FlashInfer backend does not support decode-only attention yet." - ) + def build(self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 08cda5f1e58..a8c5f464aa3 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -271,17 +271,8 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.kv_cache_spec = kv_cache_spec self.block_table = block_table - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - decode_only_common_attn_metadata: Optional[ - CommonAttentionMetadata] = None, - ): - if decode_only_common_attn_metadata is not None: - raise NotImplementedError( - "FlexAttention backend does not support decode-only " - "attention yet.") + def build(self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 2b74d6edfac..970de229e13 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -604,16 +604,8 @@ def build_for_cudagraph_capture( self._num_prefill_tokens = 0 return self.build(0, m) - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - decode_only_common_attn_metadata: Optional[ - CommonAttentionMetadata] = None, - ) -> M: - if decode_only_common_attn_metadata is not None: - raise NotImplementedError( - "MLA backend does not support decode-only attention yet.") + def build(self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 6b94e429271..14f8035f956 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -36,6 +36,10 @@ class CommonAttentionMetadata: query_start_loc: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" + + query_start_loc_np: np.ndarray + """(batch_size + 1,), numpy version of query_start_loc on the CPU""" + seq_lens: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" @@ -47,8 +51,8 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" - query_start_loc_np: Optional[np.ndarray] = None - """(batch_size + 1,), cpu numpy version of query_start_loc""" + decode_indices: Optional[torch.Tensor] = None + """indices used for decoding""" M = TypeVar("M") @@ -359,52 +363,3 @@ def make_local_attention_virtual_batches( return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ block_table_local - -def compute_decode_only_common_attn_metadata( - num_reqs: int, - decode_indices: torch.Tensor, - query_start_loc: torch.Tensor, - seq_lens: torch.Tensor, -): - """ - Compute new query related attention data only considering - token positions corresponding to decode_indices. - - Used to skip tokens during prefill for some Attention layers - that re-use KV cache from earlier layers in the model. - """ - # Inputs: - # decode_indices: [14, 18, 19, 27] - # query_start_loc: [0, 15, 20, 28] - # seq_lens: [41, 31, 40] - - # Find how many decode indices belong to each request - # request_ids: [0, 1, 1, 2] - request_ids = torch.bucketize(decode_indices, - query_start_loc[1:], - right=True) - - # Figure out how many tokens are in each request - # num_decode_tokens: [1, 2, 1] - num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) - - # Calculate new query_start_loc only considering tokens in decode_indices - # decode_query_start_loc: [0, 1, 3, 4] - decode_query_start_loc = torch.empty(num_reqs + 1, - device=query_start_loc.device, - dtype=query_start_loc.dtype) - decode_query_start_loc[0] = 0 - decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) - decode_max_query_len = num_decode_tokens.max().item() - total_num_decode_tokens = num_decode_tokens.sum().item() - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=decode_query_start_loc, - # TODO(sarckk): optimize - query_start_loc_np=decode_query_start_loc.cpu().numpy(), - seq_lens=seq_lens, - num_reqs=num_reqs, - num_actual_tokens=total_num_decode_tokens, - max_query_len=decode_max_query_len, - ) - return common_attn_metadata diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5da229e85dc..747078297e7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -45,10 +45,10 @@ GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) +from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - compute_decode_only_common_attn_metadata) + CommonAttentionMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec, @@ -317,6 +317,10 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + self.decode_indices = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -575,11 +579,31 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange + def _calc_decode_indices(self, logits_indices: torch.Tensor): + """ + Pads logits_indices to align with CUDA graph capture sizes + """ + num_decodes = logits_indices.shape[0] + # TODO(sarckk): With chunked prefills, logits_indices contains + # indices for partial requests though we do not sample any token + # from these partial requests, for simplicity. In the future, we + # can calculate the 'true' decode indices based on logits_indices + self.decode_indices[:num_decodes].copy_(logits_indices) + # pad with last idx instead of zero + self.decode_indices[num_decodes:].fill_(logits_indices[-1].item()) + if (self.use_cuda_graph + and num_decodes <= self.cudagraph_batch_sizes[-1]): + num_decodes_padded = self.vllm_config.pad_for_cudagraph( + num_decodes) + else: + num_decodes_padded = num_decodes + return self.decode_indices[:num_decodes_padded] + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray]: + Optional[SpecDecodeMetadata], np.ndarray, torch.Tensor]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -700,15 +724,6 @@ def _prepare_inputs( query_start_loc_np = self.query_start_loc_np[:num_reqs + 1] seq_lens = self.seq_lens[:num_reqs] - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - query_start_loc_np=query_start_loc_np, - seq_lens=seq_lens, - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - ) - use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 if not use_spec_decode: @@ -733,20 +748,17 @@ def _prepare_inputs( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices - decode_only_common_attn_metadata = None - if envs.VLLM_V1_KV_SHARING_SKIP_PREFILL: - decode_only_common_attn_metadata = ( - compute_decode_only_common_attn_metadata( - num_reqs=num_reqs, - # TODO(sarckk): logits_indices contains tokens for partial - # prefill requests, so we can optimize further by only - # considering tokens if its index is more than or equal to - # input_batch.num_prompt_tokens[req_index], which correspond - # to positions that are required for sampling output tokens - decode_indices=logits_indices, - query_start_loc=query_start_loc, - seq_lens=seq_lens, - )) + decode_indices = self._calc_decode_indices(logits_indices) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_np=query_start_loc_np, + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + decode_indices=decode_indices, + ) attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers @@ -769,8 +781,6 @@ def _prepare_inputs( attn_metadata_i = (builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - decode_only_common_attn_metadata= - decode_only_common_attn_metadata, )) for layer_name in kv_cache_group_spec.layer_names: @@ -785,7 +795,7 @@ def _prepare_inputs( self.set_active_loras(self.input_batch, num_scheduled_tokens) return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens) + spec_decode_metadata, num_scheduled_tokens, decode_indices) def _compute_cascade_attn_prefix_len( self, @@ -1306,8 +1316,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) + spec_decode_metadata, num_scheduled_tokens_np, + decode_indices) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1378,8 +1388,6 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - decode_indices = (logits_indices - if envs.VLLM_V1_KV_SHARING_SKIP_PREFILL else None) with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens, @@ -1978,6 +1986,10 @@ def _dummy_run( dtype=np.int32) attn_metadata: Optional[dict[str, Any]] = None + decode_indices = torch.arange(num_tokens, + device=self.device, + dtype=torch.int) + if capture_attn_cudagraph: attn_metadata = {} @@ -1997,6 +2009,7 @@ def _dummy_run( num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=num_tokens, + decode_indices=decode_indices, ) for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -2039,7 +2052,8 @@ def _dummy_run( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + num_tokens_across_dp=num_tokens_across_dp, + decode_indices=decode_indices): outputs = model( input_ids=input_ids, positions=positions, @@ -2672,6 +2686,18 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=max_model_len, page_size_padded=page_size_padded) + # Second pass to determine if N-1 prompt tokens can be skipped + # during prefill for layers that re-use shared KV cache + # Iterate in reversed order and note shared kv cache layers where + # there is no layer after it that allocates its own KV cache + for layer_name in reversed(attn_layers.keys()): + if layer_name in self.shared_kv_cache_layers: + attn_module = attn_layers[layer_name] + if isinstance(attn_module.impl, FlashAttentionImpl): + attn_module.impl.kv_sharing_skip_prefill = True + else: + break + return kv_cache_spec def _maybe_pad_mamba_page_size( From fb5d610f581c350d4fd5e4c97950d29c483b2bd3 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 2 Jul 2025 16:34:28 -0700 Subject: [PATCH 03/13] Fix lint Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_skip_prefill.py | 7 ++++--- vllm/config.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index 8364d8d8c90..8b0245f38c4 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -3,7 +3,7 @@ import gc from collections.abc import Iterable -from typing import List, Optional, Union +from typing import Optional, Union import pytest import torch @@ -112,7 +112,7 @@ def __init__( *, vllm_config: VllmConfig, prefix: str = "", - layers: List[nn.Module], + layers: list[nn.Module], ): super().__init__() self.layers = layers @@ -162,7 +162,8 @@ def __init__(self, ) # Pre-allocate static buffers for CUDA graph - self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.max_num_tokens =\ + vllm_config.scheduler_config.max_num_batched_tokens self.dtype = vllm_config.model_config.dtype self.device = next(self.parameters()).device self.hidden_size = vllm_config.model_config.get_hidden_size() diff --git a/vllm/config.py b/vllm/config.py index 9b67da42114..5cdbbbadd0b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4115,7 +4115,7 @@ class CompilationConfig: - None (default): capture sizes are inferred from vllm config. - list[int]: capture sizes are specified as given.""" cudagraph_share_memory_pool: bool = True - """Whether to share a single global memory pool for each CUDA graph captured""" + """Whether to share a single global memory pool for each graph capture""" cudagraph_copy_inputs: bool = False """Whether to copy input tensors for cudagraph. If the caller can guarantee that the same input buffers diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e2f98a8619f..ae90a7b0b7d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -234,7 +234,7 @@ def build_skip_prefill( # num_decode_tokens: [1, 2, 1] num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) - # Calculate new query_start_loc only considering tokens in decode_indices + # Calculate new query_start_loc with tokens in decode_indices # decode_query_start_loc: [0, 1, 3, 4] decode_query_start_loc = torch.empty(num_reqs + 1, device=query_start_loc.device, From c33bfd9a5a2e7ab585244206481c7f948f76a979 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 2 Jul 2025 19:34:05 -0700 Subject: [PATCH 04/13] Fix wrong prefill skip attn metadata Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_skip_prefill.py | 28 +++++++++++--------- vllm/config.py | 9 ++++++- vllm/engine/arg_utils.py | 4 +++ vllm/entrypoints/llm.py | 2 ++ vllm/envs.py | 3 --- vllm/v1/attention/backends/flash_attn.py | 16 ++++++----- vllm/v1/worker/gpu_model_runner.py | 7 +++-- 7 files changed, 45 insertions(+), 24 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index 8b0245f38c4..39fce9806eb 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -194,6 +194,7 @@ def forward( if decode_indices is None: decode_indices = torch.arange(positions.size(0), device=positions.device) + num_decodes = decode_indices.shape[0] assert num_decodes >= 1 assert first_residual is not None @@ -270,12 +271,14 @@ def load_weights(self, weights: Iterable[tuple[str, @fork_new_process_for_each_test -@pytest.mark.parametrize("enforce_eager", [False, True]) -def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager): - prompt = "What is the capital of France?" +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_kv_sharing_skip_prefill( + monkeypatch: pytest.MonkeyPatch, + enforce_eager: bool, +): ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) - single_prompt = [prompt] + prompts = ["What is the capital of France?"] compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE if not enforce_eager else CompilationLevel.NO_COMPILATION, @@ -284,21 +287,22 @@ def test_kv_sharing_skip_prefill(monkeypatch, enforce_eager): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", - enforce_eager=enforce_eager, - compilation_config=compilation_config) - responses = llm.generate(single_prompt, sampling_params) + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + enforce_eager=enforce_eager, + compilation_config=compilation_config, + ) + responses = llm.generate(prompts, sampling_params) ref_output = responses[0].outputs[0].text del llm gc.collect() torch.cuda.empty_cache() - m.setenv("VLLM_V1_KV_SHARING_SKIP_PREFILL", "1") - llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", enforce_eager=enforce_eager, - compilation_config=compilation_config) - responses = llm.generate(single_prompt, sampling_params) + compilation_config=compilation_config, + kv_sharing_skip_prefill=True) + responses = llm.generate(prompts, sampling_params) output = responses[0].outputs[0].text assert output == ref_output diff --git a/vllm/config.py b/vllm/config.py index 5cdbbbadd0b..15637370c2e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1564,6 +1564,10 @@ class CacheConfig: checkpoint if available. Otherwise, the scales will default to 1.0.""" cpu_kvcache_space_bytes: Optional[int] = None """(CPU backend only) CPU key-value cache space.""" + kv_sharing_skip_prefill: bool = False + """Skip prefill for tokens where applicable in KV cache sharing + scenarios where required key/value tensors have been populated + in earlier KV sharing target layers.""" # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) @@ -4115,7 +4119,10 @@ class CompilationConfig: - None (default): capture sizes are inferred from vllm config. - list[int]: capture sizes are specified as given.""" cudagraph_share_memory_pool: bool = True - """Whether to share a single global memory pool for each graph capture""" + """Whether to share a single global memory pool for each graph capture + When CUDA graphs are not replayed in the same order they are captured, + e.g. when compiling multiple modules in a model and modules take different + input shapes, it is unsafe to share memory across graph captures.""" cudagraph_copy_inputs: bool = False """Whether to copy input tensors for cudagraph. If the caller can guarantee that the same input buffers diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f47499309d8..d06077d43e3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -472,6 +472,7 @@ class EngineArgs: override_attention_dtype: str = ModelConfig.override_attention_dtype calculate_kv_scales: bool = CacheConfig.calculate_kv_scales + kv_sharing_skip_prefill: bool = CacheConfig.kv_sharing_skip_prefill additional_config: dict[str, Any] = \ get_field(VllmConfig, "additional_config") @@ -748,6 +749,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **cache_kwargs["cpu_offload_gb"]) cache_group.add_argument("--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]) + cache_group.add_argument("--kv-sharing-skip-prefill", + **cache_kwargs["kv_sharing_skip_prefill"]) # Tokenizer arguments tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) @@ -1158,6 +1161,7 @@ def create_engine_config( prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, + kv_sharing_skip_prefill=self.kv_sharing_skip_prefill, ) # Get the current placement group if Ray is initialized and diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c60a566f585..bda905edd54 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -194,6 +194,7 @@ def __init__( override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, + kv_sharing_skip_prefill: bool = False, **kwargs, ) -> None: """LLM constructor.""" @@ -267,6 +268,7 @@ def __init__( mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, + kv_sharing_skip_prefill=kv_sharing_skip_prefill, **kwargs, ) diff --git a/vllm/envs.py b/vllm/envs.py index ab0184174e4..b7814169c77 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -139,7 +139,6 @@ VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 - VLLM_V1_KV_SHARING_SKIP_PREFILL: bool = False def get_default_cache_root(): @@ -966,8 +965,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_TRTLLM_DECODE_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), - "VLLM_V1_KV_SHARING_SKIP_PREFILL": - lambda: os.environ.get("VLLM_V1_KV_SHARING_SKIP_PREFILL", "0") == "1", } # --8<-- [end:env-vars-definition] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ae90a7b0b7d..1d48e2db96b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -267,6 +267,16 @@ def build( common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, ) -> FlashAttentionMetadata: + prefill_skipped_attn_metadata = None + if common_attn_metadata.decode_indices is not None: + # NOTE(sarckk): attention metadata for partial prefill skip case + # needs to be built first, otherwise the line below + # block_table.slot_mapping[num_actual_tokens:].fill_(-1) + # will override the correct slot mapping + prefill_skipped_attn_metadata = self.build_skip_prefill( + common_prefix_len=0, # disable cascade attention + common_attn_metadata=common_attn_metadata) + num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -415,12 +425,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, # we only set num_splits when using cuda graphs. max_num_splits = self.max_num_splits - prefill_skipped_attn_metadata = None - if common_attn_metadata.decode_indices is not None: - prefill_skipped_attn_metadata = self.build_skip_prefill( - common_prefix_len=0, # disable cascade attention - common_attn_metadata=common_attn_metadata) - attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 747078297e7..e1b93eecaa2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -583,6 +583,8 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor): """ Pads logits_indices to align with CUDA graph capture sizes """ + if not self.cache_config.kv_sharing_skip_prefill: + return None num_decodes = logits_indices.shape[0] # TODO(sarckk): With chunked prefills, logits_indices contains # indices for partial requests though we do not sample any token @@ -602,8 +604,9 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor): def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray, torch.Tensor]: + ) -> tuple[dict[str, + Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], + np.ndarray, Optional[torch.Tensor]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, From fd764de6c2db6d57c641859bff27a14252a34855 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 2 Jul 2025 21:46:58 -0700 Subject: [PATCH 05/13] More rigorous correctness check Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/conftest.py | 36 ++++++++++++++++++++ tests/v1/e2e/test_kv_sharing_skip_prefill.py | 7 ++-- tests/v1/e2e/test_spec_decode.py | 34 ------------------ 3 files changed, 40 insertions(+), 37 deletions(-) create mode 100644 tests/v1/e2e/conftest.py diff --git a/tests/v1/e2e/conftest.py b/tests/v1/e2e/conftest.py new file mode 100644 index 00000000000..a7f5b130ad2 --- /dev/null +++ b/tests/v1/e2e/conftest.py @@ -0,0 +1,36 @@ +import random + +import pytest + + +@pytest.fixture +def test_prompts(): + prompt_types = ["repeat", "sentence"] + num_prompts = 100 + prompts = [] + + random.seed(0) + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + + # Generate a mixed batch of prompts, some of which can be easily + # predicted by n-gram matching and some which likely cannot. + for kind in random_prompt_type_choices: + word_choices = ["test", "temp", "hello", "where"] + word = random.choice(word_choices) + if kind == "repeat": + prompt = f""" + please repeat the word '{word}' 10 times. + give no other output than the word at least ten times in a row, + in lowercase with spaces between each word and without quotes. + """ + elif kind == "sentence": + prompt = f""" + please give a ten-word sentence that + uses the word {word} at least once. + give no other output than that simple sentence without quotes. + """ + else: + raise ValueError(f"Unknown prompt type: {kind}") + prompts.append([{"role": "user", "content": prompt}]) + + return prompts diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index 39fce9806eb..8a3b7850e5b 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -3,7 +3,7 @@ import gc from collections.abc import Iterable -from typing import Optional, Union +from typing import Any, Optional, Union import pytest import torch @@ -275,10 +275,11 @@ def load_weights(self, weights: Iterable[tuple[str, def test_kv_sharing_skip_prefill( monkeypatch: pytest.MonkeyPatch, enforce_eager: bool, + test_prompts: list[list[dict[str, Any]]], ): ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM) - sampling_params = SamplingParams(temperature=0.0, max_tokens=100) - prompts = ["What is the capital of France?"] + sampling_params = SamplingParams(temperature=0.0, max_tokens=42) + prompts = [prompt[0]['content'] for prompt in test_prompts] compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE if not enforce_eager else CompilationLevel.NO_COMPILATION, diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 93e7c12f3a0..177e4350af3 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations -import random from typing import Any import pytest @@ -10,39 +9,6 @@ from vllm import LLM, SamplingParams -@pytest.fixture -def test_prompts(): - prompt_types = ["repeat", "sentence"] - num_prompts = 100 - prompts = [] - - random.seed(0) - random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) - - # Generate a mixed batch of prompts, some of which can be easily - # predicted by n-gram matching and some which likely cannot. - for kind in random_prompt_type_choices: - word_choices = ["test", "temp", "hello", "where"] - word = random.choice(word_choices) - if kind == "repeat": - prompt = f""" - please repeat the word '{word}' 10 times. - give no other output than the word at least ten times in a row, - in lowercase with spaces between each word and without quotes. - """ - elif kind == "sentence": - prompt = f""" - please give a ten-word sentence that - uses the word {word} at least once. - give no other output than that simple sentence without quotes. - """ - else: - raise ValueError(f"Unknown prompt type: {kind}") - prompts.append([{"role": "user", "content": prompt}]) - - return prompts - - @pytest.fixture def sampling_config(): return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) From cc32a06b40d12e2a260654b0e0f9e798382b1782 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 2 Jul 2025 22:00:32 -0700 Subject: [PATCH 06/13] optimization to skip computing extra metadata if all requests on decode Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_skip_prefill.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index 8a3b7850e5b..940b6305ba0 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -278,7 +278,7 @@ def test_kv_sharing_skip_prefill( test_prompts: list[list[dict[str, Any]]], ): ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM) - sampling_params = SamplingParams(temperature=0.0, max_tokens=42) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) prompts = [prompt[0]['content'] for prompt in test_prompts] compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e1b93eecaa2..723c7b4e7bf 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -585,6 +585,18 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor): """ if not self.cache_config.kv_sharing_skip_prefill: return None + + num_decode_reqs = 0 + for req_index in range(self.input_batch.num_reqs): + if self.input_batch.num_computed_tokens_cpu[ + req_index] >= self.input_batch.num_prompt_tokens[ + req_index]: + num_decode_reqs += 1 + + if self.input_batch.num_reqs == num_decode_reqs: + # All requests are on decode, skip calculate decode only indices + return None + num_decodes = logits_indices.shape[0] # TODO(sarckk): With chunked prefills, logits_indices contains # indices for partial requests though we do not sample any token From cee69c4467c6cbe8e308886815d1f68962704e22 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 2 Jul 2025 22:02:04 -0700 Subject: [PATCH 07/13] minor fix Signed-off-by: Yong Hoon Shin --- vllm/envs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/envs.py b/vllm/envs.py index b7814169c77..7bff6ade815 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -964,7 +964,6 @@ def get_vllm_port() -> Optional[int]: # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. "VLLM_USE_TRTLLM_DECODE_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), - } # --8<-- [end:env-vars-definition] From b6cd35d919eb6954195e25ca3b66811aaf60df60 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 3 Jul 2025 09:30:19 -0700 Subject: [PATCH 08/13] Fix cudagraph issue with padding Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/conftest.py | 36 ------ tests/v1/e2e/test_kv_sharing_skip_prefill.py | 113 ++++++++++++++++--- tests/v1/e2e/test_spec_decode.py | 34 ++++++ vllm/compilation/decorators.py | 2 +- 4 files changed, 131 insertions(+), 54 deletions(-) delete mode 100644 tests/v1/e2e/conftest.py diff --git a/tests/v1/e2e/conftest.py b/tests/v1/e2e/conftest.py deleted file mode 100644 index a7f5b130ad2..00000000000 --- a/tests/v1/e2e/conftest.py +++ /dev/null @@ -1,36 +0,0 @@ -import random - -import pytest - - -@pytest.fixture -def test_prompts(): - prompt_types = ["repeat", "sentence"] - num_prompts = 100 - prompts = [] - - random.seed(0) - random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) - - # Generate a mixed batch of prompts, some of which can be easily - # predicted by n-gram matching and some which likely cannot. - for kind in random_prompt_type_choices: - word_choices = ["test", "temp", "hello", "where"] - word = random.choice(word_choices) - if kind == "repeat": - prompt = f""" - please repeat the word '{word}' 10 times. - give no other output than the word at least ten times in a row, - in lowercase with spaces between each word and without quotes. - """ - elif kind == "sentence": - prompt = f""" - please give a ten-word sentence that - uses the word {word} at least once. - give no other output than that simple sentence without quotes. - """ - else: - raise ValueError(f"Unknown prompt type: {kind}") - prompts.append([{"role": "user", "content": prompt}]) - - return prompts diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index 940b6305ba0..f9a384ca728 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -2,8 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc +import random from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Optional, Union import pytest import torch @@ -105,7 +106,7 @@ def forward( @support_torch_compile -class DecoderLayerGroup(nn.Module): +class FirstLayerGroup(nn.Module): def __init__( self, @@ -121,7 +122,35 @@ def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor] = None, + ): + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + return hidden_states, residual + + +@support_torch_compile +class SecondLayerGroup(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layers: list[nn.Module], + ): + super().__init__() + self.layers = layers + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, ): for layer in self.layers: hidden_states, residual = layer( @@ -147,15 +176,17 @@ def __init__(self, decoder_layer_type=decoder_layer_type, ) + self.vllm_config = vllm_config + with set_model_tag("first_layer_group"): - self.first_layer_group = DecoderLayerGroup( + self.first_layer_group = FirstLayerGroup( vllm_config=vllm_config, prefix=f"{prefix}.first_layer_group", layers=self.layers[self.start_layer:START_KV_SHARING_LAYER], ) with set_model_tag("second_layer_group"): - self.second_layer_group = DecoderLayerGroup( + self.second_layer_group = SecondLayerGroup( vllm_config=vllm_config, prefix=f"{prefix}.second_layer_group", layers=self.layers[START_KV_SHARING_LAYER:self.end_layer], @@ -170,6 +201,10 @@ def __init__(self, self.residual = torch.zeros((self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=self.device) + self.hidden_states = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) def forward( self, @@ -183,11 +218,12 @@ def forward( else: hidden_states = self.get_input_embeddings(input_ids) - residual = None + num_input_tokens = input_ids.size(0) + self.hidden_states[:num_input_tokens].copy_(hidden_states) + first_hidden_states, first_residual = self.first_layer_group( positions, - hidden_states, - residual, # no residual, assume no pipeline parallel + self.hidden_states[:num_input_tokens], ) decode_indices = get_forward_context().decode_indices @@ -202,15 +238,24 @@ def forward( # CUDA graph expects static tensor addresses # Copy output of first layer group to second layer group self.residual[:num_decodes].copy_(first_residual[decode_indices]) - hidden_states[:num_decodes].copy_(first_hidden_states[decode_indices]) + self.hidden_states[:num_decodes].copy_( + first_hidden_states[decode_indices]) positions[:num_decodes].copy_(positions[decode_indices]) second_hidden_states, second_residual = self.second_layer_group( positions[:num_decodes], - hidden_states[:num_decodes], + self.hidden_states[:num_decodes], self.residual[:num_decodes], ) + # NOTE(sarckk): Due to cudagraph padding, decode_indices may have + # trailing repeated indices. Attention output is only valid at the + # last index in this case. + last_index_mask = decode_indices == decode_indices[-1] + second_hidden_states[last_index_mask] = second_hidden_states[-1].clone( + ) + second_residual[last_index_mask] = second_residual[-1].clone() + # Merge results back first_hidden_states[decode_indices] = second_hidden_states if first_residual is not None: @@ -270,16 +315,43 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights) +@pytest.fixture +def test_prompts(): + prompt_types = ["repeat", "sentence"] + # Setting higher num prompts increases the chance of numerics mismatch + # due to matrix multiplication numerics depending on batch dimension + num_prompts = 10 + prompts = [] + + random.seed(0) + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + + # Generate a mixed batch of prompts, some of which can be easily + # predicted by n-gram matching and some which likely cannot. + for kind in random_prompt_type_choices: + word_choices = ["test", "temp", "hello", "where"] + word = random.choice(word_choices) + if kind == "repeat": + prompt = f"""please repeat the word '{word}' 10 times.""" + elif kind == "sentence": + prompt = f"""please give a ten-word sentence that + uses the word {word} at least once.""" + else: + raise ValueError(f"Unknown prompt type: {kind}") + prompts.append(prompt) + + return prompts + + @fork_new_process_for_each_test @pytest.mark.parametrize("enforce_eager", [True, False]) def test_kv_sharing_skip_prefill( monkeypatch: pytest.MonkeyPatch, enforce_eager: bool, - test_prompts: list[list[dict[str, Any]]], + test_prompts: list[str], ): ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) - prompts = [prompt[0]['content'] for prompt in test_prompts] compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE if not enforce_eager else CompilationLevel.NO_COMPILATION, @@ -293,8 +365,7 @@ def test_kv_sharing_skip_prefill( enforce_eager=enforce_eager, compilation_config=compilation_config, ) - responses = llm.generate(prompts, sampling_params) - ref_output = responses[0].outputs[0].text + ref_responses = llm.generate(test_prompts, sampling_params) del llm gc.collect() @@ -304,6 +375,14 @@ def test_kv_sharing_skip_prefill( enforce_eager=enforce_eager, compilation_config=compilation_config, kv_sharing_skip_prefill=True) - responses = llm.generate(prompts, sampling_params) - output = responses[0].outputs[0].text - assert output == ref_output + optimized_responses = llm.generate(test_prompts, sampling_params) + + misses = 0 + + for ref_response, optimized_response in zip(ref_responses, + optimized_responses): + if ref_response.outputs[0].text != optimized_response.outputs[ + 0].text: + misses += 1 + + assert misses == 0 diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 177e4350af3..93e7c12f3a0 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations +import random from typing import Any import pytest @@ -9,6 +10,39 @@ from vllm import LLM, SamplingParams +@pytest.fixture +def test_prompts(): + prompt_types = ["repeat", "sentence"] + num_prompts = 100 + prompts = [] + + random.seed(0) + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + + # Generate a mixed batch of prompts, some of which can be easily + # predicted by n-gram matching and some which likely cannot. + for kind in random_prompt_type_choices: + word_choices = ["test", "temp", "hello", "where"] + word = random.choice(word_choices) + if kind == "repeat": + prompt = f""" + please repeat the word '{word}' 10 times. + give no other output than the word at least ten times in a row, + in lowercase with spaces between each word and without quotes. + """ + elif kind == "sentence": + prompt = f""" + please give a ten-word sentence that + uses the word {word} at least once. + give no other output than that simple sentence without quotes. + """ + else: + raise ValueError(f"Unknown prompt type: {kind}") + prompts.append([{"role": "user", "content": prompt}]) + + return prompts + + @pytest.fixture def sampling_config(): return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 61af5001df0..0813deac0b9 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -26,7 +26,7 @@ def skip_torch_compile(cls: _T) -> _T: cls._skip_compile_vllm = True for base in cls.__bases__: - base._skip_compile_vllm = True + setattr(base,"_skip_compile_vllm",True) return cls From 55ddaa0be3eddd9905b46dfa3dad47acf342b5ba Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Thu, 3 Jul 2025 10:40:13 -0700 Subject: [PATCH 09/13] Fix pre-commit Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_skip_prefill.py | 5 +++-- vllm/compilation/decorators.py | 2 -- vllm/v1/attention/backends/flash_attn.py | 2 ++ vllm/v1/attention/backends/utils.py | 15 +++++---------- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index f9a384ca728..d45f38b05cf 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -317,6 +317,9 @@ def load_weights(self, weights: Iterable[tuple[str, @pytest.fixture def test_prompts(): + """ + Adapted from tests/v1/e2e/test_spec_decode.py + """ prompt_types = ["repeat", "sentence"] # Setting higher num prompts increases the chance of numerics mismatch # due to matrix multiplication numerics depending on batch dimension @@ -326,8 +329,6 @@ def test_prompts(): random.seed(0) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) - # Generate a mixed batch of prompts, some of which can be easily - # predicted by n-gram matching and some which likely cannot. for kind in random_prompt_type_choices: word_choices = ["test", "temp", "hello", "where"] word = random.choice(word_choices) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 0813deac0b9..8dd268d06a9 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -25,8 +25,6 @@ def skip_torch_compile(cls: _T) -> _T: cls._skip_compile_vllm = True - for base in cls.__bases__: - setattr(base,"_skip_compile_vllm",True) return cls diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1d48e2db96b..3a4aa16e795 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -284,6 +284,8 @@ def build( max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc query_start_loc_np = common_attn_metadata.query_start_loc_np + if query_start_loc_np is None: + query_start_loc_np = self.runner.query_start_loc_np[:num_reqs + 1] seq_lens = common_attn_metadata.seq_lens block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 14f8035f956..19ae0709631 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -37,9 +37,6 @@ class CommonAttentionMetadata: query_start_loc: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" - query_start_loc_np: np.ndarray - """(batch_size + 1,), numpy version of query_start_loc on the CPU""" - seq_lens: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" @@ -54,6 +51,9 @@ class CommonAttentionMetadata: decode_indices: Optional[torch.Tensor] = None """indices used for decoding""" + query_start_loc_np: Optional[np.ndarray] = None + """(batch_size + 1,), numpy equivalent of query_start_loc on the CPU""" + M = TypeVar("M") @@ -63,13 +63,8 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): full_cudagraph_supported: ClassVar[bool] = False @abstractmethod - def build( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - decode_only_common_attn_metadata: Optional[ - CommonAttentionMetadata] = None, - ) -> M: + def build(self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. From 3cd2474be45e7750aef14247017ef1d8cf6e4250 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Fri, 11 Jul 2025 14:18:26 -0700 Subject: [PATCH 10/13] Address comments Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_skip_prefill.py | 28 +++++++++----------- vllm/compilation/backends.py | 7 ++--- vllm/compilation/decorators.py | 6 ++--- vllm/config.py | 5 ---- vllm/v1/worker/gpu_model_runner.py | 10 ++++--- 5 files changed, 24 insertions(+), 32 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index d45f38b05cf..907758c6b9a 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -13,7 +13,7 @@ from vllm import LLM, SamplingParams from vllm.compilation.backends import set_model_tag -from vllm.compilation.decorators import (skip_torch_compile, +from vllm.compilation.decorators import (ignore_torch_compile, support_torch_compile) from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, VllmConfig) @@ -161,7 +161,7 @@ def forward( return hidden_states, residual -@skip_torch_compile +@ignore_torch_compile class Qwen2ModelWithKVSharing(Qwen2Model): def __init__(self, @@ -193,18 +193,17 @@ def __init__(self, ) # Pre-allocate static buffers for CUDA graph - self.max_num_tokens =\ - vllm_config.scheduler_config.max_num_batched_tokens - self.dtype = vllm_config.model_config.dtype - self.device = next(self.parameters()).device - self.hidden_size = vllm_config.model_config.get_hidden_size() - self.residual = torch.zeros((self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + dtype = vllm_config.model_config.dtype + device = next(self.parameters()).device + hidden_size = vllm_config.model_config.get_hidden_size() + self.residual = torch.zeros((max_num_tokens, hidden_size), + dtype=dtype, + device=device) self.hidden_states = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) + (max_num_tokens, hidden_size), + dtype=dtype, + device=device) def forward( self, @@ -355,8 +354,7 @@ def test_kv_sharing_skip_prefill( sampling_params = SamplingParams(temperature=0.0, max_tokens=100) compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE - if not enforce_eager else CompilationLevel.NO_COMPILATION, - cudagraph_share_memory_pool=False) + if not enforce_eager else CompilationLevel.NO_COMPILATION) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 9a029d73018..5148c289d86 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -412,11 +412,8 @@ def __init__( # them, e.g. backbone (default), eagle_head, etc. self.prefix = prefix or model_tag - if vllm_config.compilation_config.cudagraph_share_memory_pool: - global global_graph_pool - if global_graph_pool is None: - global_graph_pool = current_platform.graph_pool_handle() - else: + global global_graph_pool + if global_graph_pool is None: global_graph_pool = current_platform.graph_pool_handle() # TODO: in the future, if we want to use multiple diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 8dd268d06a9..2e7a5f4ee04 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -23,8 +23,8 @@ _T = TypeVar("_T", bound=type[nn.Module]) -def skip_torch_compile(cls: _T) -> _T: - cls._skip_compile_vllm = True +def ignore_torch_compile(cls: _T) -> _T: + cls._ignore_compile_vllm = True return cls @@ -161,7 +161,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() or getattr(self, "_skip_compile_vllm", False) + ] or not supports_dynamo() or getattr(self, "_ignore_compile_vllm", False) if self.do_not_compile: return compilation_counter.num_models_seen += 1 diff --git a/vllm/config.py b/vllm/config.py index 15637370c2e..7fb9d9808cc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4118,11 +4118,6 @@ class CompilationConfig: """Sizes to capture cudagraph. - None (default): capture sizes are inferred from vllm config. - list[int]: capture sizes are specified as given.""" - cudagraph_share_memory_pool: bool = True - """Whether to share a single global memory pool for each graph capture - When CUDA graphs are not replayed in the same order they are captured, - e.g. when compiling multiple modules in a model and modules take different - input shapes, it is unsafe to share memory across graph captures.""" cudagraph_copy_inputs: bool = False """Whether to copy input tensors for cudagraph. If the caller can guarantee that the same input buffers diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 723c7b4e7bf..929c245c71d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -317,9 +317,11 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} - self.decode_indices = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) + self.decode_indices = None + if self.cache_config.kv_sharing_skip_prefill: + self.decode_indices = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ @@ -583,7 +585,7 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor): """ Pads logits_indices to align with CUDA graph capture sizes """ - if not self.cache_config.kv_sharing_skip_prefill: + if self.decode_indices is None: return None num_decode_reqs = 0 From d20f5cc0bc3cfa22035001ec4bcacbbc76d5f9ea Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Fri, 11 Jul 2025 14:52:40 -0700 Subject: [PATCH 11/13] Rename decode_indices -> generation_indices Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_skip_prefill.py | 35 +++++++++--------- vllm/compilation/decorators.py | 14 +++++++- vllm/forward_context.py | 13 ++++--- vllm/v1/attention/backends/flash_attn.py | 12 +++---- vllm/v1/attention/backends/utils.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 38 ++++++++++---------- 6 files changed, 65 insertions(+), 49 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index 907758c6b9a..ccc95fab2ad 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -200,10 +200,9 @@ def __init__(self, self.residual = torch.zeros((max_num_tokens, hidden_size), dtype=dtype, device=device) - self.hidden_states = torch.zeros( - (max_num_tokens, hidden_size), - dtype=dtype, - device=device) + self.hidden_states = torch.zeros((max_num_tokens, hidden_size), + dtype=dtype, + device=device) def forward( self, @@ -225,21 +224,21 @@ def forward( self.hidden_states[:num_input_tokens], ) - decode_indices = get_forward_context().decode_indices - if decode_indices is None: - decode_indices = torch.arange(positions.size(0), - device=positions.device) + generation_indices = get_forward_context().generation_indices + if generation_indices is None: + generation_indices = torch.arange(positions.size(0), + device=positions.device) - num_decodes = decode_indices.shape[0] + num_decodes = generation_indices.shape[0] assert num_decodes >= 1 assert first_residual is not None # CUDA graph expects static tensor addresses # Copy output of first layer group to second layer group - self.residual[:num_decodes].copy_(first_residual[decode_indices]) + self.residual[:num_decodes].copy_(first_residual[generation_indices]) self.hidden_states[:num_decodes].copy_( - first_hidden_states[decode_indices]) - positions[:num_decodes].copy_(positions[decode_indices]) + first_hidden_states[generation_indices]) + positions[:num_decodes].copy_(positions[generation_indices]) second_hidden_states, second_residual = self.second_layer_group( positions[:num_decodes], @@ -247,18 +246,18 @@ def forward( self.residual[:num_decodes], ) - # NOTE(sarckk): Due to cudagraph padding, decode_indices may have + # NOTE(sarckk): Due to cudagraph padding, generation_indices may have # trailing repeated indices. Attention output is only valid at the # last index in this case. - last_index_mask = decode_indices == decode_indices[-1] + last_index_mask = generation_indices == generation_indices[-1] second_hidden_states[last_index_mask] = second_hidden_states[-1].clone( ) second_residual[last_index_mask] = second_residual[-1].clone() # Merge results back - first_hidden_states[decode_indices] = second_hidden_states + first_hidden_states[generation_indices] = second_hidden_states if first_residual is not None: - first_residual[decode_indices] = second_residual + first_residual[generation_indices] = second_residual hidden_states, _ = self.norm(first_hidden_states, first_residual) return hidden_states @@ -353,8 +352,8 @@ def test_kv_sharing_skip_prefill( ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM) sampling_params = SamplingParams(temperature=0.0, max_tokens=100) compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE - if not enforce_eager else CompilationLevel.NO_COMPILATION) + level=CompilationLevel. + PIECEWISE if not enforce_eager else CompilationLevel.NO_COMPILATION) with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 2e7a5f4ee04..e165a6d1f40 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -24,6 +24,17 @@ def ignore_torch_compile(cls: _T) -> _T: + """ + A decorator to ignore support_torch_compile decorator + on the class. This is useful when a parent class has + a support_torch_compile decorator, but we don't want to + compile the class `cls` that inherits the parent class. + + This only ignores compiling the forward of the class the + decorator is applied to. If the class has one or more submodules + that have support_torch_compile decorator applied, compile will + not be ignored for those submodules. + """ cls._ignore_compile_vllm = True return cls @@ -161,7 +172,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() or getattr(self, "_ignore_compile_vllm", False) + ] or not supports_dynamo() or getattr( + self, "_ignore_compile_vllm", False) if self.do_not_compile: return compilation_counter.num_models_seen += 1 diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 57a47b48ace..0b97718740e 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -96,8 +96,13 @@ class ForwardContext: dp_metadata: Optional[DPMetadata] = None skip_cuda_graphs: bool = False - decode_indices: Optional[torch.Tensor] = None - """indices used for decoding""" + generation_indices: Optional[torch.Tensor] = None + """ + Indices of tokens used for sampling output tokens. + Includes the last prefill token and all decode tokens. + Given N prompt tokens, the first N-1 tokens are not included as + they are not used to sample tokens for generation. + """ _forward_context: Optional[ForwardContext] = None @@ -119,7 +124,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, - decode_indices: Optional[torch.Tensor] = None, + generation_indices: Optional[torch.Tensor] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -145,7 +150,7 @@ def set_forward_context( attn_metadata=attn_metadata, dp_metadata=dp_metadata, skip_cuda_graphs=skip_cuda_graphs, - decode_indices=decode_indices, + generation_indices=generation_indices, ) try: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3a4aa16e795..ea8f3322faa 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -217,16 +217,16 @@ def build_skip_prefill( num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - decode_indices = common_attn_metadata.decode_indices + generation_indices = common_attn_metadata.generation_indices # Example inputs # num_reqs: 3 - # decode_indices: [14, 18, 19, 27] + # generation_indices: [14, 18, 19, 27] # query_start_loc: [0, 15, 20, 28] # seq_lens: [41, 31, 40] # Find how many decode indices belong to each request # request_ids: [0, 1, 1, 2] - request_ids = torch.bucketize(decode_indices, + request_ids = torch.bucketize(generation_indices, query_start_loc[1:], right=True) @@ -234,7 +234,7 @@ def build_skip_prefill( # num_decode_tokens: [1, 2, 1] num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) - # Calculate new query_start_loc with tokens in decode_indices + # Calculate new query_start_loc with tokens in generation_indices # decode_query_start_loc: [0, 1, 3, 4] decode_query_start_loc = torch.empty(num_reqs + 1, device=query_start_loc.device, @@ -254,7 +254,7 @@ def build_skip_prefill( num_actual_tokens=total_num_decode_tokens, max_query_len=decode_max_query_len, # Set to None so we don't recurse again - decode_indices=None, + generation_indices=None, ) metadata = self.build( common_prefix_len=common_prefix_len, @@ -268,7 +268,7 @@ def build( common_attn_metadata: CommonAttentionMetadata, ) -> FlashAttentionMetadata: prefill_skipped_attn_metadata = None - if common_attn_metadata.decode_indices is not None: + if common_attn_metadata.generation_indices is not None: # NOTE(sarckk): attention metadata for partial prefill skip case # needs to be built first, otherwise the line below # block_table.slot_mapping[num_actual_tokens:].fill_(-1) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 19ae0709631..1d6569e2288 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -48,7 +48,7 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" - decode_indices: Optional[torch.Tensor] = None + generation_indices: Optional[torch.Tensor] = None """indices used for decoding""" query_start_loc_np: Optional[np.ndarray] = None diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 929c245c71d..9754d1c3f61 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -317,11 +317,11 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} - self.decode_indices = None + self.generation_indices = None if self.cache_config.kv_sharing_skip_prefill: - self.decode_indices = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) + self.generation_indices = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ @@ -581,11 +581,11 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange - def _calc_decode_indices(self, logits_indices: torch.Tensor): + def _calc_generation_indices(self, logits_indices: torch.Tensor): """ Pads logits_indices to align with CUDA graph capture sizes """ - if self.decode_indices is None: + if self.generation_indices is None: return None num_decode_reqs = 0 @@ -604,16 +604,16 @@ def _calc_decode_indices(self, logits_indices: torch.Tensor): # indices for partial requests though we do not sample any token # from these partial requests, for simplicity. In the future, we # can calculate the 'true' decode indices based on logits_indices - self.decode_indices[:num_decodes].copy_(logits_indices) + self.generation_indices[:num_decodes].copy_(logits_indices) # pad with last idx instead of zero - self.decode_indices[num_decodes:].fill_(logits_indices[-1].item()) + self.generation_indices[num_decodes:].fill_(logits_indices[-1].item()) if (self.use_cuda_graph and num_decodes <= self.cudagraph_batch_sizes[-1]): num_decodes_padded = self.vllm_config.pad_for_cudagraph( num_decodes) else: num_decodes_padded = num_decodes - return self.decode_indices[:num_decodes_padded] + return self.generation_indices[:num_decodes_padded] def _prepare_inputs( self, @@ -765,7 +765,7 @@ def _prepare_inputs( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices - decode_indices = self._calc_decode_indices(logits_indices) + generation_indices = self._calc_generation_indices(logits_indices) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -774,7 +774,7 @@ def _prepare_inputs( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - decode_indices=decode_indices, + generation_indices=generation_indices, ) attn_metadata: dict[str, Any] = {} @@ -812,7 +812,7 @@ def _prepare_inputs( self.set_active_loras(self.input_batch, num_scheduled_tokens) return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens, decode_indices) + spec_decode_metadata, num_scheduled_tokens, generation_indices) def _compute_cascade_attn_prefix_len( self, @@ -1334,7 +1334,7 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, - decode_indices) = (self._prepare_inputs(scheduler_output)) + generation_indices) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1410,7 +1410,7 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, - decode_indices=decode_indices): + generation_indices=generation_indices): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( @@ -2003,9 +2003,9 @@ def _dummy_run( dtype=np.int32) attn_metadata: Optional[dict[str, Any]] = None - decode_indices = torch.arange(num_tokens, - device=self.device, - dtype=torch.int) + generation_indices = torch.arange(num_tokens, + device=self.device, + dtype=torch.int) if capture_attn_cudagraph: attn_metadata = {} @@ -2026,7 +2026,7 @@ def _dummy_run( num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=num_tokens, - decode_indices=decode_indices, + generation_indices=generation_indices, ) for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -2070,7 +2070,7 @@ def _dummy_run( self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, - decode_indices=decode_indices): + generation_indices=generation_indices): outputs = model( input_ids=input_ids, positions=positions, From dfc9de3a356d930e7eb2d4ac9236b732d8376c13 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Fri, 11 Jul 2025 16:31:13 -0700 Subject: [PATCH 12/13] Remove cudagraph padding for attention layers Signed-off-by: Yong Hoon Shin --- tests/v1/e2e/test_kv_sharing_skip_prefill.py | 49 ++++++++++--------- vllm/forward_context.py | 32 ++++++++---- vllm/v1/attention/backends/flash_attn.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 51 +++++++++++++------- 4 files changed, 82 insertions(+), 54 deletions(-) diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_skip_prefill.py index ccc95fab2ad..7d509f666b4 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_skip_prefill.py @@ -224,40 +224,41 @@ def forward( self.hidden_states[:num_input_tokens], ) - generation_indices = get_forward_context().generation_indices - if generation_indices is None: - generation_indices = torch.arange(positions.size(0), - device=positions.device) - - num_decodes = generation_indices.shape[0] - assert num_decodes >= 1 + generation_metadata = get_forward_context().generation_metadata + gen_indices_padded = ( + generation_metadata.generation_indices_padded + if generation_metadata is not None + else torch.arange(num_input_tokens, device=positions.device) + ) + num_gen_tokens_padded = gen_indices_padded.shape[0] assert first_residual is not None # CUDA graph expects static tensor addresses # Copy output of first layer group to second layer group - self.residual[:num_decodes].copy_(first_residual[generation_indices]) - self.hidden_states[:num_decodes].copy_( - first_hidden_states[generation_indices]) - positions[:num_decodes].copy_(positions[generation_indices]) + self.residual[:num_gen_tokens_padded].copy_( + first_residual[gen_indices_padded]) + self.hidden_states[:num_gen_tokens_padded].copy_( + first_hidden_states[gen_indices_padded]) + positions[:num_gen_tokens_padded].copy_(positions[gen_indices_padded]) second_hidden_states, second_residual = self.second_layer_group( - positions[:num_decodes], - self.hidden_states[:num_decodes], - self.residual[:num_decodes], + positions[:num_gen_tokens_padded], + self.hidden_states[:num_gen_tokens_padded], + self.residual[:num_gen_tokens_padded], ) - # NOTE(sarckk): Due to cudagraph padding, generation_indices may have - # trailing repeated indices. Attention output is only valid at the - # last index in this case. - last_index_mask = generation_indices == generation_indices[-1] - second_hidden_states[last_index_mask] = second_hidden_states[-1].clone( + # NOTE: we need to pad generation indices for CUDA graph but only the + # first num_gen_tokens positions are actually valid. + num_gen_tokens = ( + generation_metadata.num_generation_tokens + if generation_metadata is not None + else num_gen_tokens_padded ) - second_residual[last_index_mask] = second_residual[-1].clone() - - # Merge results back - first_hidden_states[generation_indices] = second_hidden_states + gen_indices = gen_indices_padded[:num_gen_tokens] + first_hidden_states[ + gen_indices] = second_hidden_states[:num_gen_tokens] if first_residual is not None: - first_residual[generation_indices] = second_residual + first_residual[gen_indices] = second_residual[:num_gen_tokens] hidden_states, _ = self.norm(first_hidden_states, first_residual) return hidden_states diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 0b97718740e..a5733bf0a2e 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -26,6 +26,25 @@ batchsize_forward_time: defaultdict = defaultdict(list) +@dataclass +class GenerationMetadata: + # Set dynamically for each forward pass + num_generation_tokens: int + """ + No. of generation indices without padding + """ + generation_indices_padded: torch.Tensor + """ + Indices of tokens used for sampling output tokens. + Includes the last prefill token and all decode tokens. + Given N prompt tokens, the first N-1 tokens are not included as + they are not used to sample tokens for generation. + """ + + def generation_indices_unpadded(self) -> torch.Tensor: + return self.generation_indices_padded[:self.num_generation_tokens] + + @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor @@ -95,14 +114,7 @@ class ForwardContext: # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None skip_cuda_graphs: bool = False - - generation_indices: Optional[torch.Tensor] = None - """ - Indices of tokens used for sampling output tokens. - Includes the last prefill token and all decode tokens. - Given N prompt tokens, the first N-1 tokens are not included as - they are not used to sample tokens for generation. - """ + generation_metadata: Optional[GenerationMetadata] = None _forward_context: Optional[ForwardContext] = None @@ -124,7 +136,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, - generation_indices: Optional[torch.Tensor] = None, + generation_metadata: Optional[GenerationMetadata] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -150,7 +162,7 @@ def set_forward_context( attn_metadata=attn_metadata, dp_metadata=dp_metadata, skip_cuda_graphs=skip_cuda_graphs, - generation_indices=generation_indices, + generation_metadata=generation_metadata, ) try: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ea8f3322faa..e3a4af18eb1 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -242,8 +242,8 @@ def build_skip_prefill( decode_query_start_loc[0] = 0 decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) - decode_max_query_len = num_decode_tokens.max().item() - total_num_decode_tokens = num_decode_tokens.sum().item() + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) common_attn_metadata = CommonAttentionMetadata( query_start_loc=decode_query_start_loc, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9754d1c3f61..450de51addd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,8 +27,8 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, get_forward_context, - set_forward_context) +from vllm.forward_context import (DPMetadata, GenerationMetadata, + get_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -581,11 +581,14 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange - def _calc_generation_indices(self, logits_indices: torch.Tensor): + def _calc_generation_metadata( + self, + logits_indices: torch.Tensor) -> Optional[GenerationMetadata]: """ Pads logits_indices to align with CUDA graph capture sizes """ if self.generation_indices is None: + assert not self.cache_config.kv_sharing_skip_prefill return None num_decode_reqs = 0 @@ -599,28 +602,32 @@ def _calc_generation_indices(self, logits_indices: torch.Tensor): # All requests are on decode, skip calculate decode only indices return None - num_decodes = logits_indices.shape[0] + num_gen_tokens = logits_indices.shape[0] # TODO(sarckk): With chunked prefills, logits_indices contains # indices for partial requests though we do not sample any token # from these partial requests, for simplicity. In the future, we # can calculate the 'true' decode indices based on logits_indices - self.generation_indices[:num_decodes].copy_(logits_indices) + self.generation_indices[:num_gen_tokens].copy_(logits_indices) # pad with last idx instead of zero - self.generation_indices[num_decodes:].fill_(logits_indices[-1].item()) + self.generation_indices[num_gen_tokens:].fill_( + logits_indices[-1].item()) if (self.use_cuda_graph - and num_decodes <= self.cudagraph_batch_sizes[-1]): - num_decodes_padded = self.vllm_config.pad_for_cudagraph( - num_decodes) + and num_gen_tokens <= self.cudagraph_batch_sizes[-1]): + num_gen_tokens_padded = self.vllm_config.pad_for_cudagraph( + num_gen_tokens) else: - num_decodes_padded = num_decodes - return self.generation_indices[:num_decodes_padded] + num_gen_tokens_padded = num_gen_tokens + + return GenerationMetadata( + num_generation_tokens=num_gen_tokens, + generation_indices_padded=self.generation_indices[:num_gen_tokens_padded]) def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[torch.Tensor]]: + np.ndarray, Optional[GenerationMetadata]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -765,7 +772,7 @@ def _prepare_inputs( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices - generation_indices = self._calc_generation_indices(logits_indices) + generation_metadata = self._calc_generation_metadata(logits_indices) common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, @@ -774,7 +781,9 @@ def _prepare_inputs( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - generation_indices=generation_indices, + generation_indices=( + generation_metadata.generation_indices_unpadded() + if generation_metadata is not None else None), ) attn_metadata: dict[str, Any] = {} @@ -812,7 +821,8 @@ def _prepare_inputs( self.set_active_loras(self.input_batch, num_scheduled_tokens) return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens, generation_indices) + spec_decode_metadata, num_scheduled_tokens, + generation_metadata) def _compute_cascade_attn_prefix_len( self, @@ -1334,7 +1344,7 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, - generation_indices) = (self._prepare_inputs(scheduler_output)) + generation_metadata) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1410,7 +1420,7 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, - generation_indices=generation_indices): + generation_metadata=generation_metadata): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( @@ -2065,12 +2075,17 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) + dummy_generation_metadata = GenerationMetadata( + num_generation_tokens=num_tokens, + generation_indices_padded=generation_indices, + ) + with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, - generation_indices=generation_indices): + generation_metadata=dummy_generation_metadata): outputs = model( input_ids=input_ids, positions=positions, From a945f6f38d7839b8c133b93e03edee38fc7d5e99 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Sun, 13 Jul 2025 18:09:32 -0700 Subject: [PATCH 13/13] Build truncated prefill attn metadata outside of backend impl Signed-off-by: Yong Hoon Shin --- ...y => test_kv_sharing_truncated_prefill.py} | 74 +++---- vllm/config.py | 2 +- vllm/engine/arg_utils.py | 11 +- vllm/entrypoints/llm.py | 5 +- vllm/forward_context.py | 19 +- vllm/v1/attention/backends/flash_attn.py | 73 ------- vllm/v1/attention/backends/utils.py | 3 - vllm/v1/kv_cache_interface.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 195 +++++++++++------- vllm/v1/worker/tpu_model_runner.py | 3 + vllm/v1/worker/utils.py | 11 + 11 files changed, 202 insertions(+), 198 deletions(-) rename tests/v1/e2e/{test_kv_sharing_skip_prefill.py => test_kv_sharing_truncated_prefill.py} (85%) diff --git a/tests/v1/e2e/test_kv_sharing_skip_prefill.py b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py similarity index 85% rename from tests/v1/e2e/test_kv_sharing_skip_prefill.py rename to tests/v1/e2e/test_kv_sharing_truncated_prefill.py index 7d509f666b4..d2052b6003b 100644 --- a/tests/v1/e2e/test_kv_sharing_skip_prefill.py +++ b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py @@ -53,8 +53,7 @@ def __init__( kv_sharing_target_layer_name = None if layer_idx >= START_KV_SHARING_LAYER: - # re-use KV cache from first 5 layers - target_layer_idx = layer_idx % 5 + target_layer_idx = START_KV_SHARING_LAYER - 1 kv_sharing_target_layer_name = f"{attn_prefix}.attn".replace( str(layer_idx), str(target_layer_idx)) @@ -219,48 +218,49 @@ def forward( num_input_tokens = input_ids.size(0) self.hidden_states[:num_input_tokens].copy_(hidden_states) - first_hidden_states, first_residual = self.first_layer_group( + hidden_states, residual = self.first_layer_group( positions, self.hidden_states[:num_input_tokens], ) - generation_metadata = get_forward_context().generation_metadata - gen_indices_padded = ( - generation_metadata.generation_indices_padded - if generation_metadata is not None - else torch.arange(num_input_tokens, device=positions.device) - ) - num_gen_tokens_padded = gen_indices_padded.shape[0] - assert first_residual is not None - - # CUDA graph expects static tensor addresses - # Copy output of first layer group to second layer group - self.residual[:num_gen_tokens_padded].copy_( - first_residual[gen_indices_padded]) - self.hidden_states[:num_gen_tokens_padded].copy_( - first_hidden_states[gen_indices_padded]) - positions[:num_gen_tokens_padded].copy_(positions[gen_indices_padded]) + truncated_prefill_metadata = \ + get_forward_context().truncated_prefill_metadata + if truncated_prefill_metadata is not None: + gen_indices_padded = \ + truncated_prefill_metadata.generation_indices_padded + num_tokens = gen_indices_padded.shape[0] + # CUDA graph expects static tensor addresses + # Copy output of first layer group to second layer group + # TODO(sarckk): Move logic to @support_torch_compile + self.residual[:num_tokens].copy_(residual[gen_indices_padded]) + self.hidden_states[:num_tokens].copy_( + hidden_states[gen_indices_padded]) + positions[:num_tokens].copy_(positions[gen_indices_padded]) + else: + num_tokens = num_input_tokens + self.residual[:num_tokens].copy_(residual) + self.hidden_states[:num_tokens].copy_(hidden_states) second_hidden_states, second_residual = self.second_layer_group( - positions[:num_gen_tokens_padded], - self.hidden_states[:num_gen_tokens_padded], - self.residual[:num_gen_tokens_padded], + positions[:num_tokens], + self.hidden_states[:num_tokens], + self.residual[:num_tokens], ) - # NOTE: we need to pad generation indices for CUDA graph but only the - # first num_gen_tokens positions are actually valid. - num_gen_tokens = ( - generation_metadata.num_generation_tokens - if generation_metadata is not None - else num_gen_tokens_padded - ) - gen_indices = gen_indices_padded[:num_gen_tokens] - first_hidden_states[ - gen_indices] = second_hidden_states[:num_gen_tokens] - if first_residual is not None: - first_residual[gen_indices] = second_residual[:num_gen_tokens] + if truncated_prefill_metadata is not None: + gen_indices_padded =\ + truncated_prefill_metadata.generation_indices_padded + # NOTE: we need to pad generation indices for CUDA graph + # but only the first num_gen_indices positions are actually valid. + num_gen_indices = truncated_prefill_metadata.num_generation_indices + gen_indices = gen_indices_padded[:num_gen_indices] + hidden_states[gen_indices] = second_hidden_states[:num_gen_indices] + residual[gen_indices] = second_residual[:num_gen_indices] + else: + hidden_states = second_hidden_states + residual = second_residual - hidden_states, _ = self.norm(first_hidden_states, first_residual) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -345,7 +345,7 @@ def test_prompts(): @fork_new_process_for_each_test @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_kv_sharing_skip_prefill( +def test_kv_sharing_truncated_prefill( monkeypatch: pytest.MonkeyPatch, enforce_eager: bool, test_prompts: list[str], @@ -373,7 +373,7 @@ def test_kv_sharing_skip_prefill( llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", enforce_eager=enforce_eager, compilation_config=compilation_config, - kv_sharing_skip_prefill=True) + enable_kv_sharing_truncated_prefill=True) optimized_responses = llm.generate(test_prompts, sampling_params) misses = 0 diff --git a/vllm/config.py b/vllm/config.py index 7fb9d9808cc..1fe24bf01c3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1564,7 +1564,7 @@ class CacheConfig: checkpoint if available. Otherwise, the scales will default to 1.0.""" cpu_kvcache_space_bytes: Optional[int] = None """(CPU backend only) CPU key-value cache space.""" - kv_sharing_skip_prefill: bool = False + enable_kv_sharing_truncated_prefill: bool = False """Skip prefill for tokens where applicable in KV cache sharing scenarios where required key/value tensors have been populated in earlier KV sharing target layers.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d06077d43e3..9fc0221d705 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -472,7 +472,8 @@ class EngineArgs: override_attention_dtype: str = ModelConfig.override_attention_dtype calculate_kv_scales: bool = CacheConfig.calculate_kv_scales - kv_sharing_skip_prefill: bool = CacheConfig.kv_sharing_skip_prefill + enable_kv_sharing_truncated_prefill: bool = \ + CacheConfig.enable_kv_sharing_truncated_prefill additional_config: dict[str, Any] = \ get_field(VllmConfig, "additional_config") @@ -749,8 +750,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **cache_kwargs["cpu_offload_gb"]) cache_group.add_argument("--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]) - cache_group.add_argument("--kv-sharing-skip-prefill", - **cache_kwargs["kv_sharing_skip_prefill"]) + cache_group.add_argument( + "--enable-kv-sharing-truncated-prefill", + **cache_kwargs["enable_kv_sharing_truncated_prefill"]) # Tokenizer arguments tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) @@ -1161,7 +1163,8 @@ def create_engine_config( prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, - kv_sharing_skip_prefill=self.kv_sharing_skip_prefill, + enable_kv_sharing_truncated_prefill=self. + enable_kv_sharing_truncated_prefill, ) # Get the current placement group if Ray is initialized and diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index bda905edd54..fed24840609 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -194,7 +194,7 @@ def __init__( override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, - kv_sharing_skip_prefill: bool = False, + enable_kv_sharing_truncated_prefill: bool = False, **kwargs, ) -> None: """LLM constructor.""" @@ -268,7 +268,8 @@ def __init__( mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, - kv_sharing_skip_prefill=kv_sharing_skip_prefill, + enable_kv_sharing_truncated_prefill= + enable_kv_sharing_truncated_prefill, **kwargs, ) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a5733bf0a2e..dbc6e35560d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -27,11 +27,12 @@ @dataclass -class GenerationMetadata: - # Set dynamically for each forward pass - num_generation_tokens: int +class TruncatedPrefillMetadata: + num_generation_indices: int """ - No. of generation indices without padding + No. of generation indices without CUDA graph padding. + + Set dynamically for each forward pass. """ generation_indices_padded: torch.Tensor """ @@ -39,10 +40,12 @@ class GenerationMetadata: Includes the last prefill token and all decode tokens. Given N prompt tokens, the first N-1 tokens are not included as they are not used to sample tokens for generation. + + Set dynamically for each forward pass. """ def generation_indices_unpadded(self) -> torch.Tensor: - return self.generation_indices_padded[:self.num_generation_tokens] + return self.generation_indices_padded[:self.num_generation_indices] @dataclass @@ -114,7 +117,7 @@ class ForwardContext: # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None skip_cuda_graphs: bool = False - generation_metadata: Optional[GenerationMetadata] = None + truncated_prefill_metadata: Optional[TruncatedPrefillMetadata] = None _forward_context: Optional[ForwardContext] = None @@ -136,7 +139,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, - generation_metadata: Optional[GenerationMetadata] = None, + truncated_prefill_metadata: Optional[TruncatedPrefillMetadata] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -162,7 +165,7 @@ def set_forward_context( attn_metadata=attn_metadata, dp_metadata=dp_metadata, skip_cuda_graphs=skip_cuda_graphs, - generation_metadata=generation_metadata, + truncated_prefill_metadata=truncated_prefill_metadata, ) try: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e3a4af18eb1..80e769e4df9 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -142,8 +142,6 @@ class LocalAttentionMetadata: local_attn_metadata: Optional[LocalAttentionMetadata] = None - prefill_skipped_attn_metadata: Optional["FlashAttentionMetadata"] = None - def _get_sliding_window_configs( vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: @@ -209,74 +207,11 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def build_skip_prefill( - self, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - ) -> FlashAttentionMetadata: - num_reqs = common_attn_metadata.num_reqs - query_start_loc = common_attn_metadata.query_start_loc - seq_lens = common_attn_metadata.seq_lens - generation_indices = common_attn_metadata.generation_indices - # Example inputs - # num_reqs: 3 - # generation_indices: [14, 18, 19, 27] - # query_start_loc: [0, 15, 20, 28] - # seq_lens: [41, 31, 40] - - # Find how many decode indices belong to each request - # request_ids: [0, 1, 1, 2] - request_ids = torch.bucketize(generation_indices, - query_start_loc[1:], - right=True) - - # Figure out how many tokens are in each request - # num_decode_tokens: [1, 2, 1] - num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) - - # Calculate new query_start_loc with tokens in generation_indices - # decode_query_start_loc: [0, 1, 3, 4] - decode_query_start_loc = torch.empty(num_reqs + 1, - device=query_start_loc.device, - dtype=query_start_loc.dtype) - - decode_query_start_loc[0] = 0 - decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) - decode_max_query_len = int(num_decode_tokens.max().item()) - total_num_decode_tokens = int(num_decode_tokens.sum().item()) - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=decode_query_start_loc, - # TODO(sarckk): optimize - query_start_loc_np=decode_query_start_loc.cpu().numpy(), - seq_lens=seq_lens, - num_reqs=num_reqs, - num_actual_tokens=total_num_decode_tokens, - max_query_len=decode_max_query_len, - # Set to None so we don't recurse again - generation_indices=None, - ) - metadata = self.build( - common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - ) - return metadata - def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, ) -> FlashAttentionMetadata: - prefill_skipped_attn_metadata = None - if common_attn_metadata.generation_indices is not None: - # NOTE(sarckk): attention metadata for partial prefill skip case - # needs to be built first, otherwise the line below - # block_table.slot_mapping[num_actual_tokens:].fill_(-1) - # will override the correct slot mapping - prefill_skipped_attn_metadata = self.build_skip_prefill( - common_prefix_len=0, # disable cascade attention - common_attn_metadata=common_attn_metadata) - num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -444,7 +379,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, local_attn_metadata=local_attn_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, - prefill_skipped_attn_metadata=prefill_skipped_attn_metadata, ) return attn_metadata @@ -511,8 +445,6 @@ def __init__( raise NotImplementedError( "FlashAttention does not support fp8 kv-cache on this device.") - self.kv_sharing_skip_prefill = False - def forward( self, layer: torch.nn.Module, @@ -549,11 +481,6 @@ def forward( # Profiling run. return output - if (self.kv_sharing_target_layer_name is not None - and self.kv_sharing_skip_prefill - and attn_metadata.prefill_skipped_attn_metadata is not None): - attn_metadata = attn_metadata.prefill_skipped_attn_metadata - # IMPORTANT! # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 1d6569e2288..e2f3d4af27b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -48,9 +48,6 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" - generation_indices: Optional[torch.Tensor] = None - """indices used for decoding""" - query_start_loc_np: Optional[np.ndarray] = None """(batch_size + 1,), numpy equivalent of query_start_loc on the CPU""" diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 43456a987de..ca732ac56ae 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from math import prod from typing import Optional @@ -202,6 +202,8 @@ class KVCacheGroupSpec: layer_names: list[str] # The KV cache spec of this manager layer kv_cache_spec: KVCacheSpec + # The names of model layers for which prefill can be truncated + truncated_prefill_eligible_layers: list[str] = field(default_factory=list) @dataclass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 450de51addd..bedab4539a6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,7 +27,7 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, GenerationMetadata, +from vllm.forward_context import (DPMetadata, TruncatedPrefillMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase @@ -45,7 +45,6 @@ GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) -from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) @@ -318,7 +317,7 @@ def __init__( self.shared_kv_cache_layers: dict[str, str] = {} self.generation_indices = None - if self.cache_config.kv_sharing_skip_prefill: + if self.cache_config.enable_kv_sharing_truncated_prefill: self.generation_indices = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=self.device) @@ -581,15 +580,9 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange - def _calc_generation_metadata( - self, - logits_indices: torch.Tensor) -> Optional[GenerationMetadata]: - """ - Pads logits_indices to align with CUDA graph capture sizes - """ - if self.generation_indices is None: - assert not self.cache_config.kv_sharing_skip_prefill - return None + def _truncate_prefill(self) -> bool: + if not self.cache_config.enable_kv_sharing_truncated_prefill: + return False num_decode_reqs = 0 for req_index in range(self.input_batch.num_reqs): @@ -599,35 +592,67 @@ def _calc_generation_metadata( num_decode_reqs += 1 if self.input_batch.num_reqs == num_decode_reqs: - # All requests are on decode, skip calculate decode only indices - return None + # All requests on decode, no need to truncate prefill + return False - num_gen_tokens = logits_indices.shape[0] - # TODO(sarckk): With chunked prefills, logits_indices contains - # indices for partial requests though we do not sample any token - # from these partial requests, for simplicity. In the future, we - # can calculate the 'true' decode indices based on logits_indices - self.generation_indices[:num_gen_tokens].copy_(logits_indices) - # pad with last idx instead of zero - self.generation_indices[num_gen_tokens:].fill_( - logits_indices[-1].item()) - if (self.use_cuda_graph - and num_gen_tokens <= self.cudagraph_batch_sizes[-1]): - num_gen_tokens_padded = self.vllm_config.pad_for_cudagraph( - num_gen_tokens) - else: - num_gen_tokens_padded = num_gen_tokens + for kv_cache_group_spec in self.kv_cache_config.kv_cache_groups: + if kv_cache_group_spec.truncated_prefill_eligible_layers: + return True + + return False - return GenerationMetadata( - num_generation_tokens=num_gen_tokens, - generation_indices_padded=self.generation_indices[:num_gen_tokens_padded]) + def _calc_truncated_prefill_attn_metadata( + self, + logits_indices: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, + ) -> CommonAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # Example inputs + # num_reqs: 3 + # generation_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(logits_indices, + query_start_loc[1:], + right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc with tokens in generation_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty(num_reqs + 1, + device=query_start_loc.device, + dtype=query_start_loc.dtype) + + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + # TODO(sarckk): optimize + query_start_loc_np=decode_query_start_loc.cpu().numpy(), + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + ) + return common_attn_metadata def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[GenerationMetadata]]: + np.ndarray, Optional[TruncatedPrefillMetadata]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -772,8 +797,6 @@ def _prepare_inputs( num_draft_tokens, cu_num_tokens) logits_indices = spec_decode_metadata.logits_indices - generation_metadata = self._calc_generation_metadata(logits_indices) - common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_np=query_start_loc_np, @@ -781,11 +804,43 @@ def _prepare_inputs( num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, - generation_indices=( - generation_metadata.generation_indices_unpadded() - if generation_metadata is not None else None), ) + truncate_prefill = self._truncate_prefill() + truncated_prefill_metadata = None + truncated_prefill_common_attn_metadata = None + + if truncate_prefill: + assert self.generation_indices is not None + # TODO(sarckk): With chunked prefills, logits_indices contains + # indices for partial requests though we do not sample any token + # from these partial requests, for simplicity. In the future, we + # can calculate the 'true' decode indices based on logits_indices, + # hence the distinction from logits_indices + num_generation_indices = logits_indices.shape[0] + self.generation_indices[:num_generation_indices].copy_( + logits_indices) + # pad with last idx instead of zero + self.generation_indices[num_generation_indices:].fill_( + logits_indices[-1].item()) + if (self.use_cuda_graph and num_generation_indices + <= self.cudagraph_batch_sizes[-1]): + num_gen_indices_padded = self.vllm_config.pad_for_cudagraph( + num_generation_indices) + else: + num_gen_indices_padded = num_generation_indices + + truncated_prefill_metadata = TruncatedPrefillMetadata( + num_generation_indices=num_generation_indices, + generation_indices_padded=( + self.generation_indices[:num_gen_indices_padded])) + truncated_prefill_common_attn_metadata =\ + self._calc_truncated_prefill_attn_metadata( + # Use generation indices without CUDA graph padding for attn + truncated_prefill_metadata.generation_indices_unpadded(), + common_attn_metadata, + ) + attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. @@ -804,6 +859,18 @@ def _prepare_inputs( builder, ) + common_attn_metadata = common_attn_metadata + truncated_prefill_attn_metadata_i = None + if (truncated_prefill_common_attn_metadata is not None + and kv_cache_group_spec.truncated_prefill_eligible_layers): + truncated_prefill_attn_metadata_i = ( + builder.build( + # TODO(sarckk): Cascade attn for truncated prefill + common_prefix_len=0, + common_attn_metadata=( + truncated_prefill_common_attn_metadata), + )) + attn_metadata_i = (builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, @@ -812,6 +879,14 @@ def _prepare_inputs( for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i + if (kv_cache_group_spec.truncated_prefill_eligible_layers + is not None + and truncated_prefill_attn_metadata_i is not None): + for layer_name in \ + kv_cache_group_spec.truncated_prefill_eligible_layers: + attn_metadata[layer_name] =\ + truncated_prefill_attn_metadata_i + attention_cuda_graphs = all( b.can_run_in_cudagraph(common_attn_metadata) for b in self.attn_metadata_builders) @@ -822,7 +897,7 @@ def _prepare_inputs( return (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens, - generation_metadata) + truncated_prefill_metadata) def _compute_cascade_attn_prefix_len( self, @@ -1344,7 +1419,7 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, - generation_metadata) = (self._prepare_inputs(scheduler_output)) + truncated_prefill_metadata) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1415,12 +1490,13 @@ def execute_model( # Run the model. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, - self.vllm_config, - num_tokens=num_input_tokens, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs, - generation_metadata=generation_metadata): + with set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_input_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs, + truncated_prefill_metadata=truncated_prefill_metadata): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( @@ -2013,9 +2089,6 @@ def _dummy_run( dtype=np.int32) attn_metadata: Optional[dict[str, Any]] = None - generation_indices = torch.arange(num_tokens, - device=self.device, - dtype=torch.int) if capture_attn_cudagraph: attn_metadata = {} @@ -2036,7 +2109,6 @@ def _dummy_run( num_reqs=num_reqs, num_actual_tokens=num_tokens, max_query_len=num_tokens, - generation_indices=generation_indices, ) for kv_cache_group_id, kv_cache_group_spec in enumerate( @@ -2075,17 +2147,11 @@ def _dummy_run( intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - dummy_generation_metadata = GenerationMetadata( - num_generation_tokens=num_tokens, - generation_indices_padded=generation_indices, - ) - with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - generation_metadata=dummy_generation_metadata): + num_tokens_across_dp=num_tokens_across_dp): outputs = model( input_ids=input_ids, positions=positions, @@ -2607,7 +2673,10 @@ def initialize_kv_cache_tensors( # Setup `kv_cache_config` and `kv_caches` for models # with cross-layer KV sharing if self.shared_kv_cache_layers: + attn_layers = get_layers_from_vllm_config(self.vllm_config, + Attention) initialize_kv_cache_for_kv_sharing( + list(attn_layers.keys()), self.shared_kv_cache_layers, kv_cache_config.kv_cache_groups, kv_caches, @@ -2718,18 +2787,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size=max_model_len, page_size_padded=page_size_padded) - # Second pass to determine if N-1 prompt tokens can be skipped - # during prefill for layers that re-use shared KV cache - # Iterate in reversed order and note shared kv cache layers where - # there is no layer after it that allocates its own KV cache - for layer_name in reversed(attn_layers.keys()): - if layer_name in self.shared_kv_cache_layers: - attn_module = attn_layers[layer_name] - if isinstance(attn_module.impl, FlashAttentionImpl): - attn_module.impl.kv_sharing_skip_prefill = True - else: - break - return kv_cache_spec def _maybe_pad_mamba_page_size( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5af052e6851..e0a7da3d21e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1584,7 +1584,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # Setup `kv_cache_config` and `kv_caches` for models # with cross-layer KV sharing if self.shared_kv_cache_layers: + attn_layers = get_layers_from_vllm_config(self.vllm_config, + Attention) initialize_kv_cache_for_kv_sharing( + list(attn_layers.keys()), self.shared_kv_cache_layers, kv_cache_config.kv_cache_groups, kv_caches, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 70339ff2f00..41aa6e81a56 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -79,6 +79,7 @@ def gather_mm_placeholders( def initialize_kv_cache_for_kv_sharing( + attn_layer_names: list[str], shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], kv_caches: dict[str, torch.Tensor], @@ -106,7 +107,17 @@ def initialize_kv_cache_for_kv_sharing( for layer_name in kv_cache_group.layer_names: layer_to_kv_cache_group_idx[layer_name] = i + truncated_prefill_eligible_layers = set() + for layer_name in reversed(attn_layer_names): + if layer_name in shared_kv_cache_layers: + truncated_prefill_eligible_layers.add(layer_name) + else: + break + for layer_name, target_layer_name in shared_kv_cache_layers.items(): kv_caches[layer_name] = kv_caches[target_layer_name] group_idx = layer_to_kv_cache_group_idx[target_layer_name] kv_cache_groups[group_idx].layer_names.append(layer_name) + if layer_name in truncated_prefill_eligible_layers: + kv_cache_groups[ + group_idx].truncated_prefill_eligible_layers.append(layer_name)