Skip to content

Add Eagle-3 Qwen support (follow-up to #20436) #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,10 +2769,10 @@
pass
elif hasattr(self.draft_model_config.hf_config,
"speculators_model_type") and \
self.draft_model_config.hf_config.speculators_model_type in ("eagle", "eagle3"):

Check failure on line 2772 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2772:81: E501 Line too long (104 > 80)

Check failure on line 2772 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2772:81: E501 Line too long (104 > 80)

Check failure on line 2772 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2772:81: E501 Line too long (104 > 80)
self.method = self.draft_model_config.hf_config.speculators_model_type

Check failure on line 2773 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2773:81: E501 Line too long (90 > 80)

Check failure on line 2773 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2773:81: E501 Line too long (90 > 80)

Check failure on line 2773 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2773:81: E501 Line too long (90 > 80)
elif "eagle-" in self.draft_model_config.model.lower() or \
"eagle3-" in self.draft_model_config.model.lower():

Check failure on line 2775 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2775:81: E501 Line too long (134 > 80)

Check failure on line 2775 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2775:81: E501 Line too long (134 > 80)

Check failure on line 2775 in vllm/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/config.py:2775:81: E501 Line too long (134 > 80)
self.method = "eagle"
elif self.draft_model_config.hf_config.model_type == "eagle":
self.method = "eagle"
Expand Down Expand Up @@ -2992,14 +2992,33 @@
"speculative decoding is > 1, but got "
f"{self.disable_by_batch_size=}")

if self.method == "eagle3" and self.target_model_config and \
"llama" not in self.target_model_config.hf_text_config.model_type:
if (
self.method == "eagle3"
and self.target_model_config
and self.draft_model_config
and hasattr(self.draft_model_config.hf_text_config, "speculators_version")
):
# Speculators model detected
if ("llama" not in self.target_model_config.hf_text_config.model_type
and "qwen" not in self.target_model_config.hf_text_config.model_type):
raise ValueError(
"Eagle3 is only supported for Llama and Qwen models "
"in speculators format. "
f"Got {self.target_model_config.hf_text_config.model_type=}"
)
return self

if (
self.method == "eagle3"
and self.target_model_config
and "llama" not in self.target_model_config.hf_text_config.model_type
):
raise ValueError(
"Eagle3 is only supported for Llama models. "
f"Got {self.target_model_config.hf_text_config.model_type=}")
f"Got {self.target_model_config.hf_text_config.model_type=}"
)

return self

@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per
Expand Down
44 changes: 44 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
from vllm.transformers_utils.configs.speculators_eagle import is_speculators_eagle_config

Check failure on line 44 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:44:81: E501 Line too long (89 > 80)

Check failure on line 44 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:44:81: E501 Line too long (89 > 80)

Check failure on line 44 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:44:81: E501 Line too long (89 > 80)

# yapf: enable

Expand Down Expand Up @@ -394,6 +394,7 @@
str] = ModelConfig.logits_processor_pattern

speculative_config: Optional[Dict[str, Any]] = None
draft_tensor_parallel_size: Optional[int] = None

show_hidden_metrics_for_version: Optional[str] = \
ObservabilityConfig.show_hidden_metrics_for_version
Expand Down Expand Up @@ -767,6 +768,13 @@
default=None,
help="The configurations for speculative decoding. Should be a "
"JSON string.")
speculative_group.add_argument(
"--draft-tensor-parallel-size",
type=int,
default=None,
help="Number of tensor parallel replicas for the draft model. "
"Only used with speculative decoding. "
"Note: draft_tensor_parallel_size > 1 is not supported at the moment.")

Check failure on line 777 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:777:81: E501 Line too long (82 > 80)

Check failure on line 777 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:777:81: E501 Line too long (82 > 80)

Check failure on line 777 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:777:81: E501 Line too long (82 > 80)

# Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig)
Expand Down Expand Up @@ -874,6 +882,42 @@

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# Auto-detect speculators format models
if args.model and not args.speculative_config:
from vllm.transformers_utils.configs import extract_speculators_info
from vllm.logger import init_logger
logger = init_logger(__name__)

speculators_info = extract_speculators_info(args.model)
if speculators_info:
# Log what we're doing
logger.info("🦅 Auto-detected Eagle speculators format model")
logger.info(f" Target model: {speculators_info['target_model']}")
logger.info(f" Draft model: {args.model}")
logger.info(f" Method: {speculators_info['method']}")

Check failure on line 897 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:897:21: G004 Logging statement uses f-string

Check failure on line 897 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:897:21: G004 Logging statement uses f-string

Check failure on line 897 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:897:21: G004 Logging statement uses f-string
logger.info(f" Speculative tokens: {speculators_info['num_tokens']}")

Check failure on line 898 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:898:29: G004 Logging statement uses f-string

Check failure on line 898 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:898:29: G004 Logging statement uses f-string

Check failure on line 898 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:898:29: G004 Logging statement uses f-string

Check failure on line 899 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:899:29: G004 Logging statement uses f-string

Check failure on line 899 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:899:29: G004 Logging statement uses f-string

Check failure on line 899 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:899:29: G004 Logging statement uses f-string
# Build speculative config
spec_config = {

Check failure on line 901 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:901:21: G004 Logging statement uses f-string

Check failure on line 901 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:901:21: G004 Logging statement uses f-string

Check failure on line 901 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:901:21: G004 Logging statement uses f-string
"method": speculators_info["method"],
"model": args.model, # Original model becomes draft
"num_speculative_tokens": speculators_info["num_tokens"],
}

# Add draft tensor parallel size if specified
if hasattr(args, 'draft_tensor_parallel_size') and args.draft_tensor_parallel_size is not None:
spec_config["draft_tensor_parallel_size"] = args.draft_tensor_parallel_size

# Set the speculative config directly (it's already parsed by argparse)
args.speculative_config = spec_config

# Swap the model to target

Check failure on line 914 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:914:81: E501 Line too long (87 > 80)

Check failure on line 914 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:914:81: E501 Line too long (87 > 80)

Check failure on line 914 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:914:81: E501 Line too long (87 > 80)
args.model = speculators_info["target_model"]

# Also update tokenizer if not explicitly set
if not hasattr(args, 'tokenizer') or args.tokenizer is None:
args.tokenizer = speculators_info["target_model"]

# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
Expand Down
26 changes: 19 additions & 7 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def __init__(self,
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.aux_hidden_state_layers: tuple[int] = tuple()
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
Expand Down Expand Up @@ -351,18 +351,30 @@ def forward(
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)

if len(self.aux_hidden_state_layers) > 0:
aux_hidden_states = []
for idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
else:
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)

if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states

def load_weights(self, weights: Iterable[tuple[str,
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.quant_config = quant_config
self.model = Qwen3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
Expand Down Expand Up @@ -302,6 +302,15 @@ def forward(
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states


def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
self.model.aux_hidden_state_layers = layers


def compute_logits(
self,
Expand All @@ -322,4 +331,4 @@ def load_weights(self, weights: Iterable[tuple[str,
return loader.load_weights(weights)


Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)
Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM)
4 changes: 3 additions & 1 deletion vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.transformers_utils.configs.speculators_eagle import SpeculatorsEagleConfig
from vllm.transformers_utils.configs.speculators_eagle import (
SpeculatorsEagleConfig, extract_speculators_info)
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
Expand Down Expand Up @@ -42,6 +43,7 @@
"EAGLEConfig",
"ExaoneConfig",
"SpeculatorsEagleConfig",
"extract_speculators_info",
"MiniMaxText01Config",
"MiniMaxVL01Config",
"MllamaConfig",
Expand Down
58 changes: 57 additions & 1 deletion vllm/transformers_utils/configs/speculators_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from typing import Any, Union
from typing import Any, Optional, Union

from transformers import PretrainedConfig

Expand Down Expand Up @@ -282,3 +282,59 @@ def is_speculators_eagle_config(config_path: Union[str, os.PathLike]) -> bool:
return model_type in SUPPORTED_SPECULATORS_TYPES
except Exception:
return False


def extract_speculators_info(model_path: Union[str, os.PathLike]) -> Optional[dict[str, Any]]:
"""
Extract target model and config from speculators format model.

Returns dict with:
- target_model: str - The target model name/path
- method: str - The speculative method (eagle/eagle3)
- num_tokens: int - Number of speculative tokens

Returns None if not speculators format or missing target model.
"""
try:
# Check if it's speculators format
if not is_speculators_eagle_config(model_path):
return None

# Load the config
config_dict, _ = PretrainedConfig.get_config_dict(model_path)

# Extract method
method = config_dict.get("speculators_model_type", "eagle")

# Extract num tokens
num_tokens = DEFAULT_NUM_LOOKAHEAD_TOKENS # default
speculators_cfg = config_dict.get("speculators_config", {})
proposal_methods = speculators_cfg.get("proposal_methods", [])
if proposal_methods:
num_tokens = proposal_methods[0].get("speculative_tokens", DEFAULT_NUM_LOOKAHEAD_TOKENS)

# Extract target model - try multiple possible locations
target_model = None

# Try target_config.model_name (original format)
target_config = speculators_cfg.get("target_config", {})
target_model = target_config.get("model_name")

# Try verifier.name_or_path (new format)
if not target_model:
verifier_config = speculators_cfg.get("verifier", {})
target_model = verifier_config.get("name_or_path")

# If no target model in config, return None
# This will require user to specify target model explicitly
if not target_model:
return None

return {
"target_model": target_model,
"method": method,
"num_tokens": num_tokens
}
except Exception:
# If any error occurs, treat as not speculators format
return None