diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 26e492fed25..6c20ffb6fa1 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -13,6 +13,38 @@ from argparse import ArgumentParser as FlexibleArgumentParser +QUESTION = "What is the content of each image?" +IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", + "https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG", + "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg", + "https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg", + "https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg", + "https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg", + "https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg", +] + + +def get_custom_mm_prompts(num_prompts): + prompts = [] + for url in IMAGE_URLS: + prompts.append( + [ + {"type": "image_url", "image_url": {"url": url}}, + {"type": "text", "text": QUESTION}, + ] + ) + if num_prompts > len(IMAGE_URLS): + prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1) + + return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]] + + def parse_args(): parser = FlexibleArgumentParser() add_dataset_parser(parser) @@ -35,6 +67,20 @@ def parse_args(): parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--model-dir", type=str, default=None) parser.add_argument("--eagle-dir", type=str, default=None) + parser.add_argument("--custom-mm-prompts", action="store_true") + parser.add_argument( + "--no-prefill-token-shift", + dest="prefill_token_shift", + action="store_false", + help="Disable prefill token shift (default: enabled)", + ) + parser.add_argument("--target_kv_layer_copy_from", type=int, default=-1) + parser.add_argument( + "--draft_kv_layer_copy_to", + type=str, + default="", + help="comma separated list of layer indices to copy to", + ) return parser.parse_args() @@ -46,12 +92,18 @@ def main(): if args.model_dir is None: model_dir = "meta-llama/Llama-3.1-8B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_dir) - - prompts = get_samples(args, tokenizer) - # add_special_tokens is False to avoid adding bos twice when using chat templates - prompt_ids = [ - tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts - ] + args.custom_skip_chat_template = True + + if not args.custom_mm_prompts: + prompts = get_samples(args, tokenizer) + # add_special_tokens is False to avoid adding bos twice + # when using chat templates + prompt_ids = [ + tokenizer.encode(prompt.prompt, add_special_tokens=False) + for prompt in prompts + ] + else: + prompts = get_custom_mm_prompts(args.num_prompts) if args.method == "eagle" or args.method == "eagle3": eagle_dir = args.eagle_dir @@ -60,10 +112,24 @@ def main(): elif args.method == "eagle3" and eagle_dir is None: eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + target_kv_layer_copy_from = args.target_kv_layer_copy_from + draft_kv_layers_copy_to = ( + [int(layer) for layer in args.draft_kv_layer_copy_to.split(",")] + if args.draft_kv_layer_copy_to + else None + ) + kv_sharing_mapping = None + if args.target_kv_layer_copy_from >= 0 and draft_kv_layers_copy_to: + kv_sharing_mapping = { + f"{layer}": f"{target_kv_layer_copy_from}" + for layer in draft_kv_layers_copy_to + } speculative_config = { "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, + "prefill_token_shift": args.prefill_token_shift, + "kv_sharing_mapping": kv_sharing_mapping, } elif args.method == "ngram": speculative_config = { @@ -84,10 +150,18 @@ def main(): gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, + max_model_len=16384, + limit_mm_per_prompt={"image": 5}, + disable_chunked_mm_input=True, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) - outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + if not args.custom_mm_prompts: + outputs = llm.generate( + prompt_token_ids=prompt_ids, sampling_params=sampling_params + ) + else: + outputs = llm.chat(prompts, sampling_params=sampling_params) # print the generated text if args.print_output: diff --git a/tests/models/registry.py b/tests/models/registry.py index 48302f9d664..b94376970ce 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -451,6 +451,11 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", tokenizer="meta-llama/Llama-3.1-8B-Instruct"), + "EagleLlama4ForCausalLM": _HfExamplesInfo( + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + trust_remote_code=True, + speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 "EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16", trust_remote_code=True, is_available_online=False, @@ -500,4 +505,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: raise ValueError(f"No example model defined for {model_id}") -HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) \ No newline at end of file +HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 25bc96bf326..6ab508a9c5e 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -26,6 +26,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): "KimiVLForConditionalGeneration"): pytest.skip("Avoid OOM") + if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): + from vllm.model_executor.models.llama4 import Llama4ForCausalLM + from vllm.model_executor.models.registry import ModelRegistry + ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) + # Avoid OOM and reduce initialization time by only using 1 layer def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update(model_info.hf_overrides) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 93e7c12f3a0..e11e527b804 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -6,24 +6,31 @@ from typing import Any import pytest +import torch from vllm import LLM, SamplingParams +from vllm.assets.base import VLLM_S3_BUCKET_URL +from vllm.assets.image import VLM_IMAGES_DIR +from vllm.distributed import cleanup_dist_env_and_memory -@pytest.fixture -def test_prompts(): +def get_test_prompts(mm_enabled: bool): prompt_types = ["repeat", "sentence"] - num_prompts = 100 + if mm_enabled: + prompt_types.append("mm") + num_prompts = 10 prompts = [] random.seed(0) random_prompt_type_choices = random.choices(prompt_types, k=num_prompts) + print(f"Prompt types: {random_prompt_type_choices}") # Generate a mixed batch of prompts, some of which can be easily # predicted by n-gram matching and some which likely cannot. for kind in random_prompt_type_choices: word_choices = ["test", "temp", "hello", "where"] word = random.choice(word_choices) + prompt: str | list[dict[str, Any]] = "" if kind == "repeat": prompt = f""" please repeat the word '{word}' 10 times. @@ -36,6 +43,21 @@ def test_prompts(): uses the word {word} at least once. give no other output than that simple sentence without quotes. """ + elif kind == "mm": + placeholders = [{ + "type": "image_url", + "image_url": { + "url": + f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg" + }, + }] + prompt = [ + *placeholders, + { + "type": "text", + "text": "The meaning of the image is" + }, + ] else: raise ValueError(f"Unknown prompt type: {kind}") prompts.append([{"role": "user", "content": prompt}]) @@ -53,14 +75,6 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -def eagle_model_name(): - return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - - -def eagle3_model_name(): - return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - - def test_ngram_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], @@ -77,6 +91,8 @@ def test_ngram_correctness( ref_llm = LLM(model=model_name, max_model_len=1024) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() spec_llm = LLM( model=model_name, @@ -103,39 +119,76 @@ def test_ngram_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.7 * len(ref_outputs)) del spec_llm - - -@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("model_setup,mm_enabled,prefill_shift", [ + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False, True), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False, True), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False, True), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, True), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False, False), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + (("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True, False), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), +], + ids=[ + "llama3_eagle", "llama3_eagle3", "llama4_eagle", + "llama4_eagle_mm", "llama4_eagle_no_shift", + "llama4_eagle_mm_no_shift" + ]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, - test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, - model_name: str, - use_eagle3: bool, + model_setup: tuple[str, str, str, int], + mm_enabled: bool, + prefill_shift: bool, ): + # Generate test prompts inside the function instead of using fixture + test_prompts = get_test_prompts(mm_enabled) ''' Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. + model_setup: (method, model_name, eagle_model_name, tp_size) ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + method, model_name, spec_model_name, tp_size = model_setup - ref_llm = LLM(model=model_name, max_model_len=2048) + max_model_len = 2048 if not mm_enabled else 4096 + ref_llm = LLM(model=model_name, + max_model_len=max_model_len, + tensor_parallel_size=tp_size) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() - spec_model_name = eagle3_model_name( - ) if use_eagle3 else eagle_model_name() spec_llm = LLM( model=model_name, trust_remote_code=True, + tensor_parallel_size=tp_size, speculative_config={ - "method": "eagle3" if use_eagle3 else "eagle", + "method": method, "model": spec_model_name, "num_speculative_tokens": 3, - "max_model_len": 2048, + "max_model_len": max_model_len, + "prefill_token_shift": prefill_shift, }, - max_model_len=2048, + max_model_len=max_model_len, ) spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0 @@ -152,3 +205,5 @@ def test_eagle_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/vllm/config.py b/vllm/config.py index bac18e8175d..66b917e4593 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2551,6 +2551,15 @@ class SpeculativeConfig: ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" + # Shift prefill tokens for draft, only used in eagle + prefill_token_shift: bool = True + """Shift tokens during draft prefill or not""" + + # Config for kv sharing, map from base model layer to draft layer + # Key is draft layer, value is base layer + kv_sharing_mapping: SkipValidation[dict[str, str]] = None # type: ignore + """KV copy mapping for prefill stage from base to draft""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2937,6 +2946,11 @@ def num_lookahead_slots(self) -> int: def use_eagle(self) -> bool: return self.method in ("eagle", "eagle3", "deepseek_mtp") + def eagle_shift_prefill_token(self) -> bool: + if self.use_eagle(): + return self.prefill_token_shift + return False + def __repr__(self) -> str: method = self.method model = None if method == "ngram" else self.draft_model_config.model diff --git a/vllm/envs.py b/vllm/envs.py index 0cc6792d72b..c7d06bb8d5b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -138,6 +138,7 @@ VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None + VLLM_DECODE_ONLY_ATTN: bool = False def get_default_cache_root(): @@ -953,7 +954,9 @@ def get_vllm_port() -> Optional[int]: # generations on machines < 100 for compressed-tensors # models "VLLM_USE_NVFP4_CT_EMULATIONS": - lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))) + lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))), + "VLLM_DECODE_ONLY_ATTN": + lambda: os.environ.get("VLLM_DECODE_ONLY_ATTN", "0") == "1" } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 0c9baab1f2e..442301e4e22 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -255,6 +255,7 @@ def __init__( super().__init__() self.layer_idx = extract_layer_index(prefix) + self.global_layer = config.no_rope_layers[self.layer_idx] == 0 self.hidden_size = config.hidden_size rope_theta = config.rope_theta rope_scaling = config.rope_scaling diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py new file mode 100644 index 00000000000..ece490ff2f2 --- /dev/null +++ b/vllm/model_executor/models/llama4_eagle.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. +# All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.torchao import TorchAOConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, + Llama4ForCausalLM) +from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.inputs import NestedTensors + +from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings + +logger = init_logger(__name__) + + +@support_torch_compile +class LlamaModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + self.validate_and_update_config(start_layer_id, quant_config) + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList([ + Llama4DecoderLayer( + self.config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) + ]) + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + ) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + hidden_states = self.fc( + torch.cat((inputs_embeds, hidden_states), dim=-1)) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + name = name.removeprefix("model.") + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + for name in params_dict: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + assert name in loaded_params, f"{name} is not loaded!" + return loaded_params + + def validate_and_update_config( + self, + start_layer_id: int, + quant_config: Optional[QuantizationConfig] = None) -> None: + # yoco and moe is not supported by draft model yet + assert self.config.yoco_global_kv_layer is None + assert self.config.yoco_local_kv_layer is None + assert len(self.config.moe_layers) == 0 + # draft model layer index is increased by start_layer_id, + # so we need to pad relevant configs accordingly + self.config.no_rope_layers = [ + 0 + ] * start_layer_id + self.config.no_rope_layers + # currently only TorchAO quantization is supported + if isinstance(quant_config, TorchAOConfig): + + def pad_layer_name(layer: str) -> str: + layer_index = extract_layer_index(layer) + return layer.replace(str(layer_index), + str(layer_index + start_layer_id)) + + quant_config.torchao_config.module_fqn_to_config = { + pad_layer_name(layer): quantization + for layer, quantization in + quant_config.torchao_config.module_fqn_to_config.items() + } + + +class EagleLlama4ForCausalLM(Llama4ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + # draft model quantization config may differ from target model + quant_config = VllmConfig.get_quantization_config( + vllm_config.speculative_config.draft_model_config, + vllm_config.load_config) + self.model = LlamaModel(vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num, + quant_config=quant_config) + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states, inputs_embeds) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> None: + loader = AutoWeightsLoader( + self, + # lm_head is tied with target model (Llama4ForCausalLM) + skip_prefixes=(["lm_head."]), + ) + + model_weights = {} + weights = [ + self.permute_qk_weight_for_rotary(name, loaded_weight) + for name, loaded_weight in weights + ] + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + loader.load_weights(model_weights.items()) + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings(input_ids) + + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + self.config.image_token_index, + ) + + return inputs_embeds diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index c7690604c1d..a4933b77e3a 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from typing import Optional import torch import torch.nn as nn @@ -148,7 +149,12 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/llama_eagle3.py b/vllm/model_executor/models/llama_eagle3.py index 7fc9fe2ebb6..71275f0d585 100644 --- a/vllm/model_executor/models/llama_eagle3.py +++ b/vllm/model_executor/models/llama_eagle3.py @@ -202,7 +202,12 @@ def forward( input_ids: torch.Tensor, positions: torch.Tensor, hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if inputs_embeds is not None: + raise NotImplementedError( + f"{type(self).__name__} does not support multimodal inputs yet." + ) return self.model(input_ids, positions, hidden_states) def compute_logits( diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 27d47692985..fd43ed836f2 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -239,6 +239,7 @@ "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6661d984a77..d98ed14e662 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + import torch import torch.nn as nn +import vllm.envs as envs from vllm.attention.layer import Attention from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) @@ -12,11 +15,13 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.model_executor.models.utils import extract_layer_index from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel +from vllm.v1.worker.utils import copy_kv_cache_for_layers logger = init_logger(__name__) @@ -34,28 +39,39 @@ def __init__( self.vllm_config = vllm_config self.speculative_config = vllm_config.speculative_config self.draft_model_config = self.speculative_config.draft_model_config + self.kv_sharing_mapping = self.speculative_config.kv_sharing_mapping self.method = self.speculative_config.method self.runner = runner self.dtype = vllm_config.model_config.dtype + self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size self.num_speculative_tokens = ( self.speculative_config.num_speculative_tokens) self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) + # For non-shifting case, consider full prefills each would add + # one more token + if not self.speculative_config.prefill_token_shift: + self.max_num_tokens = self.max_num_tokens * 2 # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). self.hidden_size = self.draft_model_config.get_hidden_size() + self.is_multimodal_model = vllm_config.model_config \ + .is_multimodal_model + self.use_cuda_graph = (self.vllm_config.compilation_config.level == CompilationLevel.PIECEWISE and not self.vllm_config.model_config.enforce_eager) self.cudagraph_batch_sizes = list( reversed( self.vllm_config.compilation_config.cudagraph_capture_sizes)) + self.draft_prefill_kv_sharing_from_base = ( + self.kv_sharing_mapping is not None and envs.VLLM_DECODE_ONLY_ATTN) # persistent buffers for cuda graph self.input_ids = torch.zeros(self.max_num_tokens, @@ -74,6 +90,276 @@ def __init__( 1, device=device, dtype=torch.int32) + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=device) + + def _prepare_adjusted_tensors( + self, + target_token_ids: torch.Tensor, + target_positions: torch.Tensor, + target_hidden_states: torch.Tensor, + target_slot_mapping: torch.Tensor, + cu_num_tokens: torch.Tensor, + decode_mask: torch.Tensor, + full_prefill_mask: torch.Tensor, + partial_prefill_mask: torch.Tensor, + prefill_first_hiddens: torch.Tensor, + block_table: torch.Tensor, + batch_size: int, + num_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, + torch.Tensor]: + """ + Prepare adjusted tensors for different request types + (partial prefill, full prefill, full decode). + + Args: + target_token_ids: Input token IDs tensor + target_positions: Input position IDs tensor + target_hidden_states: Input hidden states tensor + target_slot_mapping: Input slot mapping tensor + cu_num_tokens: Cumulative number of tokens per request + decode_mask: Mask indicating which tokens are for decoding + full_prefill_mask: Mask indicating which requests are full prefill + prefill_first_hiddens: First hidden states for prefill requests + block_table: Block table for KV cache mapping + batch_size: Number of requests in the batch + num_tokens: Total number of tokens + + Returns: + tuple: (target_positions, target_hidden_states, target_slot_mapping, + cu_num_tokens, current_pos, partial_prefill_mask) + + Algorithm design: + - Suppose target tokens are [1,2,3,...N], next token is N+1 + - Position is [0,1,2,...N-1] + - And hidden is [h1,h2,h3,...hN] + - Suppose partial prefill is [Nm, Nm+1, ...Nm+M-1] + -- For normal shifting: + --- draft prefill is [2,3,...N+1], position is same as target + --- Stacking hidden is [h1,h2,h3,...hN] + --- Decode tokens are [N+2, N+3, ...], hidden is [hN+1,hN+2,...] + --- Decode positions are [N,N+1,...] + --- draft partial prefill is [Nm+1, Nm+2, ...Nm+M] + -- For non-shifting: + --- draft full prefill is [1,2,3,...N+1], position is [0,1,2,...N] + --- Stacking hidden is [hN,h1,h2,h3,...hN] + --- Decode tokens are [N+2, N+3, ...], hidden is [hN+1,hN+2,...] + --- Decode positions are [N+1,N+2,...] + --- draft partial prefill is [Nm, Nm+1, ...Nm+M-1] + --- draft hidden is [hNm-1,hNm,...hNm+M] + (hNm-1 is the last round hidden) + -- For kv sharing(non-shifting required): + This means all target prefill tokens are not needed to be processed + in drafting prefill step as we don't need the kv from draft. + --- draft full prefill is [N+1], position is [N] + --- Stacking hidden is [hN] + --- Decode is the same as non-shifting decode + --- draft partial prefill is totally skipped + All other metadata like slot mapping, etc. should be based on + the positions and tokens to generate/manipulate again + """ + # Count total number of full prefill requests to determine the + # size needed for adjusted tensors + num_full_prefill = full_prefill_mask.sum().item() + + # Create tensors with extra space for the additional + # positions from full prefill requests + adjusted_token_ids = torch.zeros( + num_tokens + num_full_prefill, + dtype=target_token_ids.dtype, + device=target_token_ids.device, + ) + adjusted_positions = torch.zeros( + num_tokens + num_full_prefill, + dtype=target_positions.dtype, + device=target_positions.device, + ) + adjusted_slot_mapping = torch.zeros( + num_tokens + num_full_prefill, + dtype=target_slot_mapping.dtype, + device=target_slot_mapping.device, + ) + adjusted_hidden_states = torch.zeros( + num_tokens + num_full_prefill, + self.hidden_size, + dtype=target_hidden_states.dtype, + device=target_hidden_states.device, + ) + if self.draft_prefill_kv_sharing_from_base: + # Get the KV caches from the forward context + attentions = get_layers_from_vllm_config(self.vllm_config, + Attention) + kv_caches = { + layer: att.kv_cache[0] + for layer, att in attentions.items() + if layer in self.kv_sharing_mapping + or layer in self.kv_sharing_mapping.values() + } + copy_positions_mask = ~decode_mask + full_prefill_last_pos = cu_num_tokens[1:][full_prefill_mask] - 1 + copy_positions_mask[full_prefill_last_pos] = True + + # Call the function to copy KV cache values + copy_kv_cache_for_layers( + kv_caches=kv_caches, + kv_sharing_layers_mapping=self.kv_sharing_mapping, + copy_positions_mask=copy_positions_mask, + slot_mapping=target_slot_mapping, + ) + + # Create updated cumulative token counts + updated_cu_num_tokens = torch.zeros_like(cu_num_tokens) + + # Process batched operations using masks + current_pos = 0 + cu_num_tokens_index = 0 + + # Process each request in the batch + # Process all requests in batch order but with optimized operations + # Create arrays to track request properties + req_starts = cu_num_tokens[:-1] + req_ends = cu_num_tokens[1:] + req_lens = req_ends - req_starts + + # Process each request in order + for i in range(batch_size): + # Get the start and end indices for this request + start_idx = req_starts[i].item() + end_idx = req_ends[i].item() + req_len = req_lens[i].item() + + # Check category + is_partial_prefill = partial_prefill_mask[i].item() + is_full_prefill = full_prefill_mask[i].item() + + if is_partial_prefill: + # Category 1: Partial prefill - just copy all tokens + # if we enable copy kv then all of the tokens are skipped + if not self.draft_prefill_kv_sharing_from_base: + adjusted_token_ids[current_pos:current_pos + + req_len].copy_( + target_token_ids[start_idx:end_idx]) + + adjusted_positions[current_pos:current_pos + + req_len].copy_( + target_positions[start_idx:end_idx]) + + adjusted_slot_mapping[current_pos:current_pos + + req_len].copy_(target_slot_mapping[ + start_idx:end_idx]) + + # Put the first prefill hidden state in the first position + # and shift all the other ones, this matches the sequence + # as non-shifting will include the first prefill token + adjusted_hidden_states[current_pos + 1:current_pos + + req_len].copy_( + target_hidden_states[start_idx + + 1:end_idx]) + + adjusted_hidden_states[ + current_pos] = prefill_first_hiddens[i] + current_pos += req_len + cu_num_tokens_index += 1 + + elif is_full_prefill: + # Category 2: Full prefill with decode: + # copy tokens and add one position + pos = target_positions[end_idx - 1] + 1 + block_number = pos // self.block_size + block_number = block_table[i][block_number].item() + block_offset = pos % self.block_size + adjusted_slot = (block_number * self.block_size + block_offset) + + if not self.draft_prefill_kv_sharing_from_base: + # copy the original and adjust the one additional token + # for position, slot mapping and hidden state + adjusted_token_ids[current_pos:current_pos + + req_len].copy_( + target_token_ids[start_idx:end_idx]) + + adjusted_positions[current_pos:current_pos + + req_len].copy_( + target_positions[start_idx:end_idx]) + adjusted_positions[current_pos + req_len] = pos + + adjusted_slot_mapping[current_pos:current_pos + + req_len].copy_(target_slot_mapping[ + start_idx:end_idx]) + adjusted_slot_mapping[current_pos + + req_len] = (adjusted_slot) + + adjusted_hidden_states[ + current_pos + 1:current_pos + req_len + 1].copy_( + target_hidden_states[start_idx:end_idx]) + + adjusted_hidden_states[ + current_pos] = prefill_first_hiddens[i] + current_pos += req_len + 1 + else: + # if we enable copy kv then all of the prefill tokens + # are skipped. Only keep the prefill output token + adjusted_positions[current_pos] = pos + adjusted_slot_mapping[current_pos] = adjusted_slot + adjusted_hidden_states[current_pos] = ( + target_hidden_states[end_idx - 1]) + current_pos += 1 + + cu_num_tokens_index += 1 + + else: + # Category 3: Full decode - shift tokens + # Due to additional token in full prefill already, + # all the corresponding decode rounds will shift one tokens + adjusted_token_ids[current_pos:current_pos + req_len - + 1].copy_(target_token_ids[start_idx + + 1:end_idx]) + + adjusted_positions[current_pos:current_pos + req_len].copy_( + target_positions[start_idx:end_idx] + 1) + + adjusted_slot_mapping[current_pos:current_pos + req_len - + 1].copy_(target_slot_mapping[start_idx + + 1:end_idx]) + + pos = adjusted_positions[current_pos + req_len - 1] + block_number = pos // self.block_size + block_number = block_table[i][block_number].item() + block_offset = pos % self.block_size + adjusted_slot_mapping[current_pos + req_len - + 1] = (block_number * self.block_size + + block_offset) + + adjusted_hidden_states[current_pos:current_pos + + req_len].copy_(target_hidden_states[ + start_idx:end_idx]) + + current_pos += req_len + cu_num_tokens_index += 1 + + # Update the cumulative token count for this request + updated_cu_num_tokens[cu_num_tokens_index] = current_pos + + # using current_pos to cap the actual number of tokens + # Copy the adjusted tensors to the input buffers + self.input_ids[:current_pos] = adjusted_token_ids[:current_pos] + + # Update the variables used by the rest of the function + target_positions = adjusted_positions[:current_pos] + target_hidden_states = adjusted_hidden_states[:current_pos] + target_slot_mapping = adjusted_slot_mapping[:current_pos] + cu_num_tokens = updated_cu_num_tokens + + return ( + target_positions, + target_hidden_states, + target_slot_mapping, + cu_num_tokens, + current_pos, + partial_prefill_mask, + ) def propose( self, @@ -92,10 +378,14 @@ def propose( # [batch_size, max_num_blocks_per_req] block_table: torch.Tensor, sampling_metadata: SamplingMetadata, + prefill_first_hiddens: torch.Tensor, + mm_embeds: Optional[list[torch.Tensor]] = None, + decode_mask: torch.Tensor = None, + full_prefill_mask: torch.Tensor = None, + partial_prefill_mask: torch.Tensor = None, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -103,12 +393,95 @@ def propose( target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size - # Shift the input ids by one token. - # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3] - self.input_ids[:num_tokens - 1] = target_token_ids[1:] - # Replace the last token with the next token. - # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] - self.input_ids[last_token_indices] = next_token_ids + prefill_shift_tokens = True + has_prefill = decode_mask is not None and ( + ~decode_mask.bool()).any().item() + if not self.speculative_config.eagle_shift_prefill_token() and ( + self.method in ["eagle", "eagle3"]): + assert decode_mask is not None + assert full_prefill_mask is not None + prefill_shift_tokens = False + + if not prefill_shift_tokens and has_prefill: + if (partial_prefill_mask.all() + and self.draft_prefill_kv_sharing_from_base): + # All requests are partial prefill and + # KV cache sharing is enabled + # Skip the rest of the function + # and return dummy draft tokens + return torch.zeros( + (batch_size, self.num_speculative_tokens), + dtype=target_token_ids.dtype, + device=target_token_ids.device, + ) + # Adjust the tensors for full prefill requests + ( + target_positions, + target_hidden_states, + target_slot_mapping, + cu_num_tokens, + num_tokens, + partial_prefill_mask, + ) = self._prepare_adjusted_tensors( + target_token_ids, + target_positions, + target_hidden_states, + target_slot_mapping, + cu_num_tokens, + decode_mask, + full_prefill_mask, + partial_prefill_mask, + prefill_first_hiddens, + block_table, + batch_size, + num_tokens, + ) + batch_size = cu_num_tokens.shape[0] - 1 + else: + # Original behavior: shift all tokens by one + self.input_ids[:num_tokens - 1] = target_token_ids[1:] + partial_prefill_mask = torch.zeros_like(full_prefill_mask) + if not prefill_shift_tokens: + # For pure decode in non-shifting prefill case + # Due to one additional token in prefill, all the decode + # rounds will shift one token + target_positions += 1 + max_num_blocks_per_req = block_table.shape[1] + segment_indices = torch.arange(len(target_positions), + device=target_positions.device) + segment_indices = (segment_indices.unsqueeze(0) + >= cu_num_tokens[:-1].unsqueeze(1)).sum( + dim=0) - 1 + # Calculate the block table indices + block_table_indices = ( + target_positions // self.block_size + + segment_indices * max_num_blocks_per_req) + block_numbers = block_table.flatten()[block_table_indices] + block_offsets = target_positions % self.block_size + target_slot_mapping = (block_numbers * self.block_size + + block_offsets) + + # Use the original last token indices + last_token_indices = cu_num_tokens[1:] - 1 + + if not prefill_shift_tokens and has_prefill: + # Replace the last token with the next token under non-shifting, + # but only for non-partial prefill requests + # For partial prefill in non-shifting, we just match the target + # prefill tokens as it would match the positions and hidden states + # so no need to add this next token from next round + mask = ~partial_prefill_mask + # if we enable copy kv then all of the partial prefills + # are completely skipped so they won't be in last_token_indices + input_indices = ( + last_token_indices[mask] + if not self.draft_prefill_kv_sharing_from_base else + last_token_indices[:batch_size - + partial_prefill_mask.sum().item()]) + self.input_ids[input_indices] = next_token_ids[mask] + else: + # Original behavior: apply to all requests + self.input_ids[last_token_indices] = next_token_ids # FA requires seq_len to have dtype int32. seq_lens = (target_positions[last_token_indices] + 1).int() @@ -160,22 +533,35 @@ def propose( per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - if self.use_cuda_graph and \ - num_tokens <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) else: num_input_tokens = num_tokens # copy inputs to buffer for cudagraph self.positions[:num_tokens] = target_positions self.hidden_states[:num_tokens] = target_hidden_states + if self.is_multimodal_model: + input_ids = self.input_ids[:num_tokens] + if mm_embeds: + inputs_embeds = self.model.get_input_embeddings( + input_ids, mm_embeds) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:num_tokens] = inputs_embeds + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:num_input_tokens] with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens): ret_hidden_states = self.model( - self.input_ids[:num_input_tokens], - self.positions[:num_input_tokens], - self.hidden_states[:num_input_tokens], + input_ids=input_ids, + positions=self.positions[:num_input_tokens], + hidden_states=self.hidden_states[:num_input_tokens], + inputs_embeds=inputs_embeds, ) if self.method == "deepseek_mtp": last_hidden_states = ret_hidden_states @@ -187,6 +573,21 @@ def propose( # Early exit if there is only one draft token to be generated. if self.num_speculative_tokens == 1: + if (self.draft_prefill_kv_sharing_from_base + and partial_prefill_mask.any().item()): + # if we do kv sharing and has partial prefill + # the original position for partial prefill will not + # have token output thus we need to pad the draft tokens + # with the correct positions + padded_draft_token_ids = torch.zeros( + partial_prefill_mask.shape[0], + dtype=draft_token_ids.dtype, + device=draft_token_ids.device) + draft_token_ids = draft_token_ids[:batch_size - + partial_prefill_mask.sum( + ).item()] + padded_draft_token_ids[~partial_prefill_mask] = draft_token_ids + draft_token_ids = padded_draft_token_ids # [batch_size, 1] return draft_token_ids.view(-1, 1) @@ -199,8 +600,7 @@ def propose( positions = target_positions[last_token_indices] hidden_states = hidden_states[last_token_indices] - if self.use_cuda_graph and \ - batch_size <= self.cudagraph_batch_sizes[-1]: + if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]: input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size) else: input_batch_size = batch_size @@ -253,15 +653,24 @@ def propose( self.input_ids[:batch_size] = input_ids self.positions[:batch_size] = clamped_positions self.hidden_states[:batch_size] = hidden_states + if self.is_multimodal_model: + inputs_embeds = self.model.get_input_embeddings(input_ids) + self.inputs_embeds[:batch_size] = inputs_embeds + inputs_embeds = self.inputs_embeds[:input_batch_size] + input_ids = None + else: + inputs_embeds = None + input_ids = self.input_ids[:input_batch_size] # Run the model. with set_forward_context(per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size): last_hidden_states, hidden_states = self.model( - self.input_ids[:input_batch_size], - self.positions[:input_batch_size], - self.hidden_states[:input_batch_size], + input_ids=input_ids, + positions=self.positions[:input_batch_size], + hidden_states=self.hidden_states[:input_batch_size], + inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:batch_size] logits = self.model.compute_logits(last_hidden_states[:batch_size], @@ -273,6 +682,21 @@ def propose( # [batch_size, num_speculative_tokens] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) + if (self.draft_prefill_kv_sharing_from_base + and partial_prefill_mask.any().item()): + # if we do kv sharing and has partial prefill + # the original position for partial prefill will not + # have token output thus we need to pad the draft tokens + # with the correct positions + padded_draft_token_ids = torch.zeros( + (partial_prefill_mask.shape[0], self.num_speculative_tokens), + dtype=draft_token_ids.dtype, + device=draft_token_ids.device) + draft_token_ids = draft_token_ids[:batch_size - + partial_prefill_mask.sum().item( + )] + padded_draft_token_ids[~partial_prefill_mask] = draft_token_ids + draft_token_ids = padded_draft_token_ids return draft_token_ids @staticmethod @@ -292,8 +716,7 @@ def prepare_inputs( # a + b, a + b + 1, ..., a + b + c - n3 - 1] # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) + query_len_per_req = cu_target_query_lens[1:] - cu_target_query_lens[:-1] # [a, b, c] -> [a - n1, b - n2, c - n3] num_tokens_per_req = query_len_per_req - num_rejected_tokens @@ -317,8 +740,8 @@ def prepare_inputs( return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - draft_model_config = \ - self.vllm_config.speculative_config.draft_model_config + draft_model_config = ( + self.vllm_config.speculative_config.draft_model_config) target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) @@ -333,6 +756,30 @@ def load_model(self, target_model: nn.Module) -> None: self.attn_layer_names = list(draft_attn_layer_names) + if envs.VLLM_DECODE_ONLY_ATTN: + logger.info("Using half prefill for decoding only attention," + " Setting up KV sharing from base to draft.") + assert self.kv_sharing_mapping is not None + target_attn_layer_indices_dict = dict( + (str(extract_layer_index(target_layer)), target_layer) + for target_layer in target_attn_layer_names) + draft_attn_layer_indices_dict = dict( + (str(extract_layer_index(draft_layer)), draft_layer) + for draft_layer in draft_attn_layer_names) + updated_kv_sharing_mapping = { + draft_attn_layer_indices_dict[draft_index]: + target_attn_layer_indices_dict[target_index] + for draft_index, target_index in + self.kv_sharing_mapping.items() + } + logger.info("Updated KV sharing mapping: %s", + updated_kv_sharing_mapping) + assert len(updated_kv_sharing_mapping) == len( + self.kv_sharing_mapping), ( + "KV sharing mapping should be a subset of draft and" + " target attn layer indices") + self.kv_sharing_mapping = updated_kv_sharing_mapping + if supports_multimodal(target_model): # handle multimodality self.model.config.image_token_index = ( @@ -372,10 +819,18 @@ def dummy_run( ) -> None: with set_forward_context(None, self.vllm_config, num_tokens=num_tokens): + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None + self.model( - self.input_ids[:num_tokens], - self.positions[:num_tokens], - self.hidden_states[:num_tokens], + input_ids=input_ids, + positions=self.positions[:num_tokens], + hidden_states=self.hidden_states[:num_tokens], + inputs_embeds=inputs_embeds, ) def validate_same_kv_cache_group(self, @@ -390,12 +845,12 @@ def validate_same_kv_cache_group(self, for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups): for layer_name in kv_cache_group.layer_names: kv_cache_groups[layer_name] = id - assert len( + assert (len( set([ kv_cache_groups[layer_name] for layer_name in self.attn_layer_names - ]) - ) == 1, "All eagle layers should belong to the same kv cache group" + ])) == 1 + ), "All eagle layers should belong to the same kv cache group" # NOTE(woosuk): Currently, the below code is not used and we always use argmax diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a..6d5ae95e5da 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -44,6 +44,10 @@ class CachedRequestState: lora_request: Optional[LoRARequest] = None + # Bootstrap eagle and MTP related hidden states + # support caching last hidden states for partial prefill + prefill_hidden_states: Optional[torch.Tensor] = None + def __post_init__(self): self.num_prompt_tokens = len(self.prompt_token_ids) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a26e88db1f..c886ac7776e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -579,12 +579,12 @@ def _prepare_inputs( self, scheduler_output: "SchedulerOutput", ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray]: + Optional[SpecDecodeMetadata], np.ndarray, torch.Tensor]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, attention_cuda_graphs: whether attention can run in cudagraph - logits_indices, spec_decode_metadata + logits_indices, spec_decode_metadata, decode_mask ] """ total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -601,11 +601,18 @@ def _prepare_inputs( tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] num_scheduled_tokens = np.array(tokens, dtype=np.int32) max_num_scheduled_tokens = max(tokens) + req_last_prompt_index = np.array([ + self.requests[req_id].num_prompt_tokens - 1 + for req_id in self.input_batch.req_ids + ], + dtype=np.int32) # Get request indices. # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens) + last_prompt_indices_np = np.repeat(req_last_prompt_index, + num_scheduled_tokens) # cu_num_tokens: [2, 5, 3] -> [2, 7, 10] # arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] @@ -618,6 +625,8 @@ def _prepare_inputs( arange, out=positions_np) + decode_mask_np = positions_np >= last_prompt_indices_np + # Calculate M-RoPE positions. # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -765,8 +774,10 @@ def _prepare_inputs( if self.lora_config: self.set_active_loras(self.input_batch, num_scheduled_tokens) + decode_mask = torch.tensor(decode_mask_np, device=self.device) + return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens) + spec_decode_metadata, num_scheduled_tokens, decode_mask) def _compute_cascade_attn_prefix_len( self, @@ -1044,13 +1055,15 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): def _gather_mm_embeddings( self, scheduler_output: "SchedulerOutput", + shift_computed_tokens: int = 0, ) -> list[torch.Tensor]: mm_embeds: list[torch.Tensor] = [] for req_id in self.input_batch.req_ids: num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens + num_computed_tokens = \ + req_state.num_computed_tokens + shift_computed_tokens mm_positions = req_state.mm_positions for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset @@ -1290,8 +1303,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) + spec_decode_metadata, num_scheduled_tokens_np, + decode_mask) = (self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1544,6 +1557,8 @@ def execute_model( aux_hidden_states, spec_decode_metadata, attn_metadata, + mm_embeds, + decode_mask, ) # Clear KVConnector state after all KVs are generated. @@ -1575,6 +1590,8 @@ def propose_draft_token_ids( aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], attn_metadata: dict[str, Any], + mm_embeds: list[torch.Tensor], + decode_mask: torch.Tensor, ) -> list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -1603,24 +1620,6 @@ def propose_draft_token_ids( ) elif self.speculative_config.use_eagle(): assert isinstance(self.drafter, EagleProposer) - # TODO(woosuk): Refactor the loop. - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = self.input_batch.req_ids[i] - req_state = self.requests[req_id] - seq_len = (req_state.num_computed_tokens + - scheduler_output.num_scheduled_tokens[req_id]) - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.device) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. eagle_attn_metadata = attn_metadata[ @@ -1664,8 +1663,8 @@ def propose_draft_token_ids( num_tokens, ) target_token_ids = self.input_ids[token_indices] - # TODO(woosuk): Support M-RoPE. target_positions = self.positions[token_indices] + decode_mask = decode_mask[token_indices] if self.use_aux_hidden_state_outputs: target_hidden_states = torch.cat( [h[token_indices] for h in aux_hidden_states], dim=-1) @@ -1673,6 +1672,66 @@ def propose_draft_token_ids( target_hidden_states = hidden_states[token_indices] target_slot_mapping = eagle_attn_metadata.slot_mapping[ token_indices] + draft_mm_embeds = mm_embeds if mm_embeds else None + if envs.VLLM_DECODE_ONLY_ATTN: + draft_mm_embeds = None + elif (self.is_multimodal_model + and self.speculative_config.eagle_shift_prefill_token()): + draft_mm_embeds = self._gather_mm_embeddings( + scheduler_output, shift_computed_tokens=1) + + # TODO(woosuk): Refactor the loop. + next_token_ids: list[int] = [] + # Get the first token's hidden state from cached for each request. + # This is used in non-shifted prefill for eagle draft. + prefill_first_hiddens = [] + full_prefill_mask = [] + partial_prefill_mask = [] + for i, token_ids in enumerate(sampled_token_ids): + req_id = self.input_batch.req_ids[i] + req_state = self.requests[req_id] + # Initialize the prefill hidden state if not set., + # from experimental results, the last token hidden state + # works very well for init the first prefill hidden state. + if req_state.prefill_hidden_states is None: + req_state.prefill_hidden_states = target_hidden_states[ + cu_num_tokens[i + 1] - 1] + prefill_first_hiddens.append(req_state.prefill_hidden_states) + num_prompt_tokens = req_state.num_prompt_tokens + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + full_prefill_mask.append( + req_state.num_computed_tokens < num_prompt_tokens + and req_state.num_computed_tokens + num_scheduled_tokens + >= num_prompt_tokens) + if token_ids: + # Common case. + next_token_id = token_ids[-1] + partial_prefill_mask.append(False) + else: + # Partial prefill (rare case). + # Get the next token id from the request state. + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + next_token_id = req_state.get_token_id(seq_len) + last_hidden_index = cu_num_tokens[i + 1] - 1 + # For non-shifting prefill + partial prefill case, + # the current round last hidden state will be used + # as the first prefill hidden for the next round + req_state.prefill_hidden_states = target_hidden_states[ + last_hidden_index] + partial_prefill_mask.append(True) + next_token_ids.append(next_token_id) + next_token_ids = torch.tensor(next_token_ids, + dtype=torch.int32, + device=self.device) + prefill_first_hiddens = torch.cat(prefill_first_hiddens, dim=0) + full_prefill_mask = torch.tensor(full_prefill_mask, + dtype=torch.bool, + device=self.device) + partial_prefill_mask = torch.tensor(partial_prefill_mask, + dtype=torch.bool, + device=self.device) draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, @@ -1682,6 +1741,11 @@ def propose_draft_token_ids( cu_num_tokens=cu_num_tokens, block_table=block_table, sampling_metadata=sampling_metadata, + mm_embeds=draft_mm_embeds, + prefill_first_hiddens=prefill_first_hiddens, + decode_mask=decode_mask, + full_prefill_mask=full_prefill_mask, + partial_prefill_mask=partial_prefill_mask, ) spec_token_ids = draft_token_ids.tolist() return spec_token_ids diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 70339ff2f00..9a58854c2f3 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -110,3 +110,79 @@ def initialize_kv_cache_for_kv_sharing( kv_caches[layer_name] = kv_caches[target_layer_name] group_idx = layer_to_kv_cache_group_idx[target_layer_name] kv_cache_groups[group_idx].layer_names.append(layer_name) + + +def copy_kv_cache_for_layers( + kv_caches: dict[str, torch.Tensor], + kv_sharing_layers_mapping: dict[str, str], + copy_positions_mask: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + """ + Copies values from one set of layers to another + based on a mapping and a mask. + + This function is primarily used when KV sharing + is enabled especially for spec decoding to copy + target model's cache for the specified positions into + corresponding draft model layer's KV cache. + + Args: + kv_caches: + Dictionary mapping layer names to their cache tensors. + kv_sharing_layers_mapping: + Mapping the target layers to the source layers. + copy_positions_mask: + Boolean mask where True indicates positions to copy. + slot_mapping: + Mapping tensor that maps positions to cache slots. + """ + # Get the positions to copy + positions = torch.nonzero(copy_positions_mask, as_tuple=True)[0] + + if positions.numel() == 0: + # No positions to copy + return + + # Get the corresponding slot mappings for the positions + slots = slot_mapping[positions] + + # Copy KV cache values from source layers to target layers + for target_layer, source_layer in kv_sharing_layers_mapping.items(): + if target_layer not in kv_caches or source_layer not in kv_caches: + continue + + target_kv_cache = kv_caches[target_layer] + source_kv_cache = kv_caches[source_layer] + + block_size = source_kv_cache.shape[2] + + kv_dim = 2 + # Process in smaller batches to reduce memory overhead + batch_size = 8192 + num_positions = positions.size(0) + + for start_idx in range(0, num_positions, batch_size): + end_idx = min(start_idx + batch_size, num_positions) + + # Get batch of slots + batch_slots = slots[start_idx:end_idx] + batch_block_indices = batch_slots // block_size + batch_block_offsets = batch_slots % block_size + + # Create batch-sized indexing tensors + batch_block_indices_expanded = batch_block_indices.view( + 1, -1, 1, 1, 1) + batch_block_offsets_expanded = batch_block_offsets.view( + 1, 1, -1, 1, 1) + + # Copy values for this batch + for kv_idx in range(kv_dim): + target_kv_cache[ + kv_idx, + batch_block_indices_expanded.squeeze(), + batch_block_offsets_expanded.squeeze(), :, :] = ( + source_kv_cache[ + kv_idx, + batch_block_indices_expanded.squeeze(), + batch_block_offsets_expanded.squeeze(), :, :])