Skip to content

[V1][Spec-decode] First Stage support of Eagle 1 #1128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/vllm_ascend_test_long_term.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ jobs:
run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
# v0 spec decode test
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py # it needs a clean process
pytest -sv tests/e2e/long_term/spec_decode_v0 --ignore=tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py
# TODO: revert when test_v1_mtp_correctness.py is fixed
# VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py # it needs a clean process
# pytest -sv tests/e2e/long_term/spec_decode_v0 --ignore=tests/e2e/long_term/spec_decode_v0/e2e/test_mtp_correctness.py
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepseek_mtp CI need to be skip.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! you can just remove the comment to enable it

# v1 spec decode test
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v1/test_v1_mtp_correctness.py
# TODO: revert when test_v1_mtp_correctness.py is fixed
# VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v1/test_v1_mtp_correctness.py
# TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/long_term/spec_decode_v1/test_v1_spec_decode.py
# accuracy test single card
Expand Down
2 changes: 0 additions & 2 deletions tests/e2e/long_term/spec_decode_v1/test_v1_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def test_eagle_correctness(
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
if not use_eagle3:
pytest.skip("Not current support for the test.")
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests for eagle1 was opened here

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

Expand Down
153 changes: 72 additions & 81 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#

import gc
import os
Expand Down Expand Up @@ -1410,86 +1409,77 @@ def _get_spec_token_ids(
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self._generate_draft_token_ids(
valid_sampled_token_ids, sampling_metadata)
elif self.speculative_config.method == "eagle":
raise NotImplementedError("Eagle Is Not Supported Yet.")
elif self.speculative_config.method == "eagle3":
elif self.use_eagle:
assert isinstance(self.drafter, EagleProposer)
if self.speculative_config.use_eagle():
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (
req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])

next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
eagle_attn_metadata = attn_metadata[
self.drafter.attn_layer_name]
num_input_tokens = scheduler_output.total_num_scheduled_tokens
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([
h[:num_scheduled_tokens] for h in aux_hidden_states
],
dim=-1)
else:
target_hidden_states = hidden_states[:
num_scheduled_tokens]
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
next_token_ids: list[int] = []
for i, token_ids in enumerate(valid_sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
num_tokens = num_scheduled_tokens - sum(
num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc,
num_rejected_tokens, num_tokens)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]

positions = self.positions[:num_input_tokens]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = self.input_batch.req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])

next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
num_input_tokens = scheduler_output.total_num_scheduled_tokens
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
target_slot_mapping = eagle_attn_metadata.slot_mapping
cu_num_tokens = eagle_attn_metadata.query_start_loc
else:
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(
num_rejected_tokens,
dtype=torch.int32,
device=self.device,
)
spec_token_ids = draft_token_ids.tolist()
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
num_tokens)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = eagle_attn_metadata.slot_mapping[
token_indices]

positions = self.positions[:num_input_tokens]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
elif self.speculative_config.method == 'deepseek_mtp':
assert isinstance(self.drafter, MtpProposer)
spec_token_ids = self._generate_mtp_token_ids(
Expand Down Expand Up @@ -2001,10 +1991,11 @@ def load_model(self) -> None:
pass
if self.drafter:
logger.info("Loading drafter model...")
if self.use_aux_hidden_state_outputs:
if self.use_eagle:
self.drafter.load_model(self.model)
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
if self.use_aux_hidden_state_outputs:
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers())
else:
self.drafter.load_model()
if self.lora_config:
Expand Down
Loading