Skip to content

[V1][Spec Decode][Feature] Spec decode with probs #20459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

andylolu2
Copy link
Contributor

@andylolu2 andylolu2 commented Jul 4, 2025

Summary

This PR enables speculative decoding with draft probabilities in V1. A partial revert of #16899.

Implementation

Blocker of #16899 was the draft probabilities aren't used immediately, so we need to keep them around for the next iteration. In this PR, I propose we add draft probabilities as part of the CachedRequestState. On the next iteration where the draft token ids of a request is used, we fetch the draft probs from the cached state. This greatly simplifies the matter:

  • CachedRequestState already encapsulates the logic that data related to a request isn't necessarily used immediately. For example, this handles preemption / the problem where requests might not be scheduled on every step.
  • Since the draft probs are tied to the cached state, they will be deallocated when the cached state is deleted as well, so little risk of memory leak (as opposed to managing a new cache).

Benchmark

Numbers obtained by running the follow on current main (14601f5) vs this branch:

VLLM_USE_V1=1 python3 examples/offline_inference/spec_decode.py \
    --method eagle \
    --num_spec_tokens 4 \
    --dataset-path mt_bench.jsonl \
    --num_prompts 100 \
    --temp <T>

Mean acceptance length vs temperature

Online benchmarks

Ran with (numbers averaged across 3 runs):

VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct \
  --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE(3)-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 4}'

vllm bench serve \
  --model meta-llama/Llama-3.1-8B-Instruct \
  --endpoint-type openai-chat \
  --endpoint /v1/chat/completions \
  --dataset-name hf \
  --dataset-path philschmid/mt-bench \
  --num-prompts 100 \
  --max-concurrency 16 \
  --temperature <T> \
  --top-p 1.0

EAGLE

Temperature TPOT w/o probs TPOT w/ probs Diff (%)
0 4.896666667 4.863333333 -0.680735194
0.3 5.08 5.32 4.724409449
0.5 5.146666667 5.293333333 2.849740933
0.7 5.216666667 5.326666667 2.108626198
1 5.89 5.673333333 -3.678551217
1.3 8.673333333 6.19 -28.63182168

EAGLE3

Temperature TPOT w/o probs TPOT w/ probs Diff (%)
0 3.513333333 3.533333333 0.569259962
0.3 3.666666667 3.813333333 4
0.5 3.753333333 3.82 1.776198934
0.7 3.886666667 3.886666667 0
1 4.596666667 4.323333333 -5.946337926
1.3 7.933333333 5.793333333 -26.97478992

Explanation:

  • Overall acceptance rate is always higher after this PR, but only noticeable when temperature is sufficiently large (> 0.7).
  • As temperature grows larger, acceptance rate also grows. This is because both the draft and target distribution is getting more and more uniform, hence more similar.
  • Sampler overhead means the delta is not super large, but this can be optimized in a follow up.

Copy link

github-actions bot commented Jul 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @andylolu2, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the speculative decoding feature, particularly for the Eagle model, by introducing the capability to generate and utilize probability distributions alongside proposed draft tokens. This allows for more sophisticated and accurate sampling during the drafting phase, improving the overall efficiency and quality of the speculative decoding process. Additionally, new tests have been introduced to ensure the stability of speculative decoding under complex scheduling conditions.

Highlights

  • Speculative Decoding with Probabilities: The core change enables the Eagle speculative decoding method to not only propose draft tokens but also return their associated probability distributions. This moves beyond simple greedy (argmax) token generation to allow for more nuanced sampling strategies during drafting.
  • Integration of Draft Probabilities: The GPUModelRunner now stores these draft probabilities per request in the CachedRequestState and passes them to the rejection sampler. This is crucial for the main model to efficiently verify and accept/reject the drafted tokens based on their likelihoods.
  • Enhanced Testing for Speculative Decoding: New comprehensive tests have been added to ensure the robustness of speculative decoding under various scheduling scenarios, including partial scheduling and request preemption. This validates the stability of the new probability-aware drafting mechanism.
  • Refactored Draft Proposal Interface: The propose_draft_token_ids method in GPUModelRunner has been renamed to propose_draft and its signature updated to explicitly return both the drafted token IDs and their probabilities, standardizing the interface for different speculative decoding methods.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for returning probabilities along with tokens in speculative decoding, which is a key requirement for rejection sampling. The changes are well-implemented across the gpu_model_runner and eagle proposer, with corresponding updates to tests. A new integration test for scheduling with speculative decoding is also a great addition. I have one minor suggestion to improve comment clarity for future maintainability.

@andylolu2 andylolu2 force-pushed the andy/v1-sd-with-probs branch from cbf0f7f to f79b62b Compare July 4, 2025 00:24
andylolu2 and others added 2 commits July 8, 2025 00:03
Signed-off-by: Andy Lo <andy@mistral.ai>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Andy Lo <andy@mistral.ai>
@andylolu2 andylolu2 force-pushed the andy/v1-sd-with-probs branch from f79b62b to 5937e7b Compare July 8, 2025 00:36
@andylolu2
Copy link
Contributor Author

@WoosukKwon PR is ready for review

@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Jul 8, 2025

Hi @andylolu2 - thanks for this PR! The AL improvement looks good.

One reason for using the argmax was to reduce the overhead of sampling from draft. TRTLLM and TGI also uses argmax. Could you also run the online benchmark on MTBench to see how the e2e gains in TPOT metric due to EAGLE look like with and without this PR? Example cmd: #18847

Signed-off-by: Andy Lo <andy@mistral.ai>
@andylolu2
Copy link
Contributor Author

Hi @andylolu2 - thanks for this PR! The AL improvement looks good.

One reason for using the argmax was to reduce the overhead of sampling from draft. TRTLLM and TGI also uses argmax. Could you also run the online benchmark on MTBench to see how the e2e gains in TPOT metric due to EAGLE look like with and without this PR? Example cmd: #18847

@ekagra-ranjan I've added online benchmark numbers. The difference is not massive (results a bit noisy in general) at low temperature. I think it will still be good to merge this since:

  1. Sampling can be improved quite easily (e.g. w/ a triton kernel)
  2. The overall profile is more robust to changes in temperature.

@ekagra-ranjan
Copy link
Contributor

@andylolu2 - thanks for adding the plot with TPOT. Can you also share the absolute numbers and the relative gain/degradation for different Temp?

@andylolu2
Copy link
Contributor Author

@andylolu2 - thanks for adding the plot with TPOT. Can you also share the absolute numbers and the relative gain/degradation for different Temp?

The absolute numbers are in the "Raw numbers" collapse under the plot.

@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Jul 8, 2025

The drop in gains on low temp like 0.3 is ~10% which is a lot. Cohere for e.g., uses 0.3 as the default temperature. Coding and reasoning tasks usually use lower temperature which becomes even more important with thinking/reasoning models.

Can we have add an if-else to use the argmax method when the engine is using temp is < 0.75 to preserve the perf for these scenarios?

@andylolu2
Copy link
Contributor Author

The drop in gains on low temp like 0.3 is ~10% which is a lot. Cohere for e.g., uses 0.3 as the default temperature. Coding and reasoning tasks usually use lower temperature which becomes even more important with thinking/reasoning models.

Can we have add an if-else to use the argmax method when the engine is using temp is < 0.75 to preserve the perf for these scenarios?

I think in general it's highly model-specific choice of what sampling temperature you should use for the draft model.
Sometimes you want to match the temperature of the target, sometimes you want to use higher/lower.

I suggest we make the threshold T configurable with the following heuristic:

  1. When target temperature < T, sample draft model with temperature=0 (argmax).
  2. When target tempearture >= T, sample draft model with temperature=T.

Does that sound reasonable to you?

@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Jul 8, 2025

Yes, we can have the threshold T as a parameter. Perhaps the default value should be 0.75 based on your results instead of having it a required param.

I think in general it's highly model-specific choice of what sampling temperature you should use for the draft model.
Sometimes you want to match the temperature of the target, sometimes you want to use higher/lower.

Oh, maybe I missed it but in your experiment are you using different temp for draft from target?

@andylolu2
Copy link
Contributor Author

andylolu2 commented Jul 8, 2025

Oh, maybe I missed it but in your experiment are you using different temp for draft from target?

I'm using the same temperature for both.

Signed-off-by: Andy Lo <andy@mistral.ai>
@andylolu2
Copy link
Contributor Author

andylolu2 commented Jul 9, 2025

@ekagra-ranjan I realised it's actually quite difficult to make my proposal work. Problem is the rejection sampler does not allow both:

  1. Partially "unset draft probs" (i.e. some draft probs are filler values); and
  2. Greedy draft sampling when target sampling is not greedy.

To make this work would require some large-ish amount of change to the rejection sampler, and that would be a rabbit hole to make sure I don't introduce unwanted overheads.

Instead I've optimized the draft model sampler a bit, the overhead is in the worst case ~4% now. New numbers updated in the PR description.

@ekagra-ranjan
Copy link
Contributor

ekagra-ranjan commented Jul 9, 2025

Problem is the rejection sampler does not allow both:

Do you think its simpler to select the old argmax approach only if all of the req in batch have temp below T OR instead select the new approach if all req in batch have temp above T?

Instead I've optimized the draft model sampler a bit, the overhead is in the worst case ~4% now. New numbers updated in the PR description.

Nice, Could you pls share which line of code/commit does this?

@andylolu2
Copy link
Contributor Author

andylolu2 commented Jul 9, 2025

Do you think its simpler to select the old argmax approach only if all of the req in batch have temp below T OR instead select the new approach if all req in batch have temp above T?

I don't see an easy way. Due to the fact that drafts are not used immediately (usually the step right after, but can be in theory arbitrarily later on due to preemption), even if we do the fallback argmax approach during drafting, we might still end up with some requests with and some requests without draft probs during verification. A change to the rejection sampler would be needed.

Nice, Could you pls share which line of code/commit does this?

Forgot to push hahaa. It's this commit 20e43fd.

Also would like to add that the sampling overhead will be very negligible for larger models (e.g. DeepSeekV3/R1), so they should benefit a lot more from the increase in AL.

@andylolu2
Copy link
Contributor Author

@ekagra-ranjan I've updated the PR

@andylolu2
Copy link
Contributor Author

cc @WoosukKwon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants