Skip to content

[Meta] Official Eagle mm support, first enablement on llama4 #20788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

morgendave
Copy link
Collaborator

@morgendave morgendave commented Jul 10, 2025

Purpose

Enable MM inference for EAGLE, targeting mllama4 in this PR but generally easy to extend to other models
Issue with this PR:
MM chunked prefill needs to be disabled, or set mbnt to a large number. There will be a follow up PR using unshifting eagle prefill fix this

Test Plan

CUDA_VISIBLE_DEVICES=4,5,6,7 VLLM_USE_V1=1 python examples/offline_inference/spec_decode.py  --num_spec_tokens 7 --num_prompts 1 --method eagle  --model_dir /home/zhiweiz/local/models/scout_base_HF_20250605_201140 --eagle_dir /home/zhiweiz/local/models/scout_draft_HF_20250605_202942 --tp 4 --custom-mm-prompts

vllm serve with benchmark testing
cmd

#!/bin/bash
# Configuration of environment variables
export VLLM_USE_MODELSCOPE=False
export VLLM_TORCH_PROFILER_DIR=~/vllm_profile
export CUDA_VISIBLE_DEVICES=4,5,6,7
export VLLM_USE_V1=1
export SAFETENSORS_FAST_GPU=1
# Command to run the vllm server
spec_dec_config='{"method": "eagle", "model": "/home/zhiweiz/local/models/scout_draft_HF_20250605_202942", "prefill_token_shift": false, "num_speculative_tokens": 3, "draft_tensor_parallel_size": 4, "max_model_len": 32768}'
vllm serve /home/zhiweiz/local/models/scout_base_HF_20250605_201140 --disable-log-requests \
    -tp 4 \
    --max-num-seqs 128 \
    --max_num_batched_tokens=80000 \
    --max-model-len=32768 \
    --no-enable-prefix-caching \
    --trust-remote-code \
    --speculative-config="$spec_dec_config" \
    --num-lookahead-slots=3 \
    2>&1 | tee /data/users/$USER/logs/server/vllm_17b16e_vllm_serving$(date +%Y%m%d_%H%M%S).log

baseline without --speculative-config flag

benchmark cmd

python benchmarks/benchmark_serving.py --backend openai-chat --model /home/zhiweiz/local/models/scout_base_HF_20250605_201140 --dataset-name hf --dataset-path lmarena-ai/VisionArena-Chat --seed 0 --max-concurrency 16 --endpoint /v1/chat/completions 2>&1 | tee /data/users/$USER/tmp/vllm_17b16e_vllm_loadgen$(date +%Y%m%d_%H%M%S).log\n

Test Result

--------------------------------------------------
total_num_output_tokens: 155
num_drafts: 45
num_draft_tokens: 315
num_accepted_tokens: 110
mean acceptance length: 3.44
--------------------------------------------------
acceptance at token 0: 0.87
acceptance at token 1: 0.69
acceptance at token 2: 0.47
acceptance at token 3: 0.20
acceptance at token 4: 0.16
acceptance at token 5: 0.04
acceptance at token 6: 0.02

Eagle MM benchmark

Maximum request concurrency: 16
100%|██████████| 1000/1000 [03:16<00:00,  5.10it/s]
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  196.05
Total input tokens:                      87321
Total generated tokens:                  117053
Request throughput (req/s):              5.10
Output token throughput (tok/s):         597.04
Total Token throughput (tok/s):          1042.43
---------------Time to First Token----------------
Mean TTFT (ms):                          201.32
Median TTFT (ms):                        164.53
P99 TTFT (ms):                           797.28
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          26.02
Median TPOT (ms):                        24.69
P99 TPOT (ms):                           48.73
---------------Inter-token Latency----------------
Mean ITL (ms):                           58.54
Median ITL (ms):                         40.44
P99 ITL (ms):                            160.04
==================================================
Maximum request concurrency: 16
100%|██████████| 1000/1000 [04:02<00:00,  4.12it/s]
============ Serving Benchmark Result ============
Successful requests:                     1000
Benchmark duration (s):                  242.87
Total input tokens:                      87321
Total generated tokens:                  116716
Request throughput (req/s):              4.12
Output token throughput (tok/s):         480.58
Total Token throughput (tok/s):          840.12
---------------Time to First Token----------------
Mean TTFT (ms):                          187.59
Median TTFT (ms):                        157.39
P99 TTFT (ms):                           826.26
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          32.00
Median TPOT (ms):                        31.85
P99 TPOT (ms):                           37.72
---------------Inter-token Latency----------------
Mean ITL (ms):                           31.55
Median ITL (ms):                         22.52
P99 ITL (ms):                            131.54
==================================================

Follow ups

Need to make offline inference work with vision datasets

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @morgendave, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces comprehensive multimodal inference support for the EAGLE speculative decoding method, with an initial focus on Llama4 models. It enables the EAGLE draft model to process both text and image inputs, significantly enhancing its capability for multimodal large language models. The changes span from core model implementations to the speculative decoding pipeline and include new examples and robust test coverage for multimodal scenarios.

Highlights

  • Multimodal EAGLE Support: Enables multimodal (MM) inference capabilities for the EAGLE speculative decoding method, specifically targeting Llama4 models, with a design that allows for future extension to other multimodal models.
  • Llama4 EAGLE Draft Model: Introduces EagleLlama4ForCausalLM and LlamaModel to provide a dedicated draft model implementation for Llama4, capable of processing combined text and image embeddings for speculative decoding.
  • Speculative Decoding Pipeline Integration: Integrates multimodal input handling into the core EAGLE speculative decoding pipeline, allowing the draft model to receive and process image embeddings alongside text tokens for more accurate proposals.
  • Multimodal Testing & Examples: Adds a new example script (spec_decode.py) for running offline multimodal inference with custom image prompts and extends the end-to-end test suite to include multimodal EAGLE correctness tests for Llama4.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added documentation Improvements or additions to documentation llama Related to Llama models new-model Requests to new models speculative-decoding v1 labels Jul 10, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces multi-modal (MM) support for EAGLE speculative decoding, with an initial implementation for the llama4 model architecture. The changes are comprehensive, touching upon example scripts, E2E tests, and core model execution logic.

Overall, the implementation for llama4_eagle looks solid. However, I've identified a few issues:

  • Critical: The E2E tests contain hardcoded local paths, which will break CI and prevent other developers from running the tests.
  • High: The llama_eagle and llama_eagle3 models have been updated to accept an inputs_embeds parameter, but they don't actually use it. This is misleading and should be fixed to either fully support MM or explicitly disallow it.

Once these issues are addressed, the PR should be in good shape.

Comment on lines 132 to 135
(("eagle", "/home/zhiweiz/local/models/scout_base_HF_20250605_201140",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), False),
(("eagle", "/home/zhiweiz/local/models/scout_base_HF_20250605_201140",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), True),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The test test_eagle_correctness includes hardcoded local paths to a model (/home/zhiweiz/local/models/scout_base_HF_20250605_201140). This makes the test non-portable and will cause it to fail in CI environments and on other developers' machines. Please replace this with a model from the Hugging Face Hub or use a mechanism to download test-specific models.

Comment on lines 152 to 158
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The inputs_embeds parameter is added to the forward method's signature but is not used within the method body. The call to self.model() doesn't pass this parameter along, which means multimodal inputs will be ignored.

This is inconsistent with the implementation for llama4_eagle and suggests that multimodal support is incomplete for this model. If multimodal input is not supported for this model, it would be better to raise a NotImplementedError when inputs_embeds is provided. If it is intended to be supported, inputs_embeds should be passed to self.model and handled there.

        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if inputs_embeds is not None:
            raise NotImplementedError(
                f"{type(self).__name__} does not support multimodal inputs yet.")
        return self.model(input_ids, positions, hidden_states)

Comment on lines 205 to 211
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to llama_eagle.py, the inputs_embeds parameter is added to the forward method's signature but is not used. This makes the multimodal support for this model incomplete and potentially buggy if a user tries to use it with multimodal inputs.

Please either fully implement the handling of inputs_embeds or raise a NotImplementedError if it's not None to prevent silent failures.

        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        if inputs_embeds is not None:
            raise NotImplementedError(
                f"{type(self).__name__} does not support multimodal inputs yet.")
        return self.model(input_ids, positions, hidden_states)

@DarkLight1337
Copy link
Member

Is #20591 supposed to be merged first?

@morgendave
Copy link
Collaborator Author

Is #20591 supposed to be merged first?

Yes, this would be rebased after that

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation llama Related to Llama models new-model Requests to new models speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants