Skip to content

[V1][Speculative Decoding] Fix DeepSeek MTP #20022

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 25, 2025
Merged

Conversation

cjackal
Copy link
Contributor

@cjackal cjackal commented Jun 24, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Fix #20021.

I need better understanding on multiple KV cache group scenario, so leave it as a future work.

Test Plan

Launch options:

vllm serve /app/model/DEEPSEEK-R1/ \
  --served-model-name deepseek-ai/deepseek-r1 \
  --gpu-memory-utilization 0.95 \
  --tensor-parallel-size 16 \
  --max-model-len 65536 \
  --max-num-batched-tokens 8192 \
  --reasoning-parser deepseek_r1 \
  --speculative-config '{"method":"deepseek_mtp","num_speculative_tokeens":1}'

Test request:

curl -XPOST http://localhost:8080/v1/chat/completions -H 'Content-Type: application/json' -d '{"model":"deepseek-ai/deepseek-r1","messages":[{"role":"user","content":"안녕?"}],"stream":true}'

Test Result

Server runs successfully and return:

data: {"id":"chatcmpl-763b14f9564f4048ac5388bc0d7e66c6","object":"chat.completion.chunk","created":1750763944,"model":"deepseek-ai/deepseek-r1","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-763b14f9564f4048ac5388bc0d7e66c6","object":"chat.completion.chunk","created":1750763944,"model":"deepseek-ai/deepseek-r1","choices":[{"index":0,"delta":{"reasoning_content":"Okay"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-763b14f9564f4048ac5388bc0d7e66c6","object":"chat.completion.chunk","created":1750763944,"model":"deepseek-ai/deepseek-r1","choices":[{"index":0,"delta":{"reasoning_content":", the"},"logprobs":null,"finish_reason":null}]}

data: {"id":"chatcmpl-763b14f9564f4048ac5388bc0d7e66c6","object":"chat.completion.chunk","created":1750763944,"model":"deepseek-ai/deepseek-r1","choices":[{"index":0,"delta":{"reasoning_content":" user said"},"logprobs":null,"finish_reason":null}]}

...

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jun 24, 2025
cjackal added 2 commits June 24, 2025 14:40
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @cjackal, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on fixing a specific issue with the DeepSeek MTP model in the context of speculative decoding. It includes changes to the model's embedding layer and attention metadata handling to resolve the reported problem. The PR also includes a test plan and results to demonstrate the fix.

Highlights

  • DeepSeek MTP Fix: Addresses issue #20021 related to DeepSeek MTP (Multiple KV cache group scenario) in speculative decoding. A more complete fix for multiple KV cache groups is planned for future work.
  • Embedding Layer: Added VocabParallelEmbedding to the DeepSeek MTP model to handle vocabulary embeddings.
  • Attention Metadata: Modified the propose function in eagle.py to correctly access attention metadata builders using index 0, which is a temporary fix for multiple KV cache groups.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request fixes an issue with DeepSeek MTP and addresses the multiple KV cache group scenario by selecting the first available builder. The FIXME comment is still present, and there's a potential redundancy in defining self.embed_tokens in DeepSeekMultiTokenPredictor.

Comment on lines +115 to +118
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Adding self.embed_tokens here seems redundant, as it's already defined in the DeepSeekMultiTokenPredictorLayer class. Consider if this is truly necessary or if it can be removed to avoid duplication.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this related to the bug? Does the MTP module have a separate vocab embedding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lack of vocab embedding module raises the 'DeepSeekMultiTokenPredictor' object has no attribute 'embed_tokens' attribute error, the first traceback in the linked issue.

Architecture-wise all the vocab embeddings are of the same shape as the target model, but we do need to keep vocab embeddings for each mtp layers if the target model has been trained with multiple mtp layers (not the case for official deepseek R1/V3 families though) and the user launches server with PP > 1. There is a similar condition check in EAGLE weight loading step.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjackal Thanks for the explanation! Can we use the target model's embedding when PP=1 and only allocate the weights when PP > 1?

Copy link
Contributor Author

@cjackal cjackal Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon Another thought after the discussion: the sole purpose of speculative decoding is leveraging the small draft model to get faster generation speed, so allowing pipeline parallelism on draft model weights is rare and kind of contradictory.

We may simply assume that all the MTP layers are on the same (last) pipeline component and always share the vocab embedding of the MTP layers with that of the target model. NVM, even if MTP module is not split, there is no guarantee that target model's embedding is on the same component. Let me just move the vocab embeddings from DeepSeekMultiTokenPredictorLayer to DeepSeekMultiTokenPredictor to share them among MTP layers and leave the sharing between target and draft embedding to the EAGLE draft model loading stage that I linked before.

Copy link
Collaborator

@WoosukKwon WoosukKwon Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjackal Got it. Thanks! Could you please re-run the test locally again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need more refactor on draft weight loading part; let me ping again when ready.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjackal Got it. Thanks!

cjackal added 3 commits June 25, 2025 00:55
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
@cjackal
Copy link
Contributor Author

cjackal commented Jun 25, 2025

According to the deepseek v3 tech report, the vocab embeddings are shared among target module and every MTP layers; indeed the weights 'model.layers.61.embed_tokens.weight' and 'model.embed_tokens.weight' coincide for official deepseek-r1 checkpoint. Thus I simply load the first MTP vocab embedding and when PP=1 leave it to be dropped by EAGLE weight loader.

I have checked that the model server functions okay and prints the following server log:

INFO 06-25 14:52:32 [logger.py:43] Received request chatcmpl-888c628f25024ed1a218733dedee24ff: prompt: '<|begin▁of▁sentence|><|User|>안녕?<|Assistant|><think>\n', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.95, top_k=0, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=131063, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: None, prompt_embeds shape: None, lora_request: None, prompt_adapter_request: None.
INFO: 127.0.0.1:46406 - "POST /v1/chat/completions HTTP/1.1" 200 OK
INFO 06-25 14:52:32 [async_llm.py:270] Added request chatcmpl-888c628f25024ed1a218733dedee24ff.
INFO 06-25 14:52:35 [loggers.py:118] Engine 000: Avg prompt throughput: 0.9 tokens/s, Avg generation throughput: 12.4 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.1%, Prefix cache hit rate: 0.0%
INFO 06-25 14:52:35 [metrics.py:87] SpecDecoding metrics: Draft acceptance rate: 78.3%, Mean acceptance length: 1.78, Accepted: 54 tokens, Drafted: 69 tokens, Per-position acceptance rate: 0.783
INFO 06-25 14:52:45 [loggers.py:118] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 12.8 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
INFO 06-25 14:52:45 [metrics.py:87] SpecDecoding metrics: Draft acceptance rate: 82.9%, Mean acceptance length: 1.83, Accepted: 58 tokens, Drafted: 70 tokens, Per-position acceptance rate: 0.829

@WoosukKwon PTAL

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Jun 25, 2025
@WoosukKwon WoosukKwon merged commit 8359f4c into vllm-project:main Jun 25, 2025
13 of 14 checks passed
m-misiura pushed a commit to m-misiura/vllm that referenced this pull request Jun 26, 2025
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
gmarinho2 pushed a commit to gmarinho2/vllm that referenced this pull request Jun 26, 2025
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
@cjackal
Copy link
Contributor Author

cjackal commented Jun 27, 2025

For better visibility I leave the issue referenced above as a comment as well; it seems the fix in this PR is not working after #19717 , which I'd detailed in issue #20186.

xjpang pushed a commit to xjpang/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
Signed-off-by: Will Eaton <weaton@redhat.com>
wseaton pushed a commit to wseaton/vllm that referenced this pull request Jun 30, 2025
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
wwl2755-google pushed a commit to wwl2755-google/vllm that referenced this pull request Jul 1, 2025
Signed-off-by: cjackal <44624812+cjackal@users.noreply.github.com>
@makabaka6338
Copy link

lanuch options:
python3 -m vllm.entrypoints.openai.api_server --model /model/DeepSeek-R1 --host 0.0.0.0 --port 8000 --max-num-seqs 1024 --tensor-parallel-size 8 --speculative-config '{"method":"deepseek_mtp","num_speculative_tokeens":1}'
I am testing the scenario of Chinese-English translation, with each sentence not exceeding 32 tokens, testing under different concurrency levels. However, when my concurrency reaches 64 or higher, the latency of vllm becomes very high, and the throughput is extremely low.
machine: H20-3e 141G * 8
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: [V1] DeepSeek MTP is broken since 0.9.0
3 participants