diff --git a/tests/v1/e2e/test_kv_sharing_truncated_prefill.py b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py new file mode 100644 index 00000000000..d2052b6003b --- /dev/null +++ b/tests/v1/e2e/test_kv_sharing_truncated_prefill.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import gc +import random +from collections.abc import Iterable +from typing import Optional, Union + +import pytest +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm import LLM, SamplingParams +from vllm.compilation.backends import set_model_tag +from vllm.compilation.decorators import (ignore_torch_compile, + support_torch_compile) +from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel, + VllmConfig) +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.qwen2 import (Qwen2Attention, Qwen2MLP, + Qwen2Model) +from vllm.model_executor.models.registry import ModelRegistry +from vllm.model_executor.models.utils import (AutoWeightsLoader, + extract_layer_index, + maybe_prefix) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from ...utils import fork_new_process_for_each_test + +START_KV_SHARING_LAYER = 10 + + +class Qwen2DecoderLayerWithKVSharing(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + attn_prefix = f"{prefix}.self_attn" + layer_idx = extract_layer_index(prefix) + kv_sharing_target_layer_name = None + + if layer_idx >= START_KV_SHARING_LAYER: + target_layer_idx = START_KV_SHARING_LAYER - 1 + kv_sharing_target_layer_name = f"{attn_prefix}.attn".replace( + str(layer_idx), str(target_layer_idx)) + + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=attn_prefix, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + ) + + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class FirstLayerGroup(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layers: list[nn.Module], + ): + super().__init__() + self.layers = layers + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ): + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + return hidden_states, residual + + +@support_torch_compile +class SecondLayerGroup(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layers: list[nn.Module], + ): + super().__init__() + self.layers = layers + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + ): + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + return hidden_states, residual + + +@ignore_torch_compile +class Qwen2ModelWithKVSharing(Qwen2Model): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + decoder_layer_type: type[ + nn.Module] = Qwen2DecoderLayerWithKVSharing): + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + decoder_layer_type=decoder_layer_type, + ) + + self.vllm_config = vllm_config + + with set_model_tag("first_layer_group"): + self.first_layer_group = FirstLayerGroup( + vllm_config=vllm_config, + prefix=f"{prefix}.first_layer_group", + layers=self.layers[self.start_layer:START_KV_SHARING_LAYER], + ) + + with set_model_tag("second_layer_group"): + self.second_layer_group = SecondLayerGroup( + vllm_config=vllm_config, + prefix=f"{prefix}.second_layer_group", + layers=self.layers[START_KV_SHARING_LAYER:self.end_layer], + ) + + # Pre-allocate static buffers for CUDA graph + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + dtype = vllm_config.model_config.dtype + device = next(self.parameters()).device + hidden_size = vllm_config.model_config.get_hidden_size() + self.residual = torch.zeros((max_num_tokens, hidden_size), + dtype=dtype, + device=device) + self.hidden_states = torch.zeros((max_num_tokens, hidden_size), + dtype=dtype, + device=device) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + + num_input_tokens = input_ids.size(0) + self.hidden_states[:num_input_tokens].copy_(hidden_states) + + hidden_states, residual = self.first_layer_group( + positions, + self.hidden_states[:num_input_tokens], + ) + + truncated_prefill_metadata = \ + get_forward_context().truncated_prefill_metadata + if truncated_prefill_metadata is not None: + gen_indices_padded = \ + truncated_prefill_metadata.generation_indices_padded + num_tokens = gen_indices_padded.shape[0] + # CUDA graph expects static tensor addresses + # Copy output of first layer group to second layer group + # TODO(sarckk): Move logic to @support_torch_compile + self.residual[:num_tokens].copy_(residual[gen_indices_padded]) + self.hidden_states[:num_tokens].copy_( + hidden_states[gen_indices_padded]) + positions[:num_tokens].copy_(positions[gen_indices_padded]) + else: + num_tokens = num_input_tokens + self.residual[:num_tokens].copy_(residual) + self.hidden_states[:num_tokens].copy_(hidden_states) + + second_hidden_states, second_residual = self.second_layer_group( + positions[:num_tokens], + self.hidden_states[:num_tokens], + self.residual[:num_tokens], + ) + + if truncated_prefill_metadata is not None: + gen_indices_padded =\ + truncated_prefill_metadata.generation_indices_padded + # NOTE: we need to pad generation indices for CUDA graph + # but only the first num_gen_indices positions are actually valid. + num_gen_indices = truncated_prefill_metadata.num_generation_indices + gen_indices = gen_indices_padded[:num_gen_indices] + hidden_states[gen_indices] = second_hidden_states[:num_gen_indices] + residual[gen_indices] = second_residual[:num_gen_indices] + else: + hidden_states = second_hidden_states + residual = second_residual + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class TestQwen2ForCausalLM(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2ModelWithKVSharing( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + decoder_layer_type=Qwen2DecoderLayerWithKVSharing) + self.lm_head = self.model.embed_tokens + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + +@pytest.fixture +def test_prompts(): + """ + Adapted from tests/v1/e2e/test_spec_decode.py + """ + prompt_types = ["repeat", "sentence"] + # Setting higher num prompts increases the chance of numerics mismatch + # due to matrix multiplication numerics depending on batch dimension + num_prompts = 10 + prompts = [] + + random.seed(0) + random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + + for kind in random_prompt_type_choices: + word_choices = ["test", "temp", "hello", "where"] + word = random.choice(word_choices) + if kind == "repeat": + prompt = f"""please repeat the word '{word}' 10 times.""" + elif kind == "sentence": + prompt = f"""please give a ten-word sentence that + uses the word {word} at least once.""" + else: + raise ValueError(f"Unknown prompt type: {kind}") + prompts.append(prompt) + + return prompts + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_kv_sharing_truncated_prefill( + monkeypatch: pytest.MonkeyPatch, + enforce_eager: bool, + test_prompts: list[str], +): + ModelRegistry.register_model("Qwen2ForCausalLM", TestQwen2ForCausalLM) + sampling_params = SamplingParams(temperature=0.0, max_tokens=100) + compilation_config = CompilationConfig( + level=CompilationLevel. + PIECEWISE if not enforce_eager else CompilationLevel.NO_COMPILATION) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + model="Qwen/Qwen2-1.5B-Instruct", + enforce_eager=enforce_eager, + compilation_config=compilation_config, + ) + ref_responses = llm.generate(test_prompts, sampling_params) + + del llm + gc.collect() + torch.cuda.empty_cache() + + llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", + enforce_eager=enforce_eager, + compilation_config=compilation_config, + enable_kv_sharing_truncated_prefill=True) + optimized_responses = llm.generate(test_prompts, sampling_params) + + misses = 0 + + for ref_response, optimized_response in zip(ref_responses, + optimized_responses): + if ref_response.outputs[0].text != optimized_response.outputs[ + 0].text: + misses += 1 + + assert misses == 0 diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 05e4ca9f08b..e165a6d1f40 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -23,6 +23,22 @@ _T = TypeVar("_T", bound=type[nn.Module]) +def ignore_torch_compile(cls: _T) -> _T: + """ + A decorator to ignore support_torch_compile decorator + on the class. This is useful when a parent class has + a support_torch_compile decorator, but we don't want to + compile the class `cls` that inherits the parent class. + + This only ignores compiling the forward of the class the + decorator is applied to. If the class has one or more submodules + that have support_torch_compile decorator applied, compile will + not be ignored for those submodules. + """ + cls._ignore_compile_vllm = True + return cls + + @overload def support_torch_compile( *, @@ -156,7 +172,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() + ] or not supports_dynamo() or getattr( + self, "_ignore_compile_vllm", False) if self.do_not_compile: return compilation_counter.num_models_seen += 1 diff --git a/vllm/config.py b/vllm/config.py index b1f7f9e57a7..1fe24bf01c3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1564,6 +1564,10 @@ class CacheConfig: checkpoint if available. Otherwise, the scales will default to 1.0.""" cpu_kvcache_space_bytes: Optional[int] = None """(CPU backend only) CPU key-value cache space.""" + enable_kv_sharing_truncated_prefill: bool = False + """Skip prefill for tokens where applicable in KV cache sharing + scenarios where required key/value tensors have been populated + in earlier KV sharing target layers.""" # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f47499309d8..9fc0221d705 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -472,6 +472,8 @@ class EngineArgs: override_attention_dtype: str = ModelConfig.override_attention_dtype calculate_kv_scales: bool = CacheConfig.calculate_kv_scales + enable_kv_sharing_truncated_prefill: bool = \ + CacheConfig.enable_kv_sharing_truncated_prefill additional_config: dict[str, Any] = \ get_field(VllmConfig, "additional_config") @@ -748,6 +750,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **cache_kwargs["cpu_offload_gb"]) cache_group.add_argument("--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]) + cache_group.add_argument( + "--enable-kv-sharing-truncated-prefill", + **cache_kwargs["enable_kv_sharing_truncated_prefill"]) # Tokenizer arguments tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) @@ -1158,6 +1163,8 @@ def create_engine_config( prefix_caching_hash_algo=self.prefix_caching_hash_algo, cpu_offload_gb=self.cpu_offload_gb, calculate_kv_scales=self.calculate_kv_scales, + enable_kv_sharing_truncated_prefill=self. + enable_kv_sharing_truncated_prefill, ) # Get the current placement group if Ray is initialized and diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c60a566f585..fed24840609 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -194,6 +194,7 @@ def __init__( override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, dict[str, Any], CompilationConfig]] = None, + enable_kv_sharing_truncated_prefill: bool = False, **kwargs, ) -> None: """LLM constructor.""" @@ -267,6 +268,8 @@ def __init__( mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, compilation_config=compilation_config_instance, + enable_kv_sharing_truncated_prefill= + enable_kv_sharing_truncated_prefill, **kwargs, ) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dd55b19feea..dbc6e35560d 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -26,6 +26,28 @@ batchsize_forward_time: defaultdict = defaultdict(list) +@dataclass +class TruncatedPrefillMetadata: + num_generation_indices: int + """ + No. of generation indices without CUDA graph padding. + + Set dynamically for each forward pass. + """ + generation_indices_padded: torch.Tensor + """ + Indices of tokens used for sampling output tokens. + Includes the last prefill token and all decode tokens. + Given N prompt tokens, the first N-1 tokens are not included as + they are not used to sample tokens for generation. + + Set dynamically for each forward pass. + """ + + def generation_indices_unpadded(self) -> torch.Tensor: + return self.generation_indices_padded[:self.num_generation_indices] + + @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor @@ -95,6 +117,7 @@ class ForwardContext: # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None skip_cuda_graphs: bool = False + truncated_prefill_metadata: Optional[TruncatedPrefillMetadata] = None _forward_context: Optional[ForwardContext] = None @@ -116,6 +139,7 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, skip_cuda_graphs: bool = False, + truncated_prefill_metadata: Optional[TruncatedPrefillMetadata] = None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -141,6 +165,7 @@ def set_forward_context( attn_metadata=attn_metadata, dp_metadata=dp_metadata, skip_cuda_graphs=skip_cuda_graphs, + truncated_prefill_metadata=truncated_prefill_metadata, ) try: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7ef9d248da4..e4b5d674ff6 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -110,6 +110,7 @@ def __init__( prefix: str = "", attn_type: str = AttentionType.DECODER, dual_chunk_attention_config: Optional[dict[str, Any]] = None, + **attn_kwargs, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -171,7 +172,8 @@ def __init__( **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, - } if dual_chunk_attention_config else {}) + } if dual_chunk_attention_config else {}, + **attn_kwargs) def forward( self, diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fbc13c06c65..80e769e4df9 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -208,8 +208,9 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.aot_sliding_window: Optional[tuple[int, int]] = None def build( - self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, ) -> FlashAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens @@ -217,6 +218,9 @@ def build( max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_np = common_attn_metadata.query_start_loc_np + if query_start_loc_np is None: + query_start_loc_np = self.runner.query_start_loc_np[:num_reqs + 1] seq_lens = common_attn_metadata.seq_lens block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] @@ -271,7 +275,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ virt_block_table_tensor = make_local_attention_virtual_batches( self.runner.attention_chunk_size, - self.runner.query_start_loc_np[:num_reqs + 1], + query_start_loc_np, self.runner.seq_lens_np[:num_reqs], block_table_tensor, self.block_size, diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 88adc32406e..e2f3d4af27b 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -36,6 +36,7 @@ class CommonAttentionMetadata: query_start_loc: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" @@ -47,6 +48,9 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" + query_start_loc_np: Optional[np.ndarray] = None + """(batch_size + 1,), numpy equivalent of query_start_loc on the CPU""" + M = TypeVar("M") diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 43456a987de..ca732ac56ae 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from dataclasses import dataclass +from dataclasses import dataclass, field from math import prod from typing import Optional @@ -202,6 +202,8 @@ class KVCacheGroupSpec: layer_names: list[str] # The KV cache spec of this manager layer kv_cache_spec: KVCacheSpec + # The names of model layers for which prefill can be truncated + truncated_prefill_eligible_layers: list[str] = field(default_factory=list) @dataclass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f3279fa5fa8..bedab4539a6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -27,8 +27,8 @@ from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) -from vllm.forward_context import (DPMetadata, get_forward_context, - set_forward_context) +from vllm.forward_context import (DPMetadata, TruncatedPrefillMetadata, + get_forward_context, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -316,6 +316,12 @@ def __init__( # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + self.generation_indices = None + if self.cache_config.enable_kv_sharing_truncated_prefill: + self.generation_indices = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -574,11 +580,79 @@ def _get_cumsum_and_arange( return cu_num_tokens, arange + def _truncate_prefill(self) -> bool: + if not self.cache_config.enable_kv_sharing_truncated_prefill: + return False + + num_decode_reqs = 0 + for req_index in range(self.input_batch.num_reqs): + if self.input_batch.num_computed_tokens_cpu[ + req_index] >= self.input_batch.num_prompt_tokens[ + req_index]: + num_decode_reqs += 1 + + if self.input_batch.num_reqs == num_decode_reqs: + # All requests on decode, no need to truncate prefill + return False + + for kv_cache_group_spec in self.kv_cache_config.kv_cache_groups: + if kv_cache_group_spec.truncated_prefill_eligible_layers: + return True + + return False + + def _calc_truncated_prefill_attn_metadata( + self, + logits_indices: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, + ) -> CommonAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + # Example inputs + # num_reqs: 3 + # generation_indices: [14, 18, 19, 27] + # query_start_loc: [0, 15, 20, 28] + # seq_lens: [41, 31, 40] + + # Find how many decode indices belong to each request + # request_ids: [0, 1, 1, 2] + request_ids = torch.bucketize(logits_indices, + query_start_loc[1:], + right=True) + + # Figure out how many tokens are in each request + # num_decode_tokens: [1, 2, 1] + num_decode_tokens = torch.bincount(request_ids, minlength=num_reqs) + + # Calculate new query_start_loc with tokens in generation_indices + # decode_query_start_loc: [0, 1, 3, 4] + decode_query_start_loc = torch.empty(num_reqs + 1, + device=query_start_loc.device, + dtype=query_start_loc.dtype) + + decode_query_start_loc[0] = 0 + decode_query_start_loc[1:] = torch.cumsum(num_decode_tokens, dim=0) + decode_max_query_len = int(num_decode_tokens.max().item()) + total_num_decode_tokens = int(num_decode_tokens.sum().item()) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=decode_query_start_loc, + # TODO(sarckk): optimize + query_start_loc_np=decode_query_start_loc.cpu().numpy(), + seq_lens=seq_lens, + num_reqs=num_reqs, + num_actual_tokens=total_num_decode_tokens, + max_query_len=decode_max_query_len, + ) + return common_attn_metadata + def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray]: + ) -> tuple[dict[str, + Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], + np.ndarray, Optional[TruncatedPrefillMetadata]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -696,16 +770,77 @@ def _prepare_inputs( self.query_start_loc_cpu[num_reqs].item()) query_start_loc = self.query_start_loc[:num_reqs + 1] + query_start_loc_np = self.query_start_loc_np[:num_reqs + 1] seq_lens = self.seq_lens[:num_reqs] + use_spec_decode = len( + scheduler_output.scheduled_spec_decode_tokens) > 0 + if not use_spec_decode: + # NOTE(woosuk): Due to chunked prefills, the batch may contain + # partial requests. While we should not sample any token + # from these partial requests, we do so for simplicity. + # We will ignore the sampled tokens from the partial requests. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + spec_decode_metadata = None + else: + # Get the number of draft tokens for each request. + # Iterate over the dictionary rather than all requests since not all + # requests have draft tokens. + num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) + for req_id, draft_token_ids in ( + scheduler_output.scheduled_spec_decode_tokens.items()): + req_idx = self.input_batch.req_id_to_index[req_id] + num_draft_tokens[req_idx] = len(draft_token_ids) + + spec_decode_metadata = self._calc_spec_decode_metadata( + num_draft_tokens, cu_num_tokens) + logits_indices = spec_decode_metadata.logits_indices + common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, + query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, ) + truncate_prefill = self._truncate_prefill() + truncated_prefill_metadata = None + truncated_prefill_common_attn_metadata = None + + if truncate_prefill: + assert self.generation_indices is not None + # TODO(sarckk): With chunked prefills, logits_indices contains + # indices for partial requests though we do not sample any token + # from these partial requests, for simplicity. In the future, we + # can calculate the 'true' decode indices based on logits_indices, + # hence the distinction from logits_indices + num_generation_indices = logits_indices.shape[0] + self.generation_indices[:num_generation_indices].copy_( + logits_indices) + # pad with last idx instead of zero + self.generation_indices[num_generation_indices:].fill_( + logits_indices[-1].item()) + if (self.use_cuda_graph and num_generation_indices + <= self.cudagraph_batch_sizes[-1]): + num_gen_indices_padded = self.vllm_config.pad_for_cudagraph( + num_generation_indices) + else: + num_gen_indices_padded = num_generation_indices + + truncated_prefill_metadata = TruncatedPrefillMetadata( + num_generation_indices=num_generation_indices, + generation_indices_padded=( + self.generation_indices[:num_gen_indices_padded])) + truncated_prefill_common_attn_metadata =\ + self._calc_truncated_prefill_attn_metadata( + # Use generation indices without CUDA graph padding for attn + truncated_prefill_metadata.generation_indices_unpadded(), + common_attn_metadata, + ) + attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers # in the same group share the same metadata. @@ -724,6 +859,18 @@ def _prepare_inputs( builder, ) + common_attn_metadata = common_attn_metadata + truncated_prefill_attn_metadata_i = None + if (truncated_prefill_common_attn_metadata is not None + and kv_cache_group_spec.truncated_prefill_eligible_layers): + truncated_prefill_attn_metadata_i = ( + builder.build( + # TODO(sarckk): Cascade attn for truncated prefill + common_prefix_len=0, + common_attn_metadata=( + truncated_prefill_common_attn_metadata), + )) + attn_metadata_i = (builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, @@ -732,40 +879,25 @@ def _prepare_inputs( for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i + if (kv_cache_group_spec.truncated_prefill_eligible_layers + is not None + and truncated_prefill_attn_metadata_i is not None): + for layer_name in \ + kv_cache_group_spec.truncated_prefill_eligible_layers: + attn_metadata[layer_name] =\ + truncated_prefill_attn_metadata_i + attention_cuda_graphs = all( b.can_run_in_cudagraph(common_attn_metadata) for b in self.attn_metadata_builders) - use_spec_decode = len( - scheduler_output.scheduled_spec_decode_tokens) > 0 - if not use_spec_decode: - # NOTE(woosuk): Due to chunked prefills, the batch may contain - # partial requests. While we should not sample any token - # from these partial requests, we do so for simplicity. - # We will ignore the sampled tokens from the partial requests. - # TODO: Support prompt logprobs. - logits_indices = query_start_loc[1:] - 1 - spec_decode_metadata = None - else: - # Get the number of draft tokens for each request. - # Iterate over the dictionary rather than all requests since not all - # requests have draft tokens. - num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) - for req_id, draft_token_ids in ( - scheduler_output.scheduled_spec_decode_tokens.items()): - req_idx = self.input_batch.req_id_to_index[req_id] - num_draft_tokens[req_idx] = len(draft_token_ids) - - spec_decode_metadata = self._calc_spec_decode_metadata( - num_draft_tokens, cu_num_tokens) - logits_indices = spec_decode_metadata.logits_indices - # Hot-Swap lora model if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens) + spec_decode_metadata, num_scheduled_tokens, + truncated_prefill_metadata) def _compute_cascade_attn_prefix_len( self, @@ -1286,8 +1418,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) + spec_decode_metadata, num_scheduled_tokens_np, + truncated_prefill_metadata) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1364,7 +1496,7 @@ def execute_model( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, - ): + truncated_prefill_metadata=truncated_prefill_metadata): self.maybe_setup_kv_connector(scheduler_output) model_output = self.model( @@ -1957,10 +2089,12 @@ def _dummy_run( dtype=np.int32) attn_metadata: Optional[dict[str, Any]] = None + if capture_attn_cudagraph: attn_metadata = {} query_start_loc = self.query_start_loc[:num_reqs + 1] + query_start_loc_np = self.query_start_loc_np[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. self.seq_lens_np[:num_reqs] = self.max_model_len self.seq_lens_np[num_reqs:] = 0 @@ -1970,6 +2104,7 @@ def _dummy_run( common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, + query_start_loc_np=query_start_loc_np, seq_lens=seq_lens, num_reqs=num_reqs, num_actual_tokens=num_tokens, @@ -2538,7 +2673,10 @@ def initialize_kv_cache_tensors( # Setup `kv_cache_config` and `kv_caches` for models # with cross-layer KV sharing if self.shared_kv_cache_layers: + attn_layers = get_layers_from_vllm_config(self.vllm_config, + Attention) initialize_kv_cache_for_kv_sharing( + list(attn_layers.keys()), self.shared_kv_cache_layers, kv_cache_config.kv_cache_groups, kv_caches, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 5af052e6851..e0a7da3d21e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1584,7 +1584,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: # Setup `kv_cache_config` and `kv_caches` for models # with cross-layer KV sharing if self.shared_kv_cache_layers: + attn_layers = get_layers_from_vllm_config(self.vllm_config, + Attention) initialize_kv_cache_for_kv_sharing( + list(attn_layers.keys()), self.shared_kv_cache_layers, kv_cache_config.kv_cache_groups, kv_caches, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 70339ff2f00..41aa6e81a56 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -79,6 +79,7 @@ def gather_mm_placeholders( def initialize_kv_cache_for_kv_sharing( + attn_layer_names: list[str], shared_kv_cache_layers: dict[str, str], kv_cache_groups: list[KVCacheGroupSpec], kv_caches: dict[str, torch.Tensor], @@ -106,7 +107,17 @@ def initialize_kv_cache_for_kv_sharing( for layer_name in kv_cache_group.layer_names: layer_to_kv_cache_group_idx[layer_name] = i + truncated_prefill_eligible_layers = set() + for layer_name in reversed(attn_layer_names): + if layer_name in shared_kv_cache_layers: + truncated_prefill_eligible_layers.add(layer_name) + else: + break + for layer_name, target_layer_name in shared_kv_cache_layers.items(): kv_caches[layer_name] = kv_caches[target_layer_name] group_idx = layer_to_kv_cache_group_idx[target_layer_name] kv_cache_groups[group_idx].layer_names.append(layer_name) + if layer_name in truncated_prefill_eligible_layers: + kv_cache_groups[ + group_idx].truncated_prefill_eligible_layers.append(layer_name)