Skip to content

Commit 739728c

Browse files
committed
Signed-off-by: umeiko <umeko@stu.xmu.edu.cn>
Eagle 1 support and e2e test.
1 parent e112317 commit 739728c

File tree

2 files changed

+71
-83
lines changed

2 files changed

+71
-83
lines changed

tests/e2e/long_term/spec_decode_v1/test_v1_spec_decode.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ def test_eagle_correctness(
117117
Compare the outputs of a original LLM and a speculative LLM
118118
should be the same when using eagle speculative decoding.
119119
'''
120-
if not use_eagle3:
121-
pytest.skip("Not current support for the test.")
122120
with monkeypatch.context() as m:
123121
m.setenv("VLLM_USE_V1", "1")
124122

vllm_ascend/worker/model_runner_v1.py

Lines changed: 71 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
# This file is a part of the vllm-ascend project.
1717
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
18-
#
1918

2019
import gc
2120
import os
@@ -924,8 +923,8 @@ def _process_reqs(
924923
assert total_num_scheduled_tokens > 0
925924
num_reqs = self.input_batch.num_reqs
926925
assert num_reqs > 0
927-
if (self.use_aclgraph and
928-
total_num_scheduled_tokens <= self.aclgraph_batch_sizes[-1]):
926+
if (self.use_aclgraph and total_num_scheduled_tokens
927+
<= self.aclgraph_batch_sizes[-1]):
929928
# Add padding to the batch size.
930929
num_input_tokens = self.vllm_config.pad_for_cudagraph(
931930
total_num_scheduled_tokens)
@@ -1367,86 +1366,77 @@ def _get_spec_token_ids(
13671366
assert isinstance(self.drafter, NgramProposer)
13681367
spec_token_ids = self._generate_draft_token_ids(
13691368
valid_sampled_token_ids, sampling_metadata)
1370-
elif self.speculative_config.method == "eagle":
1371-
raise NotImplementedError("Eagle Is Not Supported Yet.")
1372-
elif self.speculative_config.method == "eagle3":
1369+
elif self.use_eagle:
13731370
assert isinstance(self.drafter, EagleProposer)
1374-
if self.speculative_config.use_eagle():
1375-
next_token_ids: list[int] = []
1376-
for i, token_ids in enumerate(valid_sampled_token_ids):
1377-
if token_ids:
1378-
# Common case.
1379-
next_token_id = token_ids[-1]
1380-
else:
1381-
# Partial prefill (rare case).
1382-
# Get the next token id from the request state.
1383-
req_id = self.input_batch.req_ids[i]
1384-
req_state = self.requests[req_id]
1385-
seq_len = (
1386-
req_state.num_computed_tokens +
1387-
scheduler_output.num_scheduled_tokens[req_id])
1388-
1389-
next_token_id = req_state.get_token_id(seq_len)
1390-
next_token_ids.append(next_token_id)
1391-
next_token_ids = torch.tensor(next_token_ids,
1392-
dtype=torch.int32,
1393-
device=self.device)
1394-
eagle_attn_metadata = attn_metadata[
1395-
self.drafter.attn_layer_name]
1396-
num_input_tokens = scheduler_output.total_num_scheduled_tokens
1397-
if spec_decode_metadata is None:
1398-
# input_ids can be None for multimodal models.
1399-
target_token_ids = self.input_ids[:num_scheduled_tokens]
1400-
target_positions = positions[:num_scheduled_tokens]
1401-
if self.use_aux_hidden_state_outputs:
1402-
target_hidden_states = torch.cat([
1403-
h[:num_scheduled_tokens] for h in aux_hidden_states
1404-
],
1405-
dim=-1)
1406-
else:
1407-
target_hidden_states = hidden_states[:
1408-
num_scheduled_tokens]
1409-
target_slot_mapping = eagle_attn_metadata.slot_mapping
1410-
cu_num_tokens = eagle_attn_metadata.query_start_loc
1371+
next_token_ids: list[int] = []
1372+
for i, token_ids in enumerate(valid_sampled_token_ids):
1373+
if token_ids:
1374+
# Common case.
1375+
next_token_id = token_ids[-1]
14111376
else:
1412-
num_draft_tokens = spec_decode_metadata.num_draft_tokens
1413-
num_rejected_tokens = [
1414-
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
1415-
for i, n in enumerate(num_draft_tokens)
1416-
]
1417-
num_rejected_tokens = torch.tensor(
1418-
num_rejected_tokens,
1419-
dtype=torch.int32,
1420-
device=self.device,
1421-
)
1422-
num_tokens = num_scheduled_tokens - sum(
1423-
num_rejected_tokens)
1424-
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
1425-
eagle_attn_metadata.query_start_loc,
1426-
num_rejected_tokens, num_tokens)
1427-
target_token_ids = self.input_ids[token_indices]
1428-
target_positions = positions[token_indices]
1429-
if self.use_aux_hidden_state_outputs:
1430-
target_hidden_states = torch.cat(
1431-
[h[token_indices] for h in aux_hidden_states],
1432-
dim=-1)
1433-
else:
1434-
target_hidden_states = hidden_states[token_indices]
1435-
target_slot_mapping = eagle_attn_metadata.slot_mapping[
1436-
token_indices]
1437-
1438-
positions = self.positions[:num_input_tokens]
1439-
draft_token_ids = self.drafter.propose(
1440-
target_token_ids=target_token_ids,
1441-
target_positions=target_positions,
1442-
target_hidden_states=target_hidden_states,
1443-
target_slot_mapping=target_slot_mapping,
1444-
next_token_ids=next_token_ids,
1445-
cu_num_tokens=cu_num_tokens,
1446-
block_table=eagle_attn_metadata.block_tables,
1447-
sampling_metadata=sampling_metadata,
1377+
# Partial prefill (rare case).
1378+
# Get the next token id from the request state.
1379+
req_id = self.input_batch.req_ids[i]
1380+
req_state = self.requests[req_id]
1381+
seq_len = (req_state.num_computed_tokens +
1382+
scheduler_output.num_scheduled_tokens[req_id])
1383+
1384+
next_token_id = req_state.get_token_id(seq_len)
1385+
next_token_ids.append(next_token_id)
1386+
next_token_ids = torch.tensor(next_token_ids,
1387+
dtype=torch.int32,
1388+
device=self.device)
1389+
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
1390+
num_input_tokens = scheduler_output.total_num_scheduled_tokens
1391+
if spec_decode_metadata is None:
1392+
# input_ids can be None for multimodal models.
1393+
target_token_ids = self.input_ids[:num_scheduled_tokens]
1394+
target_positions = positions[:num_scheduled_tokens]
1395+
if self.use_aux_hidden_state_outputs:
1396+
target_hidden_states = torch.cat(
1397+
[h[:num_scheduled_tokens] for h in aux_hidden_states],
1398+
dim=-1)
1399+
else:
1400+
target_hidden_states = hidden_states[:num_scheduled_tokens]
1401+
target_slot_mapping = eagle_attn_metadata.slot_mapping
1402+
cu_num_tokens = eagle_attn_metadata.query_start_loc
1403+
else:
1404+
num_draft_tokens = spec_decode_metadata.num_draft_tokens
1405+
num_rejected_tokens = [
1406+
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
1407+
for i, n in enumerate(num_draft_tokens)
1408+
]
1409+
num_rejected_tokens = torch.tensor(
1410+
num_rejected_tokens,
1411+
dtype=torch.int32,
1412+
device=self.device,
14481413
)
1449-
spec_token_ids = draft_token_ids.tolist()
1414+
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
1415+
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
1416+
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
1417+
num_tokens)
1418+
target_token_ids = self.input_ids[token_indices]
1419+
target_positions = positions[token_indices]
1420+
if self.use_aux_hidden_state_outputs:
1421+
target_hidden_states = torch.cat(
1422+
[h[token_indices] for h in aux_hidden_states], dim=-1)
1423+
else:
1424+
target_hidden_states = hidden_states[token_indices]
1425+
target_slot_mapping = eagle_attn_metadata.slot_mapping[
1426+
token_indices]
1427+
1428+
positions = self.positions[:num_input_tokens]
1429+
draft_token_ids = self.drafter.propose(
1430+
target_token_ids=target_token_ids,
1431+
target_positions=target_positions,
1432+
target_hidden_states=target_hidden_states,
1433+
target_slot_mapping=target_slot_mapping,
1434+
next_token_ids=next_token_ids,
1435+
cu_num_tokens=cu_num_tokens,
1436+
block_table=eagle_attn_metadata.block_tables,
1437+
sampling_metadata=sampling_metadata,
1438+
)
1439+
spec_token_ids = draft_token_ids.tolist()
14501440
elif self.speculative_config.method == 'deepseek_mtp':
14511441
assert isinstance(self.drafter, MtpProposer)
14521442
spec_token_ids = self._generate_mtp_token_ids(
@@ -1832,7 +1822,7 @@ def load_model(self) -> None:
18321822
self.model = get_model(vllm_config=self.vllm_config)
18331823
if hasattr(self, "drafter"):
18341824
logger.info("Loading drafter model...")
1835-
if self.use_aux_hidden_state_outputs:
1825+
if self.use_eagle:
18361826
self.drafter.load_model(self.model)
18371827
else:
18381828
self.drafter.load_model()

0 commit comments

Comments
 (0)