-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Meta] Llama4 EAGLE Support #20591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Meta] Llama4 EAGLE Support #20591
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @morgendave, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances VLLM's capabilities by integrating speculative decoding support for Llama4 models, specifically designed to work with a dense-only draft model. It introduces a new model implementation tailored for this purpose and updates an example script to demonstrate its usage, thereby expanding the range of supported models for efficient inference.
Highlights
- Llama4 EAGLE Support: This pull request introduces core support for Llama4 models within the EAGLE speculative decoding framework. This enables the use of a dense-only draft model specifically designed for Llama4, leveraging official Meta-based support.
- New Model Implementation: A new file,
vllm/model_executor/models/llama4_eagle.py
, has been added. This file definesLlamaModel
, a specialized draft model for Llama4, andEagleLlama4ForCausalLM
, which integrates this draft model for speculative decoding. It includes custom weight loading logic for stacked parameters and handles configuration validation for features like YoCo and MoE. - Model Registration: The newly implemented
EagleLlama4ForCausalLM
has been registered invllm/model_executor/models/registry.py
at line 242, making it discoverable and usable by the VLLM system. - Example Script Updates: The
examples/offline_inference/spec_decode.py
script has been updated to facilitate testing of the new Llama4 EAGLE support. Specifically,gpu_memory_utilization
was adjusted from 0.8 to 0.7 (line 84), andmax_model_len=16384
was added to the engine initialization (line 87).
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for EAGLE speculative decoding for Llama4 models. The changes look good overall, with a new model file vllm/model_executor/models/llama4_eagle.py
implementing the draft model logic, a corresponding entry in the model registry, and updates to an example script.
My review focuses on the new llama4_eagle.py
file. I've identified a few areas for improvement regarding code style, maintainability, and memory efficiency. Specifically, I've suggested removing unnecessary line continuations, addressing a potentially fragile inheritance pattern, adding a missing type hint, and refactoring a weight loading function to be more memory-efficient. These changes should improve the quality and robustness of the new implementation.
self.config = vllm_config. \ | ||
speculative_config.draft_model_config.hf_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class EagleLlama4ForCausalLM(Llama4ForCausalLM): | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
nn.Module.__init__(self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__
method of EagleLlama4ForCausalLM
calls nn.Module.__init__(self)
directly, bypassing the initializer of its base class, Llama4ForCausalLM
. This is a fragile design because if Llama4ForCausalLM
's __init__
sets up important state that is used by inherited methods (like permute_qk_weight_for_rotary
), this implementation could break.
To improve maintainability, consider either using composition over inheritance or ensuring a proper call to super().__init__(...)
.
self.config = vllm_config. \ | ||
speculative_config.draft_model_config.hf_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> tuple[torch.Tensor, torch.Tensor]: | ||
return self.model(input_ids, positions, hidden_states) | ||
|
||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_weights = {} | ||
weights = [ | ||
self.permute_qk_weight_for_rotary(name, loaded_weight) | ||
for name, loaded_weight in weights | ||
] | ||
for name, loaded_weight in weights: | ||
if "lm_head" not in name: | ||
name = "model." + name | ||
model_weights[name] = loaded_weight | ||
|
||
loader.load_weights(model_weights.items()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block of code for processing weights is not memory-efficient. It first creates a list of all permuted weights, and then a dictionary of these weights, both of which can consume a large amount of memory for large models.
A more memory-efficient approach is to use a generator to process the weights one by one. This avoids loading all weights into memory at once.
def _processed_weights():
for name, loaded_weight in weights:
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
if "lm_head" not in name:
name = "model." + name
yield name, loaded_weight
loader.load_weights(_processed_weights())
Yes, we are pretty aware of this, that's why from Meta side we want to upstream this first in order to avoid conflicts as we are working on official supports. Also our official designs for llama4_eagle is dense based to get best performance. |
@@ -0,0 +1,199 @@ | |||
# SPDX-License-Identifier: Apache-2.0 | |||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add the copyright from Meta side?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely, thanks for the suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
speculative_config=speculative_config, | ||
disable_log_stats=False, | ||
max_model_len=16384, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed due to OOM issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this might not be needed, it is just from our internal test, though I think default is smaller than this?
Yes this is for OOM as original length with BF16 will be too big
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw the test plan, we still need force_eager, does it mean the CUDA graph or torch.compile still doesn't work yet?
@@ -0,0 +1,199 @@ | |||
# SPDX-License-Identifier: Apache-2.0 | |||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope, just copied from @zixi-qi's runbook. I can delete that and should work |
Hi @morgendave , I'm the author of #18369 . I noticed that our PRs are for the same purpose and have similar code, like the padding of no_rope_layers. The eagle head that I uploaded is also dense only (for the best performance). I noticed that your eagle model has 3 decoder layers, is this usually the case, or maybe it just has better performance? I don't think that is compatible with my code. I also tried running your code straight from your PR (target model: scout, draft model: scout) and changing the max_model_len to a smaller number to avoid OOM; it runs but it doesn't seem to give me acceptance>1. Maybe there is some code still missing? I'll try to run it again. I would also like to call out that my PR is almost 2 months old at this point and I feel like it is fair to merge my PR and then maybe you can build on top of it (I'm pretty sure it works for standard eagle models)? I recognize that you have added some quantization support and qk perm for the rotary which will be great additions. Thanks |
Sorry but this is the Meta's official support, it's also going to be followed up with MM and other support for next generation model so we have to merge this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can add the following in the commit message so the author is set correctly
Co-authored-by: Zixi Qi qizixi@meta.com
Could you include the vllm serve
command you use add some TTFT/TTIT numbers to the original PR? What's the E2E speed up we see?
tests/v1/e2e/test_spec_decode.py
Outdated
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), | ||
("eagle3", "meta-llama/Llama-3.1-8B-Instruct", | ||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), | ||
("eagle", "/home/zhiweiz/local/models/scout_base_HF_20250605_201140", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update to publicmodel repo
@@ -81,9 +81,10 @@ def main(): | |||
tensor_parallel_size=args.tp, | |||
enable_chunked_prefill=args.enable_chunked_prefill, | |||
enforce_eager=args.enforce_eager, | |||
gpu_memory_utilization=0.8, | |||
gpu_memory_utilization=0.7, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this the issue fixed in #20628 when max batch size is configured
1cf4064
to
8894e8f
Compare
Signed-off-by: qizixi <qizixi@meta.com>
8894e8f
to
4238b3a
Compare
tests/v1/e2e/test_spec_decode.py
Outdated
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), | ||
("eagle3", "meta-llama/Llama-3.1-8B-Instruct", | ||
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), | ||
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if this will triggered in CI that might be OOM in CI jobs? https://github.com/vllm-project/vllm/blob/main/.buildkite/test-pipeline.yaml#L265 cc @WoosukKwon
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes the CI can't run this. I think we can keep the code, but do pytest.skip
for Llama4 so that we can easily test it locally?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
Please fix the failing tests |
Head branch was pushed to by a user without write access
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
def forward( | ||
self, | ||
input_ids: Optional[torch.Tensor], | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
) -> tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will MM input support be added later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep that's in #20591
…xt only model to be registered
Purpose
Support EAGLE speculative decoding with dense-only draft model for Llama4, using official Meta based support
Original Author: @zixi-qi
Test Plan
Ran with a uploaded scout based eagle to test E2E
Example cmd
unit test: python -m pytest tests/v1/e2e/test_spec_decode.py
vllm serve + benchmarking
EAGLE server cmd
base cmd = eagle server cmd, removing
--speculative-config="$spec_dec_config" \
benchmarking
Test Result
unit test passed
EAGLE Bechmark
Accepted tokens average: 2.75-2.95
Baseline
Follow up
Upload draft model to huggingface
Scout based Draft: morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct
Pending: Maverick Draft: morgendave/EAGLE-Llama-4-Maverick-17B-128E-Instruct