-
Notifications
You must be signed in to change notification settings - Fork 155
Spec decode support for V1 Engine #874
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -0,0 +1,587 @@ | |||
# SPDX-License-Identifier: Apache-2.0 | |||
from typing import Optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rejection_sampler and eagle_proposer fully overwrite,
Can't this be implemented in a way that the main body uses vllm code and the problematic part is patched in vllm_ascend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use class AscendRejectionSampler Inheritance class RejectionSampler,Extracting the part of the change
spec decode is key feature, can u add ngram and eagle e2e ut in vllm-ascend? |
yes,already added test_spec_decode.py |
pls rebase u all commits to 1 commit |
9f5ae8a
to
ed32fcf
Compare
9e7bc8b
to
aa2bb74
Compare
|
||
@pytest.fixture | ||
def model_name(): | ||
return "meta-llama/Meta-Llama-3-8B-Instruct" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use modelscope instead
return "meta-llama/Meta-Llama-3-8B-Instruct" | |
return "LLM-Research/Meta-Llama-3.1-8B-Instruct" |
|
||
@pytest.fixture | ||
def eagle_model_name(): | ||
return "yuhuili/EAGLE-LLaMA3-Instruct-8B" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
@@ -0,0 +1,70 @@ | |||
# SPDX-License-Identifier: Apache-2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not recommand to make a seperate patch dir for spec-decode. Let's move these patch to vllm_ascend/patch/platform
or vllm_ascend/patch/worker
, depending on which period is appropriate.
And the most important, please make comments in vllm_ascend/patch/__init__.py
to describe why we make this patch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allready fixed
# a, a + 1, ..., a + b - n2 - 1, | ||
# a + b, a + b + 1, ..., a + b + c - n3 - 1] | ||
|
||
# [0, a, a + b, a + b + c] -> [a, b, c] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plz remove the useless comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is copy from vllm/eagle.py, why this useless?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mistakenly thought it was a cuda-specific comment, please ignore it
vllm_ascend/utils.py
Outdated
@@ -100,6 +100,7 @@ def adapt_patch(is_global_patch: bool = False): | |||
if is_global_patch: | |||
from vllm_ascend.patch import platform # noqa: F401 | |||
else: | |||
from vllm_ascend.patch import spec_decode # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's remove this when the above suggestions on patch is solved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already fixed
# [0, 1, 2, 5, 6, 9] | ||
target_logits_indices += arange | ||
|
||
# TODO: Optimize the CPU -> GPU copy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# TODO: Optimize the CPU -> GPU copy. | |
# TODO: Optimize the CPU -> NPU copy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already fixed
@@ -737,12 +888,92 @@ def execute_model( | |||
if max_gen_len == 1: | |||
# No spec decode tokens. | |||
valid_sampled_token_ids = sampled_token_ids.tolist() | |||
else: | |||
# Includes spec decode tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we extract the model execution code into a seperate function to make code clear?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already fixed
Signed-off-by: ponix-j <657511300@qq.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
will use 1 commit finally |
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?