Skip to content

Enable auto-detection for Eagle speculators format models #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, get_ip, is_in_ray_actor)
from vllm.transformers_utils.configs.speculators_eagle import is_speculators_eagle_config

Check failure on line 44 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:44:81: E501 Line too long (89 > 80)

# yapf: enable

Expand Down Expand Up @@ -394,6 +394,7 @@
str] = ModelConfig.logits_processor_pattern

speculative_config: Optional[Dict[str, Any]] = None
draft_tensor_parallel_size: Optional[int] = None

show_hidden_metrics_for_version: Optional[str] = \
ObservabilityConfig.show_hidden_metrics_for_version
Expand Down Expand Up @@ -767,6 +768,13 @@
default=None,
help="The configurations for speculative decoding. Should be a "
"JSON string.")
speculative_group.add_argument(
"--draft-tensor-parallel-size",
type=int,
default=None,
help="Number of tensor parallel replicas for the draft model. "
"Only used with speculative decoding. "
"Note: draft_tensor_parallel_size > 1 is not supported at the moment.")

Check failure on line 777 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:777:81: E501 Line too long (82 > 80)

# Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig)
Expand Down Expand Up @@ -874,6 +882,42 @@

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
# Auto-detect speculators format models
if args.model and not args.speculative_config:
from vllm.transformers_utils.configs import extract_speculators_info
from vllm.logger import init_logger
logger = init_logger(__name__)

speculators_info = extract_speculators_info(args.model)
if speculators_info:
# Log what we're doing
logger.info("🦅 Auto-detected Eagle speculators format model")
logger.info(f" Target model: {speculators_info['target_model']}")
logger.info(f" Draft model: {args.model}")
logger.info(f" Method: {speculators_info['method']}")

Check failure on line 897 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:897:21: G004 Logging statement uses f-string
logger.info(f" Speculative tokens: {speculators_info['num_tokens']}")

Check failure on line 898 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:898:29: G004 Logging statement uses f-string

Check failure on line 899 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:899:29: G004 Logging statement uses f-string
# Build speculative config
spec_config = {

Check failure on line 901 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/engine/arg_utils.py:901:21: G004 Logging statement uses f-string
"method": speculators_info["method"],
"model": args.model, # Original model becomes draft
"num_speculative_tokens": speculators_info["num_tokens"],
}

# Add draft tensor parallel size if specified
if hasattr(args, 'draft_tensor_parallel_size') and args.draft_tensor_parallel_size is not None:
spec_config["draft_tensor_parallel_size"] = args.draft_tensor_parallel_size

# Set the speculative config directly (it's already parsed by argparse)
args.speculative_config = spec_config

# Swap the model to target

Check failure on line 914 in vllm/engine/arg_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/engine/arg_utils.py:914:81: E501 Line too long (87 > 80)
args.model = speculators_info["target_model"]

# Also update tokenizer if not explicitly set
if not hasattr(args, 'tokenizer') or args.tokenizer is None:
args.tokenizer = speculators_info["target_model"]

# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
Expand Down
4 changes: 3 additions & 1 deletion vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config
from vllm.transformers_utils.configs.eagle import EAGLEConfig
from vllm.transformers_utils.configs.exaone import ExaoneConfig
from vllm.transformers_utils.configs.speculators_eagle import SpeculatorsEagleConfig
from vllm.transformers_utils.configs.speculators_eagle import (
SpeculatorsEagleConfig, extract_speculators_info)
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
Expand Down Expand Up @@ -42,6 +43,7 @@
"EAGLEConfig",
"ExaoneConfig",
"SpeculatorsEagleConfig",
"extract_speculators_info",
"MiniMaxText01Config",
"MiniMaxVL01Config",
"MllamaConfig",
Expand Down
62 changes: 61 additions & 1 deletion vllm/transformers_utils/configs/speculators_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from typing import Any, Union
from typing import Any, Optional, Union

from transformers import PretrainedConfig

Expand Down Expand Up @@ -282,3 +282,63 @@ def is_speculators_eagle_config(config_path: Union[str, os.PathLike]) -> bool:
return model_type in SUPPORTED_SPECULATORS_TYPES
except Exception:
return False


def extract_speculators_info(model_path: Union[str, os.PathLike]) -> Optional[dict[str, Any]]:
"""
Extract target model and config from speculators format model.

Returns dict with:
- target_model: str - The target model name/path
- method: str - The speculative method (eagle/eagle3)
- num_tokens: int - Number of speculative tokens

Returns None if not speculators format or missing target model.
"""
try:
# Check if it's speculators format
if not is_speculators_eagle_config(model_path):
return None

# Load the config
config_dict, _ = PretrainedConfig.get_config_dict(model_path)

# Extract method
method = config_dict.get("speculators_model_type", "eagle")

# Extract num tokens
num_tokens = DEFAULT_NUM_LOOKAHEAD_TOKENS # default
speculators_cfg = config_dict.get("speculators_config", {})
proposal_methods = speculators_cfg.get("proposal_methods", [])
if proposal_methods:
num_tokens = proposal_methods[0].get("speculative_tokens", DEFAULT_NUM_LOOKAHEAD_TOKENS)

# Extract target model - try multiple possible locations
target_model = None

# Try target_config.model_name (original format)
target_config = speculators_cfg.get("target_config", {})
target_model = target_config.get("model_name")

# Try verifier.name_or_path (new format)
if not target_model:
verifier_config = speculators_cfg.get("verifier", {})
target_model = verifier_config.get("name_or_path")

# If no target model in config, return None
# This will require user to specify target model explicitly
if not target_model:
return None

return {
"target_model": target_model,
"method": method,
"num_tokens": num_tokens
}
except Exception as e:
from vllm.logger import init_logger
logger = init_logger(__name__)
logger.debug("Failed to extract speculators info from %s.",
model_path,
exc_info=e)
return None
Loading