diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 26e492fed25..ce735f3b27d 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -84,6 +84,7 @@ def main(): gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, + max_model_len=16384, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) diff --git a/tests/models/registry.py b/tests/models/registry.py index 48302f9d664..b94376970ce 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -451,6 +451,11 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", tokenizer="meta-llama/Llama-3.1-8B-Instruct"), + "EagleLlama4ForCausalLM": _HfExamplesInfo( + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + trust_remote_code=True, + speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 "EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16", trust_remote_code=True, is_available_online=False, @@ -500,4 +505,4 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: raise ValueError(f"No example model defined for {model_id}") -HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) \ No newline at end of file +HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 25bc96bf326..6ab508a9c5e 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -26,6 +26,11 @@ def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): "KimiVLForConditionalGeneration"): pytest.skip("Avoid OOM") + if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): + from vllm.model_executor.models.llama4 import Llama4ForCausalLM + from vllm.model_executor.models.registry import ModelRegistry + ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) + # Avoid OOM and reduce initialization time by only using 1 layer def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update(model_info.hf_overrides) diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 93e7c12f3a0..2423f966acf 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -6,8 +6,10 @@ from typing import Any import pytest +import torch from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory @pytest.fixture @@ -53,14 +55,6 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -def eagle_model_name(): - return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - - -def eagle3_model_name(): - return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - - def test_ngram_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], @@ -77,6 +71,8 @@ def test_ngram_correctness( ref_llm = LLM(model=model_name, max_model_len=1024) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() spec_llm = LLM( model=model_name, @@ -103,34 +99,50 @@ def test_ngram_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.7 * len(ref_outputs)) del spec_llm - - -@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("model_setup", [ + ("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), + ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), +], + ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, - model_name: str, - use_eagle3: bool, + model_setup: tuple[str, str, str, int], ): ''' Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. + model_setup: (method, model_name, eagle_model_name, tp_size) ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + method, model_name, spec_model_name, tp_size = model_setup - ref_llm = LLM(model=model_name, max_model_len=2048) + ref_llm = LLM(model=model_name, + max_model_len=2048, + tensor_parallel_size=tp_size) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() - spec_model_name = eagle3_model_name( - ) if use_eagle3 else eagle_model_name() spec_llm = LLM( model=model_name, trust_remote_code=True, + tensor_parallel_size=tp_size, speculative_config={ - "method": "eagle3" if use_eagle3 else "eagle", + "method": method, "model": spec_model_name, "num_speculative_tokens": 3, "max_model_len": 2048, @@ -152,3 +164,5 @@ def test_eagle_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py new file mode 100644 index 00000000000..222ab5dfaee --- /dev/null +++ b/vllm/model_executor/models/llama4_eagle.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. +# All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.torchao import TorchAOConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, + Llama4ForCausalLM) +from vllm.model_executor.models.utils import extract_layer_index + +from .utils import AutoWeightsLoader, maybe_prefix + +logger = init_logger(__name__) + + +@support_torch_compile +class LlamaModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + self.validate_and_update_config(start_layer_id, quant_config) + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList([ + Llama4DecoderLayer( + self.config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) + ]) + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.fc( + torch.cat((input_embeds, hidden_states), dim=-1)) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + name = name.removeprefix("model.") + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + for name in params_dict: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + assert name in loaded_params, f"{name} is not loaded!" + return loaded_params + + def validate_and_update_config( + self, + start_layer_id: int, + quant_config: Optional[QuantizationConfig] = None) -> None: + # yoco and moe is not supported by draft model yet + assert self.config.yoco_global_kv_layer is None + assert self.config.yoco_local_kv_layer is None + assert len(self.config.moe_layers) == 0 + # draft model layer index is increased by start_layer_id, + # so we need to pad relevant configs accordingly + self.config.no_rope_layers = [ + 0 + ] * start_layer_id + self.config.no_rope_layers + # currently only TorchAO quantization is supported + if isinstance(quant_config, TorchAOConfig): + + def pad_layer_name(layer: str) -> str: + layer_index = extract_layer_index(layer) + return layer.replace(str(layer_index), + str(layer_index + start_layer_id)) + + quant_config.torchao_config.module_fqn_to_config = { + pad_layer_name(layer): quantization + for layer, quantization in + quant_config.torchao_config.module_fqn_to_config.items() + } + + +class EagleLlama4ForCausalLM(Llama4ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + # draft model quantization config may differ from target model + quant_config = VllmConfig.get_quantization_config( + vllm_config.speculative_config.draft_model_config, + vllm_config.load_config) + self.model = LlamaModel(vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num, + quant_config=quant_config) + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> None: + loader = AutoWeightsLoader( + self, + # lm_head is tied with target model (Llama4ForCausalLM) + skip_prefixes=(["lm_head."]), + ) + + model_weights = {} + weights = [ + self.permute_qk_weight_for_rotary(name, loaded_weight) + for name, loaded_weight in weights + ] + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 27d47692985..fd43ed836f2 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -239,6 +239,7 @@ "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),