From f1c861d584b216a23ee669a5c223ee92b7841b74 Mon Sep 17 00:00:00 2001 From: qizixi Date: Wed, 2 Jul 2025 15:46:27 -0700 Subject: [PATCH 1/9] [Meta] Llama4 EAGLE Support Co-authored-by: Zixi Qi Signed-off-by: qizixi --- examples/offline_inference/spec_decode.py | 3 +- vllm/model_executor/models/llama4_eagle.py | 197 +++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 3 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/llama4_eagle.py diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 26e492fed25..eb949afc3a0 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -81,9 +81,10 @@ def main(): tensor_parallel_size=args.tp, enable_chunked_prefill=args.enable_chunked_prefill, enforce_eager=args.enforce_eager, - gpu_memory_utilization=0.8, + gpu_memory_utilization=0.7, speculative_config=speculative_config, disable_log_stats=False, + max_model_len=16384, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py new file mode 100644 index 00000000000..5a7a7e9613e --- /dev/null +++ b/vllm/model_executor/models/llama4_eagle.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Iterable + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.torchao import TorchAOConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, + Llama4ForCausalLM) +from vllm.model_executor.models.utils import extract_layer_index + +from .utils import AutoWeightsLoader, maybe_prefix +from typing import Optional + +logger = init_logger(__name__) + + +@support_torch_compile +class LlamaModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + self.validate_and_update_config(start_layer_id, quant_config) + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList([ + Llama4DecoderLayer( + self.config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) + ]) + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.fc( + torch.cat((input_embeds, hidden_states), dim=-1)) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + name = name.removeprefix("model.") + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + for name in params_dict: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + assert name in loaded_params, f"{name} is not loaded!" + return loaded_params + + def validate_and_update_config( + self, + start_layer_id: int, + quant_config: Optional[QuantizationConfig] = None) -> None: + # yoco and moe is not supported by draft model yet + assert self.config.yoco_global_kv_layer is None + assert self.config.yoco_local_kv_layer is None + assert len(self.config.moe_layers) == 0 + # draft model layer index is increased by start_layer_id, + # so we need to pad relevant configs accordingly + self.config.no_rope_layers = [ + 0 + ] * start_layer_id + self.config.no_rope_layers + # currently only TorchAO quantization is supported + if isinstance(quant_config, TorchAOConfig): + + def pad_layer_name(layer: str) -> str: + layer_index = extract_layer_index(layer) + return layer.replace(str(layer_index), + str(layer_index + start_layer_id)) + + quant_config.torchao_config.module_fqn_to_config = { + pad_layer_name(layer): quantization + for layer, quantization in + quant_config.torchao_config.module_fqn_to_config.items() + } + + +class EagleLlama4ForCausalLM(Llama4ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config. \ + speculative_config.draft_model_config.hf_config + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + # draft model quantization config may differ from target model + quant_config = VllmConfig.get_quantization_config( + vllm_config.speculative_config.draft_model_config, + vllm_config.load_config) + self.model = LlamaModel(vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num, + quant_config=quant_config) + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + # lm_head is tied with target model (Llama4ForCausalLM) + skip_prefixes=(["lm_head."]), + ) + + model_weights = {} + weights = [ + self.permute_qk_weight_for_rotary(name, loaded_weight) + for name, loaded_weight in weights + ] + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 27d47692985..fd43ed836f2 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -239,6 +239,7 @@ "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), From 41785434e61d698bcc5235946d415bf5946cb8a2 Mon Sep 17 00:00:00 2001 From: morgendave Date: Tue, 8 Jul 2025 07:33:04 -0700 Subject: [PATCH 2/9] linting fixes --- vllm/model_executor/models/llama4_eagle.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 5a7a7e9613e..e65e991bac0 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -22,7 +23,6 @@ from vllm.model_executor.models.utils import extract_layer_index from .utils import AutoWeightsLoader, maybe_prefix -from typing import Optional logger = init_logger(__name__) @@ -39,8 +39,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) self.validate_and_update_config(start_layer_id, quant_config) self.vocab_size = self.config.vocab_size self.embed_tokens = VocabParallelEmbedding( @@ -153,8 +153,8 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) - self.config = vllm_config. \ - speculative_config.draft_model_config.hf_config + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) target_layer_num = vllm_config.model_config.get_num_layers( vllm_config.parallel_config) # draft model quantization config may differ from target model @@ -177,7 +177,10 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]] + ) -> None: loader = AutoWeightsLoader( self, # lm_head is tied with target model (Llama4ForCausalLM) From 2f8fe498c461110dd6c0a6d9efd1e63769e81c21 Mon Sep 17 00:00:00 2001 From: morgendave Date: Tue, 8 Jul 2025 13:31:45 -0700 Subject: [PATCH 3/9] Add llama4 eagle to e2e test for spec decode --- tests/v1/e2e/test_spec_decode.py | 46 ++++++++++++++-------- vllm/model_executor/models/llama4_eagle.py | 7 ++-- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 93e7c12f3a0..df1dbdd4816 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -6,8 +6,10 @@ from typing import Any import pytest +import torch from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory @pytest.fixture @@ -53,14 +55,6 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -def eagle_model_name(): - return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - - -def eagle3_model_name(): - return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - - def test_ngram_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], @@ -77,6 +71,8 @@ def test_ngram_correctness( ref_llm = LLM(model=model_name, max_model_len=1024) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() spec_llm = LLM( model=model_name, @@ -103,34 +99,48 @@ def test_ngram_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.7 * len(ref_outputs)) del spec_llm - - -@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("model_setup", [ + ("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), + ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), + ("eagle", "/home/zhiweiz/local/models/scout_base_HF_20250605_201140", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), +], + ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, - model_name: str, - use_eagle3: bool, + model_setup: tuple[str, str, str, int], ): ''' Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. + model_setup: (method, model_name, eagle_model_name, tp_size) ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + method, model_name, spec_model_name, tp_size = model_setup - ref_llm = LLM(model=model_name, max_model_len=2048) + ref_llm = LLM(model=model_name, + max_model_len=2048, + tensor_parallel_size=tp_size) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() - spec_model_name = eagle3_model_name( - ) if use_eagle3 else eagle_model_name() spec_llm = LLM( model=model_name, trust_remote_code=True, + tensor_parallel_size=tp_size, speculative_config={ - "method": "eagle3" if use_eagle3 else "eagle", + "method": method, "model": spec_model_name, "num_speculative_tokens": 3, "max_model_len": 2048, @@ -152,3 +162,5 @@ def test_eagle_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index e65e991bac0..1feae5d60c3 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable from typing import Optional @@ -177,10 +178,8 @@ def forward( ) -> tuple[torch.Tensor, torch.Tensor]: return self.model(input_ids, positions, hidden_states) - def load_weights( - self, - weights: Iterable[tuple[str, torch.Tensor]] - ) -> None: + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> None: loader = AutoWeightsLoader( self, # lm_head is tied with target model (Llama4ForCausalLM) From 4238b3a796f03a6403bcaf583d10f82f6f027afb Mon Sep 17 00:00:00 2001 From: morgendave Date: Thu, 10 Jul 2025 07:47:05 -0700 Subject: [PATCH 4/9] Add Meta copyright and license --- examples/offline_inference/spec_decode.py | 2 +- tests/v1/e2e/test_spec_decode.py | 2 +- vllm/model_executor/models/llama4_eagle.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index eb949afc3a0..ce735f3b27d 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -81,7 +81,7 @@ def main(): tensor_parallel_size=args.tp, enable_chunked_prefill=args.enable_chunked_prefill, enforce_eager=args.enforce_eager, - gpu_memory_utilization=0.7, + gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, max_model_len=16384, diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index df1dbdd4816..ef278ae928a 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -108,7 +108,7 @@ def test_ngram_correctness( "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), - ("eagle", "/home/zhiweiz/local/models/scout_base_HF_20250605_201140", + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), ], ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"]) diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 1feae5d60c3..222ab5dfaee 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -1,5 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. +# All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from collections.abc import Iterable from typing import Optional From 716ec8cf79c8e3cbfc0d956730fccc1b3f4bbcad Mon Sep 17 00:00:00 2001 From: morgendave Date: Mon, 14 Jul 2025 13:35:42 -0700 Subject: [PATCH 5/9] add pytest skip for eagle llama4-scout to avoid OOM in CI --- tests/models/registry.py | 7 ++++++- tests/v1/e2e/test_spec_decode.py | 6 ++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 48302f9d664..b94376970ce 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -451,6 +451,11 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", tokenizer="meta-llama/Llama-3.1-8B-Instruct"), + "EagleLlama4ForCausalLM": _HfExamplesInfo( + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + trust_remote_code=True, + speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 "EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16", trust_remote_code=True, is_available_online=False, @@ -500,4 +505,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: raise ValueError(f"No example model defined for {model_id}") -HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) \ No newline at end of file +HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index ef278ae928a..2423f966acf 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -108,8 +108,10 @@ def test_ngram_correctness( "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), ], ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"]) def test_eagle_correctness( From d64bf91ee33c6283470020224ccb5fec42167458 Mon Sep 17 00:00:00 2001 From: morgendave Date: Fri, 9 May 2025 16:44:17 -0700 Subject: [PATCH 6/9] eagle mm support, primarily llama4 Signed-off-by: morgendave --- examples/offline_inference/spec_decode.py | 60 ++++++++++++++++++--- tests/v1/e2e/test_spec_decode.py | 57 ++++++++++++++------ vllm/model_executor/models/llama4.py | 1 + vllm/model_executor/models/llama4_eagle.py | 35 +++++++++++-- vllm/model_executor/models/llama_eagle.py | 6 +++ vllm/model_executor/models/llama_eagle3.py | 5 ++ vllm/v1/spec_decode/eagle.py | 61 ++++++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 10 +++- 8 files changed, 199 insertions(+), 36 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index ce735f3b27d..43fc156e78c 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -13,6 +13,38 @@ from argparse import ArgumentParser as FlexibleArgumentParser +QUESTION = "What is the content of each image?" +IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", + "https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG", + "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg", + "https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg", + "https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg", + "https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg", +] + + +def get_custom_mm_prompts(num_prompts): + prompts = [] + for url in IMAGE_URLS: + prompts.append( + [ + {"type": "image_url", "image_url": {"url": url}}, + {"type": "text", "text": QUESTION}, + ] + ) + if num_prompts > len(IMAGE_URLS): + prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) + + return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] + + def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) @@ -35,6 +67,7 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--custom-mm-prompts", action="store_true") return parser.parse_args() @@ -46,12 +79,18 @@ def main(): if args.model_dir is None: model_dir = "meta-llama/Llama-3.1-8B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_dir) - - prompts = get_samples(args, tokenizer) - # add_special_tokens is False to avoid adding bos twice when using chat templates - prompt_ids = [ - tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts - ] + args.custom_skip_chat_template = True + + if not args.custom_mm_prompts: + prompts = get_samples(args, tokenizer) + # add_special_tokens is False to avoid adding bos twice + # when using chat templates + prompt_ids = [ + tokenizer.encode(prompt.prompt, add_special_tokens=False) + for prompt in prompts + ] + else: + prompts = get_custom_mm_prompts(args.num_prompts) if args.method == "eagle" or args.method == "eagle3": eagle_dir = args.eagle_dir @@ -85,10 +124,17 @@ def main(): speculative_config=speculative_config, disable_log_stats=False, max_model_len=16384, + limit_mm_per_prompt={"image": 5}, + disable_chunked_mm_input=True, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) - outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + if not args.custom_mm_prompts: + outputs = llm.generate( + prompt_token_ids=prompt_ids, sampling_params=sampling_params + ) + else: + outputs = llm.chat(prompts, sampling_params=sampling_params) # print the generated text if args.print_output: diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 2423f966acf..b603fabed5f 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -9,23 +9,28 @@ import torch from vllm import LLM, SamplingParams +from vllm.assets.base import VLLM_S3_BUCKET_URL +from vllm.assets.image import VLM_IMAGES_DIR from vllm.distributed import cleanup_dist_env_and_memory -@pytest.fixture -def test_prompts(): +def get_test_prompts(mm_enabled: bool): prompt_types = ["repeat", "sentence"] - num_prompts = 100 + if mm_enabled: + prompt_types.append("mm") + num_prompts = 10 prompts = [] random.seed(0) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + print(f"Prompt types: {random_prompt_type_choices}") # Generate a mixed batch of prompts, some of which can be easily # predicted by n-gram matching and some which likely cannot. for kind in random_prompt_type_choices: word_choices = ["test", "temp", "hello", "where"] word = random.choice(word_choices) + prompt: str | list[dict[str, Any]] = "" if kind == "repeat": prompt = f""" please repeat the word '{word}' 10 times. @@ -38,6 +43,21 @@ def test_prompts(): uses the word {word} at least once. give no other output than that simple sentence without quotes. """ + elif kind == "mm": + placeholders = [{ + "type": "image_url", + "image_url": { + "url": + f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" + }, + }] + prompt = [ + *placeholders, + { + "type": "text", + "text": "The meaning of the image is" + }, + ] else: raise ValueError(f"Unknown prompt type: {kind}") prompts.append([{"role": "user", "content": prompt}]) @@ -103,23 +123,30 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() -@pytest.mark.parametrize("model_setup", [ - ("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), - ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), -], - ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"]) +@pytest.mark.parametrize( + "model_setup,mm_enabled", [ + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + ], + ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, - test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], + mm_enabled: bool, ): + # Generate test prompts inside the function instead of using fixture + test_prompts = get_test_prompts(mm_enabled) ''' Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 0c9baab1f2e..442301e4e22 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -255,6 +255,7 @@ def __init__( super().__init__() self.layer_idx = extract_layer_index(prefix) + self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.hidden_size = config.hidden_size rope_theta = config.rope_theta rope_scaling = config.rope_scaling diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index 222ab5dfaee..ece490ff2f2 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -37,8 +37,9 @@ from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, Llama4ForCausalLM) from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.inputs import NestedTensors -from .utils import AutoWeightsLoader, maybe_prefix +from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings logger = init_logger(__name__) @@ -78,15 +79,23 @@ def __init__( self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps) + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: Optional[torch.Tensor], positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - input_embeds = self.embed_tokens(input_ids) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) hidden_states = self.fc( - torch.cat((input_embeds, hidden_states), dim=-1)) + torch.cat((inputs_embeds, hidden_states), dim=-1)) residual = None for layer in self.layers: hidden_states, residual = layer( @@ -190,8 +199,9 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - return self.model(input_ids, positions, hidden_states) + return self.model(input_ids, positions, hidden_states, inputs_embeds) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: @@ -212,3 +222,20 @@ def load_weights(self, weights: Iterable[tuple[str, model_weights[name] = loaded_weight loader.load_weights(model_weights.items()) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_index, + ) + + return inputs_embeds diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index c7690604c1d..a4933b77e3a 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -148,7 +149,12 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7fc9fe2ebb6..71275f0d585 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -202,7 +202,12 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) return self.model(input_ids, positions, hidden_states) def compute_logits( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6661d984a77..e6ab156e966 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import torch import torch.nn as nn @@ -39,6 +41,7 @@ def __init__( self.runner = runner self.dtype = vllm_config.model_config.dtype + self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size self.num_speculative_tokens = ( @@ -50,6 +53,9 @@ def __init__( # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() + self.is_multimodal_model = vllm_config.model_config \ + .is_multimodal_model + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) @@ -75,6 +81,11 @@ def __init__( device=device, dtype=torch.int32) + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) + def propose( self, # [num_tokens] @@ -92,6 +103,7 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, + mm_embeds: Optional[list[torch.Tensor]] = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -168,14 +180,28 @@ def propose( # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states + if self.is_multimodal_model: + input_ids = self.input_ids[:num_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:num_tokens] = inputs_embeds + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:num_input_tokens] with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): ret_hidden_states = self.model( - self.input_ids[:num_input_tokens], - self.positions[:num_input_tokens], - self.hidden_states[:num_input_tokens], + input_ids=input_ids, + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + inputs_embeds=inputs_embeds, ) if self.method == "deepseek_mtp": last_hidden_states = ret_hidden_states @@ -253,15 +279,24 @@ def propose( self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states + if self.is_multimodal_model: + inputs_embeds = self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:batch_size] = inputs_embeds + inputs_embeds = self.inputs_embeds[:input_batch_size] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:input_batch_size] # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( - self.input_ids[:input_batch_size], - self.positions[:input_batch_size], - self.hidden_states[:input_batch_size], + input_ids=input_ids, + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], @@ -372,10 +407,18 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + self.model( - self.input_ids[:num_tokens], - self.positions[:num_tokens], - self.hidden_states[:num_tokens], + input_ids=input_ids, + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, ) def validate_same_kv_cache_group(self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a26e88db1f..f8d8661bdad 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1044,13 +1044,15 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", + shift_computed_tokens: int = 0, ) -> list[torch.Tensor]: mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens + num_computed_tokens = \ + req_state.num_computed_tokens + shift_computed_tokens mm_positions = req_state.mm_positions for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset @@ -1673,6 +1675,11 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] + mm_embeds = None + if self.is_multimodal_model: + mm_embeds = self._gather_mm_embeddings(scheduler_output, + shift_computed_tokens=1) + draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1682,6 +1689,7 @@ def propose_draft_token_ids( cu_num_tokens=cu_num_tokens, block_table=block_table, sampling_metadata=sampling_metadata, + mm_embeds=mm_embeds, ) spec_token_ids = draft_token_ids.tolist() return spec_token_ids From 5546b9f8881ed88f7809c5a8e112320f7007970e Mon Sep 17 00:00:00 2001 From: morgendave Date: Tue, 10 Jun 2025 15:54:30 -0700 Subject: [PATCH 7/9] initial commit for non-shifting prefill in eagle, prepare for kv sharing rope change rope change rebase --- examples/offline_inference/spec_decode.py | 3 + vllm/config.py | 13 + vllm/model_executor/models/llama4.py | 5 + vllm/v1/spec_decode/eagle.py | 308 ++++++++++++++++++++-- vllm/v1/worker/gpu_input_batch.py | 4 + vllm/v1/worker/gpu_model_runner.py | 91 +++++-- 6 files changed, 379 insertions(+), 45 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 43fc156e78c..604c4fbe0ef 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -68,6 +68,8 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument("--no-prefill-token-shift", dest="prefill_token_shift", + action="store_false", help="Disable prefill token shift (default: enabled)") return parser.parse_args() @@ -103,6 +105,7 @@ def main(): "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, + "prefill_token_shift": args.prefill_token_shift, } elif args.method == "ngram": speculative_config = { diff --git a/vllm/config.py b/vllm/config.py index bac18e8175d..29aef108861 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2551,6 +2551,14 @@ class SpeculativeConfig: ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" + # Shift prefill tokens for draft, only used in eagle + prefill_token_shift: bool = True + """Shift tokens during draft prefill or not""" + + # Config for kv sharing, map from base model layer to draft layer + kv_sharing_mapping: SkipValidation[dict[str, str]] = None + """KV copy mapping for prefill stage from base to draft""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2937,6 +2945,11 @@ def num_lookahead_slots(self) -> int: def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "deepseek_mtp") + def eagle_shift_prefill_token(self) -> bool: + if self.use_eagle(): + return self.prefill_token_shift + return False + def __repr__(self) -> str: method = self.method model = None if method == "ngram" else self.draft_model_config.model diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 442301e4e22..b509d28a1b7 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -183,6 +183,11 @@ def __init__(self, is_gguf = quant_config and quant_config.get_name() == "gguf" if is_gguf and config.model_type == "llama": is_neox_style = False + elif config.model_type == "eagle": + # EAGLE draft model does not use neox style RoPE + is_neox_style = False + else: + is_neox_style = True self.rotary_emb = get_rope( self.head_dim, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index e6ab156e966..97a0fa8946a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -36,6 +36,7 @@ def __init__( self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config + self.kv_sharing_mapping = self.speculative_config.kv_sharing_mapping self.method = self.speculative_config.method self.runner = runner @@ -62,6 +63,7 @@ def __init__( self.cudagraph_batch_sizes = list( reversed( self.vllm_config.compilation_config.cudagraph_capture_sizes)) + self.draft_prefill_kv_sharing_from_base = self.kv_sharing_mapping is not None # persistent buffers for cuda graph self.input_ids = torch.zeros(self.max_num_tokens, @@ -80,12 +82,225 @@ def __init__( 1, device=device, dtype=torch.int32) - self.inputs_embeds = torch.zeros( (self.max_num_tokens, self.hidden_size), dtype=self.dtype, device=device) + def _prepare_adjusted_tensors( + self, + target_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + target_slot_mapping: torch.Tensor, + cu_num_tokens: torch.Tensor, + decode_mask: torch.Tensor, + full_prefill_mask: torch.Tensor, + prefill_first_hiddens: torch.Tensor, + block_table: torch.Tensor, + batch_size: int, + num_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, + torch.Tensor]: + """ + Prepare adjusted tensors for different request types (partial prefill, full prefill, full decode). + + Args: + target_token_ids: Input token IDs tensor + target_positions: Input position IDs tensor + target_hidden_states: Input hidden states tensor + target_slot_mapping: Input slot mapping tensor + cu_num_tokens: Cumulative number of tokens per request + decode_mask: Mask indicating which tokens are for decoding + full_prefill_mask: Mask indicating which requests are full prefill + prefill_first_hiddens: First hidden states for prefill requests + block_table: Block table for KV cache mapping + batch_size: Number of requests in the batch + num_tokens: Total number of tokens + + Returns: + tuple: (target_positions, target_hidden_states, target_slot_mapping, + cu_num_tokens, current_pos, partial_prefill_mask) + + """ + # Count total number of full prefill requests to determine the size needed for adjusted tensors + num_full_prefill = full_prefill_mask.sum().item() + + # Create tensors with extra space for the additional positions from full prefill requests + adjusted_token_ids = torch.zeros( + num_tokens + num_full_prefill, + dtype=target_token_ids.dtype, + device=target_token_ids.device, + ) + adjusted_positions = torch.zeros( + num_tokens + num_full_prefill, + dtype=target_positions.dtype, + device=target_positions.device, + ) + adjusted_slot_mapping = torch.zeros( + num_tokens + num_full_prefill, + dtype=target_slot_mapping.dtype, + device=target_slot_mapping.device, + ) + adjusted_hidden_states = torch.zeros( + num_tokens + num_full_prefill, + self.hidden_size, + dtype=target_hidden_states.dtype, + device=target_hidden_states.device, + ) + + # Create updated cumulative token counts + updated_cu_num_tokens = torch.zeros_like(cu_num_tokens) + + # Track which requests are partial prefill (no decode tokens) + partial_prefill_mask = torch.zeros_like(full_prefill_mask) + + # Create masks for each category + has_decode_mask = torch.zeros(batch_size, + dtype=torch.bool, + device=decode_mask.device) + for i in range(batch_size): + start_idx = cu_num_tokens[i].item() + end_idx = cu_num_tokens[i + 1].item() + has_decode_mask[i] = decode_mask[start_idx:end_idx].any().item() + + # Category 1: Partial prefill (no decode tokens) + partial_prefill_mask = ~has_decode_mask + + # Process batched operations using masks + current_pos = 0 + cu_num_tokens_index = 0 + + # Process each request in the batch + # Process all requests in batch order but with optimized operations + # Create arrays to track request properties + req_starts = cu_num_tokens[:-1] + req_ends = cu_num_tokens[1:] + req_lens = req_ends - req_starts + + # Process each request in order + for i in range(batch_size): + # Get the start and end indices for this request + start_idx = req_starts[i].item() + end_idx = req_ends[i].item() + req_len = req_lens[i].item() + + # Check category + is_partial_prefill = partial_prefill_mask[i].item() + is_full_prefill = full_prefill_mask[i].item() + + if is_partial_prefill: + # Category 1: Partial prefill - just copy all tokens + if not self.draft_prefill_kv_sharing_from_base: + # Use torch operations for copying blocks of data + adjusted_token_ids[current_pos:current_pos + + req_len].copy_( + target_token_ids[start_idx:end_idx]) + adjusted_positions[current_pos:current_pos + + req_len].copy_( + target_positions[start_idx:end_idx]) + adjusted_slot_mapping[current_pos:current_pos + + req_len].copy_(target_slot_mapping[ + start_idx:end_idx]) + adjusted_hidden_states[current_pos + 1:current_pos + + req_len].copy_( + target_hidden_states[start_idx + + 1:end_idx]) + adjusted_hidden_states[ + current_pos] = prefill_first_hiddens[i] + current_pos += req_len + cu_num_tokens_index += 1 + + elif is_full_prefill: + # Category 2: Full prefill with decode - copy tokens and add one position + pos = target_positions[end_idx - 1] + 1 + block_number = pos // self.block_size + block_number = block_table[i][block_number].item() + block_offset = pos % self.block_size + + if not self.draft_prefill_kv_sharing_from_base: + # Copy token IDs, positions, slot mappings, and hidden states + adjusted_token_ids[current_pos:current_pos + + req_len].copy_( + target_token_ids[start_idx:end_idx]) + adjusted_positions[current_pos:current_pos + + req_len].copy_( + target_positions[start_idx:end_idx]) + adjusted_positions[current_pos + + req_len] = adjusted_positions[ + current_pos + req_len - 1] + 1 + + adjusted_slot_mapping[current_pos:current_pos + + req_len].copy_(target_slot_mapping[ + start_idx:end_idx]) + adjusted_slot_mapping[ + current_pos + + req_len] = block_number * self.block_size + block_offset + + adjusted_hidden_states[ + current_pos + 1:current_pos + req_len + 1].copy_( + target_hidden_states[start_idx:end_idx]) + adjusted_hidden_states[ + current_pos] = prefill_first_hiddens[i] + current_pos += req_len + 1 + else: + adjusted_positions[current_pos] = 0 + adjusted_slot_mapping[ + current_pos] = block_number * self.block_size + block_offset + adjusted_hidden_states[current_pos] = target_hidden_states[ + end_idx - 1] + current_pos += 1 + + cu_num_tokens_index += 1 + + else: + # Category 3: Full decode - shift tokens + # Shift operations using optimized copy operations + adjusted_token_ids[current_pos:current_pos + req_len - + 1].copy_(target_token_ids[start_idx + + 1:end_idx]) + adjusted_positions[current_pos:current_pos + req_len].copy_( + target_positions[start_idx:end_idx] + 1) + + adjusted_slot_mapping[current_pos:current_pos + req_len - + 1].copy_(target_slot_mapping[start_idx + + 1:end_idx]) + pos = adjusted_positions[current_pos + req_len - 1] + block_number = pos // self.block_size + block_number = block_table[i][block_number].item() + block_offset = pos % self.block_size + adjusted_slot_mapping[ + current_pos + req_len - + 1] = block_number * self.block_size + block_offset + + adjusted_hidden_states[current_pos:current_pos + + req_len].copy_(target_hidden_states[ + start_idx:end_idx]) + + current_pos += req_len + cu_num_tokens_index += 1 + + # Update the cumulative token count for this request + updated_cu_num_tokens[cu_num_tokens_index] = current_pos + + # Copy the adjusted tensors to the input buffers + self.input_ids[:current_pos] = adjusted_token_ids[:current_pos] + + # Update the variables used by the rest of the function + target_positions = adjusted_positions[:current_pos] + target_hidden_states = adjusted_hidden_states[:current_pos] + target_slot_mapping = adjusted_slot_mapping[:current_pos] + cu_num_tokens = updated_cu_num_tokens + + return ( + target_positions, + target_hidden_states, + target_slot_mapping, + cu_num_tokens, + current_pos, + partial_prefill_mask, + ) + def propose( self, # [num_tokens] @@ -103,11 +318,13 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, + prefill_first_hiddens: torch.Tensor, mm_embeds: Optional[list[torch.Tensor]] = None, + decode_mask: torch.Tensor = None, + full_prefill_mask: torch.Tensor = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -115,12 +332,67 @@ def propose( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + prefill_shift_tokens = True + has_prefill = decode_mask is not None and ( + ~decode_mask.bool()).any().item() + if not self.speculative_config.eagle_shift_prefill_token() and ( + self.method in ["eagle", "eagle3"]): + assert decode_mask is not None + assert full_prefill_mask is not None + prefill_shift_tokens = False + + if not prefill_shift_tokens and has_prefill: + # Adjust the tensors for full prefill requests + ( + target_positions, + target_hidden_states, + target_slot_mapping, + cu_num_tokens, + num_tokens, + partial_prefill_mask, + ) = self._prepare_adjusted_tensors( + target_token_ids, + target_positions, + target_hidden_states, + target_slot_mapping, + cu_num_tokens, + decode_mask, + full_prefill_mask, + prefill_first_hiddens, + block_table, + batch_size, + num_tokens, + ) + else: + # Original behavior: shift all tokens by one + partial_prefill_mask = None + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + if not prefill_shift_tokens: + target_positions += 1 + max_num_blocks_per_req = block_table.shape[1] + segment_indices = torch.arange(len(target_positions), + device=target_positions.device) + segment_indices = (segment_indices.unsqueeze(0) + >= cu_num_tokens[:-1].unsqueeze(1)).sum( + dim=0) - 1 + # Calculate the block table indices + block_table_indices = ( + target_positions // self.block_size + + segment_indices * max_num_blocks_per_req) + block_numbers = block_table.flatten()[block_table_indices] + block_offsets = target_positions % self.block_size + target_slot_mapping = block_numbers * self.block_size + block_offsets + + # Use the original last token indices + last_token_indices = cu_num_tokens[1:] - 1 + + # Replace the last token with the next token, but only for non-partial prefill requests + if not prefill_shift_tokens and has_prefill: + mask = ~partial_prefill_mask + self.input_ids[last_token_indices[mask]] = next_token_ids[mask] + else: + # Original behavior: apply to all requests + self.input_ids[last_token_indices] = next_token_ids # FA requires seq_len to have dtype int32. seq_lens = (target_positions[last_token_indices] + 1).int() @@ -172,8 +444,7 @@ def propose( per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens @@ -210,6 +481,7 @@ def propose( sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) + # print("draft_tokens topK:", logits.topk(3, dim=-1)) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: @@ -225,8 +497,7 @@ def propose( positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size @@ -308,6 +579,7 @@ def propose( # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + # print("draft_token_ids:", draft_token_ids) return draft_token_ids @staticmethod @@ -327,8 +599,7 @@ def prepare_inputs( # a + b, a + b + 1, ..., a + b + c - n3 - 1] # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) + query_len_per_req = cu_target_query_lens[1:] - cu_target_query_lens[:-1] # [a, b, c] -> [a - n1, b - n2, c - n3] num_tokens_per_req = query_len_per_req - num_rejected_tokens @@ -352,8 +623,7 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + draft_model_config = self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) @@ -433,12 +703,12 @@ def validate_same_kv_cache_group(self, for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): for layer_name in kv_cache_group.layer_names: kv_cache_groups[layer_name] = id - assert len( + assert (len( set([ kv_cache_groups[layer_name] for layer_name in self.attn_layer_names - ]) - ) == 1, "All eagle layers should belong to the same kv cache group" + ])) == 1 + ), "All eagle layers should belong to the same kv cache group" # NOTE(woosuk): Currently, the below code is not used and we always use argmax diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a..6d5ae95e5da 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -44,6 +44,10 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None + # Bootstrap eagle and MTP related hidden states + # support caching last hidden states for partial prefill + prefill_hidden_states: Optional[torch.Tensor] = None + def __post_init__(self): self.num_prompt_tokens = len(self.prompt_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f8d8661bdad..79a539d7f2a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -579,12 +579,12 @@ def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray]: + Optional[SpecDecodeMetadata], np.ndarray, torch.Tensor]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, attention_cuda_graphs: whether attention can run in cudagraph - logits_indices, spec_decode_metadata + logits_indices, spec_decode_metadata, decode_mask ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -601,11 +601,18 @@ def _prepare_inputs( tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) + req_last_prompt_index = np.array([ + self.requests[req_id].num_prompt_tokens - 1 + for req_id in self.input_batch.req_ids], + dtype=np.int32 + ) # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + last_prompt_indices_np = np.repeat(req_last_prompt_index, + num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -618,6 +625,8 @@ def _prepare_inputs( arange, out=positions_np) + decode_mask_np = positions_np >= last_prompt_indices_np + # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -765,8 +774,10 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) + decode_mask = torch.tensor(decode_mask_np, device=self.device) + return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens) + spec_decode_metadata, num_scheduled_tokens, decode_mask) def _compute_cascade_attn_prefix_len( self, @@ -1293,7 +1304,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, - num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) + num_scheduled_tokens_np, + decode_mask) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1546,6 +1558,7 @@ def execute_model( aux_hidden_states, spec_decode_metadata, attn_metadata, + mm_embeds, ) # Clear KVConnector state after all KVs are generated. @@ -1577,6 +1590,7 @@ def propose_draft_token_ids( aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], attn_metadata: dict[str, Any], + mm_embeds: list[torch.Tensor], ) -> list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -1605,24 +1619,6 @@ def propose_draft_token_ids( ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - # TODO(woosuk): Refactor the loop. - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. eagle_attn_metadata = attn_metadata[ @@ -1675,11 +1671,51 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] - mm_embeds = None - if self.is_multimodal_model: - mm_embeds = self._gather_mm_embeddings(scheduler_output, + draft_mm_embeds = mm_embeds if mm_embeds else None + if ( + self.is_multimodal_model and + self.speculative_config.eagle_shift_prefill_token()): + draft_mm_embeds = self._gather_mm_embeddings(scheduler_output, shift_computed_tokens=1) + # TODO(woosuk): Refactor the loop. + next_token_ids: list[int] = [] + prefill_first_hiddens = [] + full_prefill_mask = [] + for i, token_ids in enumerate(sampled_token_ids): + req_id = self.input_batch.req_ids[i] + req_state = self.requests[req_id] + if req_state.prefill_hidden_states is None: + req_state.prefill_hidden_states = target_hidden_states[ + cu_num_tokens[i] + ] + prefill_first_hiddens.append(req_state.prefill_hidden_states) + num_prompt_tokens = req_state.num_prompt_tokens + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + full_prefill_mask.append( + req_state.num_computed_tokens < num_prompt_tokens + and req_state.num_computed_tokens+num_scheduled_tokens >= num_prompt_tokens) + if token_ids: + # Common case. + next_token_id = token_ids[-1] + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + last_hidden_index = cu_num_tokens[i + 1] - 1 + req_state.prefill_hidden_states = target_hidden_states[ + last_hidden_index] + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + prefill_first_hiddens = torch.cat(prefill_first_hiddens, + dim=0) + full_prefill_mask = torch.tensor(full_prefill_mask, + dtype=torch.bool, + device=self.device) draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1689,7 +1725,10 @@ def propose_draft_token_ids( cu_num_tokens=cu_num_tokens, block_table=block_table, sampling_metadata=sampling_metadata, - mm_embeds=mm_embeds, + mm_embeds=draft_mm_embeds, + prefill_first_hiddens=prefill_first_hiddens, + decode_mask=decode_mask, + full_prefill_mask=full_prefill_mask, ) spec_token_ids = draft_token_ids.tolist() return spec_token_ids From 2e9541ff1cb7576a6717cbb91bc32d063e1d73dd Mon Sep 17 00:00:00 2001 From: morgendave Date: Thu, 12 Jun 2025 14:30:21 -0700 Subject: [PATCH 8/9] add kv copy logic and offline tests Signed-off-by: morgendave --- examples/offline_inference/spec_decode.py | 28 +++- vllm/config.py | 1 + vllm/envs.py | 5 +- vllm/model_executor/models/llama4.py | 5 - vllm/v1/spec_decode/eagle.py | 181 ++++++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 47 +++--- vllm/v1/worker/utils.py | 76 +++++++++ 7 files changed, 288 insertions(+), 55 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 604c4fbe0ef..6c20ffb6fa1 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -68,8 +68,19 @@ def parse_args(): parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) parser.add_argument("--custom-mm-prompts", action="store_true") - parser.add_argument("--no-prefill-token-shift", dest="prefill_token_shift", - action="store_false", help="Disable prefill token shift (default: enabled)") + parser.add_argument( + "--no-prefill-token-shift", + dest="prefill_token_shift", + action="store_false", + help="Disable prefill token shift (default: enabled)", + ) + parser.add_argument("--target_kv_layer_copy_from", type=int, default=-1) + parser.add_argument( + "--draft_kv_layer_copy_to", + type=str, + default="", + help="comma separated list of layer indices to copy to", + ) return parser.parse_args() @@ -101,11 +112,24 @@ def main(): elif args.method == "eagle3" and eagle_dir is None: eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + target_kv_layer_copy_from = args.target_kv_layer_copy_from + draft_kv_layers_copy_to = ( + [int(layer) for layer in args.draft_kv_layer_copy_to.split(",")] + if args.draft_kv_layer_copy_to + else None + ) + kv_sharing_mapping = None + if args.target_kv_layer_copy_from >= 0 and draft_kv_layers_copy_to: + kv_sharing_mapping = { + f"{layer}": f"{target_kv_layer_copy_from}" + for layer in draft_kv_layers_copy_to + } speculative_config = { "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "prefill_token_shift": args.prefill_token_shift, + "kv_sharing_mapping": kv_sharing_mapping, } elif args.method == "ngram": speculative_config = { diff --git a/vllm/config.py b/vllm/config.py index 29aef108861..4301c1ac3d8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2556,6 +2556,7 @@ class SpeculativeConfig: """Shift tokens during draft prefill or not""" # Config for kv sharing, map from base model layer to draft layer + # Key is draft layer, value is base layer kv_sharing_mapping: SkipValidation[dict[str, str]] = None """KV copy mapping for prefill stage from base to draft""" diff --git a/vllm/envs.py b/vllm/envs.py index 0cc6792d72b..c7d06bb8d5b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -138,6 +138,7 @@ VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None + VLLM_DECODE_ONLY_ATTN: bool = False def get_default_cache_root(): @@ -953,7 +954,9 @@ def get_vllm_port() -> Optional[int]: # generations on machines < 100 for compressed-tensors # models "VLLM_USE_NVFP4_CT_EMULATIONS": - lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))) + lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))), + "VLLM_DECODE_ONLY_ATTN": + lambda: os.environ.get("VLLM_DECODE_ONLY_ATTN", "0") == "1" } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index b509d28a1b7..442301e4e22 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -183,11 +183,6 @@ def __init__(self, is_gguf = quant_config and quant_config.get_name() == "gguf" if is_gguf and config.model_type == "llama": is_neox_style = False - elif config.model_type == "eagle": - # EAGLE draft model does not use neox style RoPE - is_neox_style = False - else: - is_neox_style = True self.rotary_emb = get_rope( self.head_dim, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 97a0fa8946a..f765ba81d83 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +import vllm.envs as envs from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) @@ -14,11 +15,13 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.model_executor.models.utils import extract_layer_index from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel +from vllm.v1.worker.utils import copy_kv_cache_for_layers logger = init_logger(__name__) @@ -49,6 +52,10 @@ def __init__( self.speculative_config.num_speculative_tokens) self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) + # For non-shifting case, consider full prefills each would add + # one more token + if not self.speculative_config.prefill_token_shift: + self.max_num_tokens = self.max_num_tokens * 2 # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). @@ -63,7 +70,8 @@ def __init__( self.cudagraph_batch_sizes = list( reversed( self.vllm_config.compilation_config.cudagraph_capture_sizes)) - self.draft_prefill_kv_sharing_from_base = self.kv_sharing_mapping is not None + self.draft_prefill_kv_sharing_from_base = ( + self.kv_sharing_mapping is not None and envs.VLLM_DECODE_ONLY_ATTN) # persistent buffers for cuda graph self.input_ids = torch.zeros(self.max_num_tokens, @@ -103,7 +111,8 @@ def _prepare_adjusted_tensors( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor]: """ - Prepare adjusted tensors for different request types (partial prefill, full prefill, full decode). + Prepare adjusted tensors for different request types + (partial prefill, full prefill, full decode). Args: target_token_ids: Input token IDs tensor @@ -123,10 +132,12 @@ def _prepare_adjusted_tensors( cu_num_tokens, current_pos, partial_prefill_mask) """ - # Count total number of full prefill requests to determine the size needed for adjusted tensors + # Count total number of full prefill requests to determine the + # size needed for adjusted tensors num_full_prefill = full_prefill_mask.sum().item() - # Create tensors with extra space for the additional positions from full prefill requests + # Create tensors with extra space for the additional + # positions from full prefill requests adjusted_token_ids = torch.zeros( num_tokens + num_full_prefill, dtype=target_token_ids.dtype, @@ -148,6 +159,27 @@ def _prepare_adjusted_tensors( dtype=target_hidden_states.dtype, device=target_hidden_states.device, ) + if self.draft_prefill_kv_sharing_from_base: + # Get the KV caches from the forward context + attentions = get_layers_from_vllm_config(self.vllm_config, + Attention) + kv_caches = { + layer: att.kv_cache[0] + for layer, att in attentions.items() + if layer in self.kv_sharing_mapping + or layer in self.kv_sharing_mapping.values() + } + copy_positions_mask = ~decode_mask + full_prefill_last_pos = cu_num_tokens[1:][full_prefill_mask] - 1 + copy_positions_mask[full_prefill_last_pos] = True + + # Call the function to copy KV cache values + copy_kv_cache_for_layers( + kv_caches=kv_caches, + kv_sharing_layers_mapping=self.kv_sharing_mapping, + copy_positions_mask=copy_positions_mask, + slot_mapping=target_slot_mapping, + ) # Create updated cumulative token counts updated_cu_num_tokens = torch.zeros_like(cu_num_tokens) @@ -191,87 +223,100 @@ def _prepare_adjusted_tensors( if is_partial_prefill: # Category 1: Partial prefill - just copy all tokens + # if we enable copy kv then all of the tokens are skipped if not self.draft_prefill_kv_sharing_from_base: - # Use torch operations for copying blocks of data adjusted_token_ids[current_pos:current_pos + req_len].copy_( target_token_ids[start_idx:end_idx]) + adjusted_positions[current_pos:current_pos + req_len].copy_( target_positions[start_idx:end_idx]) + adjusted_slot_mapping[current_pos:current_pos + req_len].copy_(target_slot_mapping[ start_idx:end_idx]) + + # Put the first prefill hidden state in the first position + # and shift all the other ones, this matches the sequence + # as non-shifting will include the first prefill token adjusted_hidden_states[current_pos + 1:current_pos + req_len].copy_( target_hidden_states[start_idx + 1:end_idx]) + adjusted_hidden_states[ current_pos] = prefill_first_hiddens[i] current_pos += req_len cu_num_tokens_index += 1 elif is_full_prefill: - # Category 2: Full prefill with decode - copy tokens and add one position + # Category 2: Full prefill with decode: + # copy tokens and add one position pos = target_positions[end_idx - 1] + 1 block_number = pos // self.block_size block_number = block_table[i][block_number].item() block_offset = pos % self.block_size + adjusted_slot = (block_number * self.block_size + block_offset) if not self.draft_prefill_kv_sharing_from_base: - # Copy token IDs, positions, slot mappings, and hidden states + # copy the original and adjust the one additional token + # for position, slot mapping and hidden state adjusted_token_ids[current_pos:current_pos + req_len].copy_( target_token_ids[start_idx:end_idx]) + adjusted_positions[current_pos:current_pos + req_len].copy_( target_positions[start_idx:end_idx]) - adjusted_positions[current_pos + - req_len] = adjusted_positions[ - current_pos + req_len - 1] + 1 + adjusted_positions[current_pos + req_len] = pos adjusted_slot_mapping[current_pos:current_pos + req_len].copy_(target_slot_mapping[ start_idx:end_idx]) - adjusted_slot_mapping[ - current_pos + - req_len] = block_number * self.block_size + block_offset + adjusted_slot_mapping[current_pos + + req_len] = (adjusted_slot) adjusted_hidden_states[ current_pos + 1:current_pos + req_len + 1].copy_( target_hidden_states[start_idx:end_idx]) + adjusted_hidden_states[ current_pos] = prefill_first_hiddens[i] current_pos += req_len + 1 else: - adjusted_positions[current_pos] = 0 - adjusted_slot_mapping[ - current_pos] = block_number * self.block_size + block_offset - adjusted_hidden_states[current_pos] = target_hidden_states[ - end_idx - 1] + # if we enable copy kv then all of the prefill tokens + # are skipped. Only keep the prefill output token + adjusted_positions[current_pos] = pos + adjusted_slot_mapping[current_pos] = adjusted_slot + adjusted_hidden_states[current_pos] = ( + target_hidden_states[end_idx - 1]) current_pos += 1 cu_num_tokens_index += 1 else: # Category 3: Full decode - shift tokens - # Shift operations using optimized copy operations + # Due to additional token in full prefill already, + # all the corresponding decode rounds will shift one tokens adjusted_token_ids[current_pos:current_pos + req_len - 1].copy_(target_token_ids[start_idx + 1:end_idx]) + adjusted_positions[current_pos:current_pos + req_len].copy_( target_positions[start_idx:end_idx] + 1) adjusted_slot_mapping[current_pos:current_pos + req_len - 1].copy_(target_slot_mapping[start_idx + 1:end_idx]) + pos = adjusted_positions[current_pos + req_len - 1] block_number = pos // self.block_size block_number = block_table[i][block_number].item() block_offset = pos % self.block_size - adjusted_slot_mapping[ - current_pos + req_len - - 1] = block_number * self.block_size + block_offset + adjusted_slot_mapping[current_pos + req_len - + 1] = (block_number * self.block_size + + block_offset) adjusted_hidden_states[current_pos:current_pos + req_len].copy_(target_hidden_states[ @@ -283,6 +328,7 @@ def _prepare_adjusted_tensors( # Update the cumulative token count for this request updated_cu_num_tokens[cu_num_tokens_index] = current_pos + # using current_pos to cap the actual number of tokens # Copy the adjusted tensors to the input buffers self.input_ids[:current_pos] = adjusted_token_ids[:current_pos] @@ -363,11 +409,26 @@ def propose( batch_size, num_tokens, ) + if (partial_prefill_mask.all() + and self.draft_prefill_kv_sharing_from_base): + # All requests are partial prefill and + # KV cache sharing is enabled + # Skip the rest of the function + # and return dummy draft tokens + return torch.zeros( + (batch_size, self.num_speculative_tokens), + dtype=target_token_ids.dtype, + device=target_token_ids.device, + ) + batch_size = cu_num_tokens.shape[0] - 1 else: # Original behavior: shift all tokens by one - partial_prefill_mask = None self.input_ids[:num_tokens - 1] = target_token_ids[1:] + partial_prefill_mask = torch.zeros_like(full_prefill_mask) if not prefill_shift_tokens: + # For pure decode in non-shifting prefill case + # Due to one additional token in prefill, all the decode + # rounds will shift one token target_positions += 1 max_num_blocks_per_req = block_table.shape[1] segment_indices = torch.arange(len(target_positions), @@ -381,15 +442,24 @@ def propose( segment_indices * max_num_blocks_per_req) block_numbers = block_table.flatten()[block_table_indices] block_offsets = target_positions % self.block_size - target_slot_mapping = block_numbers * self.block_size + block_offsets + target_slot_mapping = (block_numbers * self.block_size + + block_offsets) # Use the original last token indices last_token_indices = cu_num_tokens[1:] - 1 - # Replace the last token with the next token, but only for non-partial prefill requests if not prefill_shift_tokens and has_prefill: + # Replace the last token with the next token under non-shifting, + # but only for non-partial prefill requests mask = ~partial_prefill_mask - self.input_ids[last_token_indices[mask]] = next_token_ids[mask] + # if we enable copy kv then all of the partial prefills + # are completely skipped so they won't be in last_token_indices + input_indices = ( + last_token_indices[mask] + if not self.draft_prefill_kv_sharing_from_base else + last_token_indices[:batch_size - + partial_prefill_mask.sum().item()]) + self.input_ids[input_indices] = next_token_ids[mask] else: # Original behavior: apply to all requests self.input_ids[last_token_indices] = next_token_ids @@ -481,10 +551,24 @@ def propose( sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states, None) draft_token_ids = logits.argmax(dim=-1) - # print("draft_tokens topK:", logits.topk(3, dim=-1)) # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: + if (self.draft_prefill_kv_sharing_from_base + and partial_prefill_mask.any().item()): + # if we do kv sharing and has partial prefill + # the original position for partial prefill will not + # have token output thus we need to pad the draft tokens + # with the correct positions + padded_draft_token_ids = torch.zeros( + partial_prefill_mask.shape[0], + dtype=draft_token_ids.dtype, + device=draft_token_ids.device) + draft_token_ids = draft_token_ids[:batch_size - + partial_prefill_mask.sum( + ).item()] + padded_draft_token_ids[~partial_prefill_mask] = draft_token_ids + draft_token_ids = padded_draft_token_ids # [batch_size, 1] return draft_token_ids.view(-1, 1) @@ -579,7 +663,21 @@ def propose( # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) - # print("draft_token_ids:", draft_token_ids) + if (self.draft_prefill_kv_sharing_from_base + and partial_prefill_mask.any().item()): + # if we do kv sharing and has partial prefill + # the original position for partial prefill will not + # have token output thus we need to pad the draft tokens + # with the correct positions + padded_draft_token_ids = torch.zeros( + (partial_prefill_mask.shape[0], self.num_speculative_tokens), + dtype=draft_token_ids.dtype, + device=draft_token_ids.device) + draft_token_ids = draft_token_ids[:batch_size - + partial_prefill_mask.sum().item( + )] + padded_draft_token_ids[~partial_prefill_mask] = draft_token_ids + draft_token_ids = padded_draft_token_ids return draft_token_ids @staticmethod @@ -623,7 +721,8 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - draft_model_config = self.vllm_config.speculative_config.draft_model_config + draft_model_config = ( + self.vllm_config.speculative_config.draft_model_config) target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) @@ -638,6 +737,30 @@ def load_model(self, target_model: nn.Module) -> None: self.attn_layer_names = list(draft_attn_layer_names) + if envs.VLLM_DECODE_ONLY_ATTN: + logger.info("Using half prefill for decoding only attention," + " Setting up KV sharing from base to draft.") + assert self.kv_sharing_mapping is not None + target_attn_layer_indices_dict = dict( + (str(extract_layer_index(target_layer)), target_layer) + for target_layer in target_attn_layer_names) + draft_attn_layer_indices_dict = dict( + (str(extract_layer_index(draft_layer)), draft_layer) + for draft_layer in draft_attn_layer_names) + updated_kv_sharing_mapping = { + draft_attn_layer_indices_dict[draft_index]: + target_attn_layer_indices_dict[target_index] + for draft_index, target_index in + self.kv_sharing_mapping.items() + } + logger.info("Updated KV sharing mapping: %s", + updated_kv_sharing_mapping) + assert len(updated_kv_sharing_mapping) == len( + self.kv_sharing_mapping), ( + "KV sharing mapping should be a subset of draft and" + " target attn layer indices") + self.kv_sharing_mapping = updated_kv_sharing_mapping + if supports_multimodal(target_model): # handle multimodality self.model.config.image_token_index = ( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 79a539d7f2a..ebd91f6f5b7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -602,17 +602,17 @@ def _prepare_inputs( num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) req_last_prompt_index = np.array([ - self.requests[req_id].num_prompt_tokens - 1 - for req_id in self.input_batch.req_ids], - dtype=np.int32 - ) + self.requests[req_id].num_prompt_tokens - 1 + for req_id in self.input_batch.req_ids + ], + dtype=np.int32) # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) last_prompt_indices_np = np.repeat(req_last_prompt_index, - num_scheduled_tokens) + num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -1559,6 +1559,7 @@ def execute_model( spec_decode_metadata, attn_metadata, mm_embeds, + decode_mask, ) # Clear KVConnector state after all KVs are generated. @@ -1591,6 +1592,7 @@ def propose_draft_token_ids( spec_decode_metadata: Optional[SpecDecodeMetadata], attn_metadata: dict[str, Any], mm_embeds: list[torch.Tensor], + decode_mask: torch.Tensor, ) -> list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -1662,8 +1664,8 @@ def propose_draft_token_ids( num_tokens, ) target_token_ids = self.input_ids[token_indices] - # TODO(woosuk): Support M-RoPE. target_positions = self.positions[token_indices] + decode_mask = decode_mask[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) @@ -1672,29 +1674,36 @@ def propose_draft_token_ids( target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] draft_mm_embeds = mm_embeds if mm_embeds else None - if ( - self.is_multimodal_model and - self.speculative_config.eagle_shift_prefill_token()): - draft_mm_embeds = self._gather_mm_embeddings(scheduler_output, - shift_computed_tokens=1) + if envs.VLLM_DECODE_ONLY_ATTN: + draft_mm_embeds = None + elif (self.is_multimodal_model + and self.speculative_config.eagle_shift_prefill_token()): + draft_mm_embeds = self._gather_mm_embeddings( + scheduler_output, shift_computed_tokens=1) # TODO(woosuk): Refactor the loop. next_token_ids: list[int] = [] + # Get the first token's hidden state from cached for each request. + # This is used in non-shifted prefill for eagle draft. prefill_first_hiddens = [] full_prefill_mask = [] for i, token_ids in enumerate(sampled_token_ids): req_id = self.input_batch.req_ids[i] req_state = self.requests[req_id] + # Initialize the prefill hidden state if not set., + # from experimental results, the last token hidden state + # works very well for init the first prefill hidden state. if req_state.prefill_hidden_states is None: req_state.prefill_hidden_states = target_hidden_states[ - cu_num_tokens[i] - ] + cu_num_tokens[i]] prefill_first_hiddens.append(req_state.prefill_hidden_states) num_prompt_tokens = req_state.num_prompt_tokens - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] full_prefill_mask.append( - req_state.num_computed_tokens < num_prompt_tokens - and req_state.num_computed_tokens+num_scheduled_tokens >= num_prompt_tokens) + req_state.num_computed_tokens < num_prompt_tokens + and req_state.num_computed_tokens + num_scheduled_tokens + >= num_prompt_tokens) if token_ids: # Common case. next_token_id = token_ids[-1] @@ -1705,14 +1714,16 @@ def propose_draft_token_ids( scheduler_output.num_scheduled_tokens[req_id]) next_token_id = req_state.get_token_id(seq_len) last_hidden_index = cu_num_tokens[i + 1] - 1 + # For non-shifting prefill + partial prefill case, + # the current round last hidden state will be used + # as the first prefill hidden for the next round req_state.prefill_hidden_states = target_hidden_states[ last_hidden_index] next_token_ids.append(next_token_id) next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) - prefill_first_hiddens = torch.cat(prefill_first_hiddens, - dim=0) + prefill_first_hiddens = torch.cat(prefill_first_hiddens, dim=0) full_prefill_mask = torch.tensor(full_prefill_mask, dtype=torch.bool, device=self.device) diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 70339ff2f00..9a58854c2f3 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -110,3 +110,79 @@ def initialize_kv_cache_for_kv_sharing( kv_caches[layer_name] = kv_caches[target_layer_name] group_idx = layer_to_kv_cache_group_idx[target_layer_name] kv_cache_groups[group_idx].layer_names.append(layer_name) + + +def copy_kv_cache_for_layers( + kv_caches: dict[str, torch.Tensor], + kv_sharing_layers_mapping: dict[str, str], + copy_positions_mask: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + """ + Copies values from one set of layers to another + based on a mapping and a mask. + + This function is primarily used when KV sharing + is enabled especially for spec decoding to copy + target model's cache for the specified positions into + corresponding draft model layer's KV cache. + + Args: + kv_caches: + Dictionary mapping layer names to their cache tensors. + kv_sharing_layers_mapping: + Mapping the target layers to the source layers. + copy_positions_mask: + Boolean mask where True indicates positions to copy. + slot_mapping: + Mapping tensor that maps positions to cache slots. + """ + # Get the positions to copy + positions = torch.nonzero(copy_positions_mask, as_tuple=True)[0] + + if positions.numel() == 0: + # No positions to copy + return + + # Get the corresponding slot mappings for the positions + slots = slot_mapping[positions] + + # Copy KV cache values from source layers to target layers + for target_layer, source_layer in kv_sharing_layers_mapping.items(): + if target_layer not in kv_caches or source_layer not in kv_caches: + continue + + target_kv_cache = kv_caches[target_layer] + source_kv_cache = kv_caches[source_layer] + + block_size = source_kv_cache.shape[2] + + kv_dim = 2 + # Process in smaller batches to reduce memory overhead + batch_size = 8192 + num_positions = positions.size(0) + + for start_idx in range(0, num_positions, batch_size): + end_idx = min(start_idx + batch_size, num_positions) + + # Get batch of slots + batch_slots = slots[start_idx:end_idx] + batch_block_indices = batch_slots // block_size + batch_block_offsets = batch_slots % block_size + + # Create batch-sized indexing tensors + batch_block_indices_expanded = batch_block_indices.view( + 1, -1, 1, 1, 1) + batch_block_offsets_expanded = batch_block_offsets.view( + 1, 1, -1, 1, 1) + + # Copy values for this batch + for kv_idx in range(kv_dim): + target_kv_cache[ + kv_idx, + batch_block_indices_expanded.squeeze(), + batch_block_offsets_expanded.squeeze(), :, :] = ( + source_kv_cache[ + kv_idx, + batch_block_indices_expanded.squeeze(), + batch_block_offsets_expanded.squeeze(), :, :]) From fb93e7f22019faa59afdf47d31b17b21219a44a3 Mon Sep 17 00:00:00 2001 From: morgendave Date: Mon, 30 Jun 2025 16:26:35 -0700 Subject: [PATCH 9/9] Add examples and algorithm for non-shifting, fixes some minor issues Signed-off-by: morgendave --- tests/models/test_initialization.py | 5 ++ tests/v1/e2e/test_spec_decode.py | 52 +++++++++++++-------- vllm/config.py | 2 +- vllm/v1/spec_decode/eagle.py | 71 ++++++++++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 12 +++-- 5 files changed, 93 insertions(+), 49 deletions(-) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 25bc96bf326..6ab508a9c5e 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -26,6 +26,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): "KimiVLForConditionalGeneration"): pytest.skip("Avoid OOM") + if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): + from vllm.model_executor.models.llama4 import Llama4ForCausalLM + from vllm.model_executor.models.registry import ModelRegistry + ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) + # Avoid OOM and reduce initialization time by only using 1 layer def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update(model_info.hf_overrides) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index b603fabed5f..e11e527b804 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -123,27 +123,39 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() -@pytest.mark.parametrize( - "model_setup,mm_enabled", [ - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param( - (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False), - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - pytest.param( - (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True), - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - ], - ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"]) +@pytest.mark.parametrize("model_setup,mm_enabled,prefill_shift", [ + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False, True), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False, True), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False, True), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, True), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False, False), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, False), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), +], + ids=[ + "llama3_eagle", "llama3_eagle3", "llama4_eagle", + "llama4_eagle_mm", "llama4_eagle_no_shift", + "llama4_eagle_mm_no_shift" + ]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, sampling_config: SamplingParams, model_setup: tuple[str, str, str, int], mm_enabled: bool, + prefill_shift: bool, ): # Generate test prompts inside the function instead of using fixture test_prompts = get_test_prompts(mm_enabled) @@ -156,8 +168,9 @@ def test_eagle_correctness( m.setenv("VLLM_USE_V1", "1") method, model_name, spec_model_name, tp_size = model_setup + max_model_len = 2048 if not mm_enabled else 4096 ref_llm = LLM(model=model_name, - max_model_len=2048, + max_model_len=max_model_len, tensor_parallel_size=tp_size) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm @@ -172,9 +185,10 @@ def test_eagle_correctness( "method": method, "model": spec_model_name, "num_speculative_tokens": 3, - "max_model_len": 2048, + "max_model_len": max_model_len, + "prefill_token_shift": prefill_shift, }, - max_model_len=2048, + max_model_len=max_model_len, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 diff --git a/vllm/config.py b/vllm/config.py index 4301c1ac3d8..66b917e4593 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2557,7 +2557,7 @@ class SpeculativeConfig: # Config for kv sharing, map from base model layer to draft layer # Key is draft layer, value is base layer - kv_sharing_mapping: SkipValidation[dict[str, str]] = None + kv_sharing_mapping: SkipValidation[dict[str, str]] = None # type: ignore """KV copy mapping for prefill stage from base to draft""" def compute_hash(self) -> str: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index f765ba81d83..d98ed14e662 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -104,6 +104,7 @@ def _prepare_adjusted_tensors( cu_num_tokens: torch.Tensor, decode_mask: torch.Tensor, full_prefill_mask: torch.Tensor, + partial_prefill_mask: torch.Tensor, prefill_first_hiddens: torch.Tensor, block_table: torch.Tensor, batch_size: int, @@ -131,6 +132,34 @@ def _prepare_adjusted_tensors( tuple: (target_positions, target_hidden_states, target_slot_mapping, cu_num_tokens, current_pos, partial_prefill_mask) + Algorithm design: + - Suppose target tokens are [1,2,3,...N], next token is N+1 + - Position is [0,1,2,...N-1] + - And hidden is [h1,h2,h3,...hN] + - Suppose partial prefill is [Nm, Nm+1, ...Nm+M-1] + -- For normal shifting: + --- draft prefill is [2,3,...N+1], position is same as target + --- Stacking hidden is [h1,h2,h3,...hN] + --- Decode tokens are [N+2, N+3, ...], hidden is [hN+1,hN+2,...] + --- Decode positions are [N,N+1,...] + --- draft partial prefill is [Nm+1, Nm+2, ...Nm+M] + -- For non-shifting: + --- draft full prefill is [1,2,3,...N+1], position is [0,1,2,...N] + --- Stacking hidden is [hN,h1,h2,h3,...hN] + --- Decode tokens are [N+2, N+3, ...], hidden is [hN+1,hN+2,...] + --- Decode positions are [N+1,N+2,...] + --- draft partial prefill is [Nm, Nm+1, ...Nm+M-1] + --- draft hidden is [hNm-1,hNm,...hNm+M] + (hNm-1 is the last round hidden) + -- For kv sharing(non-shifting required): + This means all target prefill tokens are not needed to be processed + in drafting prefill step as we don't need the kv from draft. + --- draft full prefill is [N+1], position is [N] + --- Stacking hidden is [hN] + --- Decode is the same as non-shifting decode + --- draft partial prefill is totally skipped + All other metadata like slot mapping, etc. should be based on + the positions and tokens to generate/manipulate again """ # Count total number of full prefill requests to determine the # size needed for adjusted tensors @@ -184,21 +213,6 @@ def _prepare_adjusted_tensors( # Create updated cumulative token counts updated_cu_num_tokens = torch.zeros_like(cu_num_tokens) - # Track which requests are partial prefill (no decode tokens) - partial_prefill_mask = torch.zeros_like(full_prefill_mask) - - # Create masks for each category - has_decode_mask = torch.zeros(batch_size, - dtype=torch.bool, - device=decode_mask.device) - for i in range(batch_size): - start_idx = cu_num_tokens[i].item() - end_idx = cu_num_tokens[i + 1].item() - has_decode_mask[i] = decode_mask[start_idx:end_idx].any().item() - - # Category 1: Partial prefill (no decode tokens) - partial_prefill_mask = ~has_decode_mask - # Process batched operations using masks current_pos = 0 cu_num_tokens_index = 0 @@ -368,6 +382,7 @@ def propose( mm_embeds: Optional[list[torch.Tensor]] = None, decode_mask: torch.Tensor = None, full_prefill_mask: torch.Tensor = None, + partial_prefill_mask: torch.Tensor = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] @@ -388,6 +403,17 @@ def propose( prefill_shift_tokens = False if not prefill_shift_tokens and has_prefill: + if (partial_prefill_mask.all() + and self.draft_prefill_kv_sharing_from_base): + # All requests are partial prefill and + # KV cache sharing is enabled + # Skip the rest of the function + # and return dummy draft tokens + return torch.zeros( + (batch_size, self.num_speculative_tokens), + dtype=target_token_ids.dtype, + device=target_token_ids.device, + ) # Adjust the tensors for full prefill requests ( target_positions, @@ -404,22 +430,12 @@ def propose( cu_num_tokens, decode_mask, full_prefill_mask, + partial_prefill_mask, prefill_first_hiddens, block_table, batch_size, num_tokens, ) - if (partial_prefill_mask.all() - and self.draft_prefill_kv_sharing_from_base): - # All requests are partial prefill and - # KV cache sharing is enabled - # Skip the rest of the function - # and return dummy draft tokens - return torch.zeros( - (batch_size, self.num_speculative_tokens), - dtype=target_token_ids.dtype, - device=target_token_ids.device, - ) batch_size = cu_num_tokens.shape[0] - 1 else: # Original behavior: shift all tokens by one @@ -451,6 +467,9 @@ def propose( if not prefill_shift_tokens and has_prefill: # Replace the last token with the next token under non-shifting, # but only for non-partial prefill requests + # For partial prefill in non-shifting, we just match the target + # prefill tokens as it would match the positions and hidden states + # so no need to add this next token from next round mask = ~partial_prefill_mask # if we enable copy kv then all of the partial prefills # are completely skipped so they won't be in last_token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ebd91f6f5b7..c886ac7776e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1303,8 +1303,7 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np, + spec_decode_metadata, num_scheduled_tokens_np, decode_mask) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph @@ -1687,6 +1686,7 @@ def propose_draft_token_ids( # This is used in non-shifted prefill for eagle draft. prefill_first_hiddens = [] full_prefill_mask = [] + partial_prefill_mask = [] for i, token_ids in enumerate(sampled_token_ids): req_id = self.input_batch.req_ids[i] req_state = self.requests[req_id] @@ -1695,7 +1695,7 @@ def propose_draft_token_ids( # works very well for init the first prefill hidden state. if req_state.prefill_hidden_states is None: req_state.prefill_hidden_states = target_hidden_states[ - cu_num_tokens[i]] + cu_num_tokens[i + 1] - 1] prefill_first_hiddens.append(req_state.prefill_hidden_states) num_prompt_tokens = req_state.num_prompt_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ @@ -1707,6 +1707,7 @@ def propose_draft_token_ids( if token_ids: # Common case. next_token_id = token_ids[-1] + partial_prefill_mask.append(False) else: # Partial prefill (rare case). # Get the next token id from the request state. @@ -1719,6 +1720,7 @@ def propose_draft_token_ids( # as the first prefill hidden for the next round req_state.prefill_hidden_states = target_hidden_states[ last_hidden_index] + partial_prefill_mask.append(True) next_token_ids.append(next_token_id) next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, @@ -1727,6 +1729,9 @@ def propose_draft_token_ids( full_prefill_mask = torch.tensor(full_prefill_mask, dtype=torch.bool, device=self.device) + partial_prefill_mask = torch.tensor(partial_prefill_mask, + dtype=torch.bool, + device=self.device) draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1740,6 +1745,7 @@ def propose_draft_token_ids( prefill_first_hiddens=prefill_first_hiddens, decode_mask=decode_mask, full_prefill_mask=full_prefill_mask, + partial_prefill_mask=partial_prefill_mask, ) spec_token_ids = draft_token_ids.tolist() return spec_token_ids