Skip to content

Spec decode support for V1 Engine #874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025

Conversation

ponix-j
Copy link
Contributor

@ponix-j ponix-j commented May 15, 2025

What this PR does / why we need it?

Make spec decode support for V1 Engine

  • Currently, Ascend does not support the triton kernel. PyTorch is used to rewrite the rejection_sampler.py triton kernel. However, PyTorch is not as good as Triton. Therefore, ascend c is used to implement the function in the future.
  • Currently, spec decode supports only the ngram algorithm. The eagle algorithm needs to be further adapted.

Does this PR introduce any user-facing change?

Not change user facing.

How was this patch tested?

test by tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py and tests/sample/test_rejection_sampler.py, test base function of rejection sampler and e2e function of spec decode.

@ponix-j ponix-j changed the title Spec v0.8.5rc1 Spec decode v0.8.5rc1 May 15, 2025
@@ -0,0 +1,587 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rejection_sampler and eagle_proposer fully overwrite,
Can't this be implemented in a way that the main body uses vllm code and the problematic part is patched in vllm_ascend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use class AscendRejectionSampler Inheritance class RejectionSampler,Extracting the part of the change

@mengwei805
Copy link
Collaborator

spec decode is key feature, can u add ngram and eagle e2e ut in vllm-ascend?

@ponix-j
Copy link
Contributor Author

ponix-j commented May 16, 2025

spec decode is key feature, can u add ngram and eagle e2e ut in vllm-ascend?

yes,already added test_spec_decode.py

@mengwei805
Copy link
Collaborator

pls rebase u all commits to 1 commit

@ponix-j ponix-j force-pushed the spec_v0.8.5rc1 branch 2 times, most recently from 9f5ae8a to ed32fcf Compare May 19, 2025 04:04
@github-actions github-actions bot added documentation Improvements or additions to documentation module:ops module:quantization labels May 19, 2025
@github-actions github-actions bot removed documentation Improvements or additions to documentation module:ops module:quantization labels May 19, 2025
@ponix-j ponix-j force-pushed the spec_v0.8.5rc1 branch 2 times, most recently from 9e7bc8b to aa2bb74 Compare May 20, 2025 01:16

@pytest.fixture
def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use modelscope instead

Suggested change
return "meta-llama/Meta-Llama-3-8B-Instruct"
return "LLM-Research/Meta-Llama-3.1-8B-Instruct"


@pytest.fixture
def eagle_model_name():
return "yuhuili/EAGLE-LLaMA3-Instruct-8B"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not recommand to make a seperate patch dir for spec-decode. Let's move these patch to vllm_ascend/patch/platform or vllm_ascend/patch/worker, depending on which period is appropriate.

And the most important, please make comments in vllm_ascend/patch/__init__.py to describe why we make this patch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allready fixed

# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]

# [0, a, a + b, a + b + c] -> [a, b, c]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz remove the useless comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is copy from vllm/eagle.py, why this useless?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mistakenly thought it was a cuda-specific comment, please ignore it

@@ -100,6 +100,7 @@ def adapt_patch(is_global_patch: bool = False):
if is_global_patch:
from vllm_ascend.patch import platform # noqa: F401
else:
from vllm_ascend.patch import spec_decode # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove this when the above suggestions on patch is solved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already fixed

# [0, 1, 2, 5, 6, 9]
target_logits_indices += arange

# TODO: Optimize the CPU -> GPU copy.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO: Optimize the CPU -> GPU copy.
# TODO: Optimize the CPU -> NPU copy.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already fixed

@@ -737,12 +888,92 @@ def execute_model(
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = sampled_token_ids.tolist()
else:
# Includes spec decode tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we extract the model execution code into a seperate function to make code clear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already fixed

@ponix-j
Copy link
Contributor Author

ponix-j commented May 21, 2025

pls rebase u all commits to 1 commit

will use 1 commit finally

@wangxiyuan wangxiyuan changed the title Spec decode v0.8.5rc1 Spec decode support for V1 Engine May 21, 2025
@ponix-j ponix-j force-pushed the spec_v0.8.5rc1 branch 3 times, most recently from b408e85 to a86f051 Compare May 22, 2025 12:13
Copy link
Collaborator

@wangxiyuan wangxiyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basiclly I'm fine with the change. Just some nit. Thanks.

# Re-implementation the `prepare_input_kernel` triton kernel by pytorch
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
# - https://github.com/vllm-project/vllm-ascend/pull/874
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the PR for vllm. the content can be changed to ascend doesn't support triton

@@ -141,24 +141,22 @@ def reorder_batch(self, input_batch: "InputBatch",

def build(self, num_reqs, num_actual_tokens, max_query_len,
common_prefix_len):
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need rebase, this change is merge to main already. 7aa4f85

return output_token_ids


def expand_batch_to_tokens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where does this func called?

assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.method == "eagle":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note/todo here to highlight that eagle mode doesn't work currently.
And if eagle doesn't work now, i think we'd better raise error here directly. And complete the function in the future PR.

 elif self.speculative_config.method == "eagle":
    raise NotImplementedError("eagle method for spec decode doesn't work on vllm-ascend currently")

@@ -220,11 +218,11 @@ def forward(
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads * head_size]
num_kv_heads, head_size]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -687,6 +811,92 @@ def apply_grammar_bitmask(
)
return logits.to(self.device).to(logits_dtype)

def get_spec_token_ids(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to _get_spec_token_ids

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -1083,3 +1344,35 @@ def capture_model(self) -> None:
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, npu_graph_size / (1 << 30))

def generate_draft_token_ids(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Signed-off-by: ponix-j <657511300@qq.com>
target_probs[token_idx, draft_token_id] = orig_prob


rs.expand_batch_to_tokens = expand_batch_to_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be moved to patch module

valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.method == "eagle":
raise NotImplementedError(
"eagle method for spec decode doesn't work on vllm-ascend currently"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add eagle support in the future

@ganyi1996ppo ganyi1996ppo merged commit df58fb8 into vllm-project:main May 23, 2025
16 checks passed
wangxiyuan pushed a commit that referenced this pull request May 30, 2025
### What this PR does / why we need it?
add basic v1 mtp features
please merge it after
#874 and
#844.

### Does this PR introduce _any_ user-facing change?
now, we supported basic v1 mtp, only supported tp only、eager mode and
k=1
we will continue to expand more scenarios.

### How was this patch tested?
local tested

Signed-off-by: XWFAlone <xuewenfei2@huawei.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
momo609 pushed a commit to momo609/vllm-ascend that referenced this pull request May 30, 2025
<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->
Make spec decode support for V1 Engine
- Currently, Ascend does not support the triton kernel. PyTorch is used
to rewrite the `rejection_sampler.py` triton kernel. However, PyTorch is
not as good as Triton. Therefore, ascend c is used to implement the
function in the future.
- Currently, spec decode supports only the ngram algorithm. The eagle
algorithm needs to be further adapted.
### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
Not change user facing.

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->
test by `tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py` and
`tests/sample/test_rejection_sampler.py`, test base function of
rejection sampler and e2e function of spec decode.

Signed-off-by: ponix-j <657511300@qq.com>
Signed-off-by: wangxiaoxin (A) <w00664509@china.huawei.com>
momo609 pushed a commit to momo609/vllm-ascend that referenced this pull request Jun 3, 2025
### What this PR does / why we need it?
add basic v1 mtp features
please merge it after
vllm-project#874 and
vllm-project#844.

### Does this PR introduce _any_ user-facing change?
now, we supported basic v1 mtp, only supported tp only、eager mode and
k=1
we will continue to expand more scenarios.

### How was this patch tested?
local tested

Signed-off-by: XWFAlone <xuewenfei2@huawei.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
Signed-off-by: wangxiaoxin (A) <w00664509@china.huawei.com>
David9857 pushed a commit to David9857/vllm-ascend that referenced this pull request Jun 3, 2025
### What this PR does / why we need it?
add basic v1 mtp features
please merge it after
vllm-project#874 and
vllm-project#844.

### Does this PR introduce _any_ user-facing change?
now, we supported basic v1 mtp, only supported tp only、eager mode and
k=1
we will continue to expand more scenarios.

### How was this patch tested?
local tested

Signed-off-by: XWFAlone <xuewenfei2@huawei.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
Yikun added a commit that referenced this pull request Jun 8, 2025
…nc graph typo fix (#1121)

### What this PR does / why we need it?
1. The dependency was introduced by
#874
- Move numba/quart from requirements-dev to requirments
- Align pyproject.toml with requirements

2. This patch also fix deepseek accuracy baseline which
#1118 was not addressed.
According to https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite the
gsm8k is about `41.1`

3. This also sync the vLLM upstream changes:
vllm-project/vllm@eaa2e51

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed
vllm ascend test (basic workflow)
vllm longterm test (spec decode)

Closes: #1120

---------

Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants