-
Notifications
You must be signed in to change notification settings - Fork 256
Spec decode support for V1 Engine #874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -0,0 +1,587 @@ | |||
# SPDX-License-Identifier: Apache-2.0 | |||
from typing import Optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rejection_sampler and eagle_proposer fully overwrite,
Can't this be implemented in a way that the main body uses vllm code and the problematic part is patched in vllm_ascend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use class AscendRejectionSampler Inheritance class RejectionSampler,Extracting the part of the change
spec decode is key feature, can u add ngram and eagle e2e ut in vllm-ascend? |
yes,already added test_spec_decode.py |
pls rebase u all commits to 1 commit |
9f5ae8a
to
ed32fcf
Compare
9e7bc8b
to
aa2bb74
Compare
|
||
@pytest.fixture | ||
def model_name(): | ||
return "meta-llama/Meta-Llama-3-8B-Instruct" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use modelscope instead
return "meta-llama/Meta-Llama-3-8B-Instruct" | |
return "LLM-Research/Meta-Llama-3.1-8B-Instruct" |
|
||
@pytest.fixture | ||
def eagle_model_name(): | ||
return "yuhuili/EAGLE-LLaMA3-Instruct-8B" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@@ -0,0 +1,70 @@ | |||
# SPDX-License-Identifier: Apache-2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not recommand to make a seperate patch dir for spec-decode. Let's move these patch to vllm_ascend/patch/platform
or vllm_ascend/patch/worker
, depending on which period is appropriate.
And the most important, please make comments in vllm_ascend/patch/__init__.py
to describe why we make this patch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allready fixed
# a, a + 1, ..., a + b - n2 - 1, | ||
# a + b, a + b + 1, ..., a + b + c - n3 - 1] | ||
|
||
# [0, a, a + b, a + b + c] -> [a, b, c] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plz remove the useless comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is copy from vllm/eagle.py, why this useless?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mistakenly thought it was a cuda-specific comment, please ignore it
vllm_ascend/utils.py
Outdated
@@ -100,6 +100,7 @@ def adapt_patch(is_global_patch: bool = False): | |||
if is_global_patch: | |||
from vllm_ascend.patch import platform # noqa: F401 | |||
else: | |||
from vllm_ascend.patch import spec_decode # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove this when the above suggestions on patch is solved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already fixed
# [0, 1, 2, 5, 6, 9] | ||
target_logits_indices += arange | ||
|
||
# TODO: Optimize the CPU -> GPU copy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# TODO: Optimize the CPU -> GPU copy. | |
# TODO: Optimize the CPU -> NPU copy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already fixed
@@ -737,12 +888,92 @@ def execute_model( | |||
if max_gen_len == 1: | |||
# No spec decode tokens. | |||
valid_sampled_token_ids = sampled_token_ids.tolist() | |||
else: | |||
# Includes spec decode tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we extract the model execution code into a seperate function to make code clear?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already fixed
will use 1 commit finally |
b408e85
to
a86f051
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basiclly I'm fine with the change. Just some nit. Thanks.
vllm_ascend/patch/__init__.py
Outdated
# Re-implementation the `prepare_input_kernel` triton kernel by pytorch | ||
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit.... | ||
# - https://github.com/vllm-project/vllm-ascend/pull/874 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the PR for vllm. the content can be changed to ascend doesn't support triton
@@ -141,24 +141,22 @@ def reorder_batch(self, input_batch: "InputBatch", | |||
|
|||
def build(self, num_reqs, num_actual_tokens, max_query_len, | |||
common_prefix_len): | |||
if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need rebase, this change is merge to main already. 7aa4f85
return output_token_ids | ||
|
||
|
||
def expand_batch_to_tokens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where does this func called?
assert isinstance(self.drafter, NgramProposer) | ||
spec_token_ids = self.generate_draft_token_ids( | ||
valid_sampled_token_ids, sampling_metadata) | ||
elif self.speculative_config.method == "eagle": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note/todo here to highlight that eagle
mode doesn't work currently.
And if eagle
doesn't work now, i think we'd better raise error here directly. And complete the function in the future PR.
elif self.speculative_config.method == "eagle":
raise NotImplementedError("eagle method for spec decode doesn't work on vllm-ascend currently")
@@ -220,11 +218,11 @@ def forward( | |||
key: shape = [batch_size, seq_len, num_kv_heads * head_size] | |||
value: shape = [batch_size, seq_len, num_kv_heads * head_size] | |||
kv_cache: shape = [2, num_blocks, block_size, | |||
num_kv_heads * head_size] | |||
num_kv_heads, head_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -687,6 +811,92 @@ def apply_grammar_bitmask( | |||
) | |||
return logits.to(self.device).to(logits_dtype) | |||
|
|||
def get_spec_token_ids( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to _get_spec_token_ids
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -1083,3 +1344,35 @@ def capture_model(self) -> None: | |||
# This usually takes 5~20 seconds. | |||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", | |||
elapsed_time, npu_graph_size / (1 << 30)) | |||
|
|||
def generate_draft_token_ids( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Signed-off-by: ponix-j <657511300@qq.com>
target_probs[token_idx, draft_token_id] = orig_prob | ||
|
||
|
||
rs.expand_batch_to_tokens = expand_batch_to_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be moved to patch module
valid_sampled_token_ids, sampling_metadata) | ||
elif self.speculative_config.method == "eagle": | ||
raise NotImplementedError( | ||
"eagle method for spec decode doesn't work on vllm-ascend currently" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add eagle support in the future
### What this PR does / why we need it? add basic v1 mtp features please merge it after #874 and #844. ### Does this PR introduce _any_ user-facing change? now, we supported basic v1 mtp, only supported tp only、eager mode and k=1 we will continue to expand more scenarios. ### How was this patch tested? local tested Signed-off-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> Make spec decode support for V1 Engine - Currently, Ascend does not support the triton kernel. PyTorch is used to rewrite the `rejection_sampler.py` triton kernel. However, PyTorch is not as good as Triton. Therefore, ascend c is used to implement the function in the future. - Currently, spec decode supports only the ngram algorithm. The eagle algorithm needs to be further adapted. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> Not change user facing. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> test by `tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py` and `tests/sample/test_rejection_sampler.py`, test base function of rejection sampler and e2e function of spec decode. Signed-off-by: ponix-j <657511300@qq.com> Signed-off-by: wangxiaoxin (A) <w00664509@china.huawei.com>
### What this PR does / why we need it? add basic v1 mtp features please merge it after vllm-project#874 and vllm-project#844. ### Does this PR introduce _any_ user-facing change? now, we supported basic v1 mtp, only supported tp only、eager mode and k=1 we will continue to expand more scenarios. ### How was this patch tested? local tested Signed-off-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: JC-ut0 <xuyexiong@huawei.com> Signed-off-by: wangxiaoxin (A) <w00664509@china.huawei.com>
### What this PR does / why we need it? add basic v1 mtp features please merge it after vllm-project#874 and vllm-project#844. ### Does this PR introduce _any_ user-facing change? now, we supported basic v1 mtp, only supported tp only、eager mode and k=1 we will continue to expand more scenarios. ### How was this patch tested? local tested Signed-off-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
…nc graph typo fix (#1121) ### What this PR does / why we need it? 1. The dependency was introduced by #874 - Move numba/quart from requirements-dev to requirments - Align pyproject.toml with requirements 2. This patch also fix deepseek accuracy baseline which #1118 was not addressed. According to https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite the gsm8k is about `41.1` 3. This also sync the vLLM upstream changes: vllm-project/vllm@eaa2e51 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passed vllm ascend test (basic workflow) vllm longterm test (spec decode) Closes: #1120 --------- Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
What this PR does / why we need it?
Make spec decode support for V1 Engine
rejection_sampler.py
triton kernel. However, PyTorch is not as good as Triton. Therefore, ascend c is used to implement the function in the future.Does this PR introduce any user-facing change?
Not change user facing.
How was this patch tested?
test by
tests/singlecard/spec_decode/e2e/test_v1_spec_decode.py
andtests/sample/test_rejection_sampler.py
, test base function of rejection sampler and e2e function of spec decode.