Skip to content

Commit c3f1605

Browse files
committed
Signed-off-by: umeiko <umeko@stu.xmu.edu.cn>
EAGLE1 SUPPORT
1 parent ca884ef commit c3f1605

File tree

2 files changed

+69
-81
lines changed

2 files changed

+69
-81
lines changed

tests/e2e/long_term/spec_decode_v1/test_v1_spec_decode.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ def test_eagle_correctness(
117117
Compare the outputs of a original LLM and a speculative LLM
118118
should be the same when using eagle speculative decoding.
119119
'''
120-
if not use_eagle3:
121-
pytest.skip("Not current support for the test.")
122120
with monkeypatch.context() as m:
123121
m.setenv("VLLM_USE_V1", "1")
124122

vllm_ascend/worker/model_runner_v1.py

Lines changed: 69 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# limitations under the License.
1616
# This file is a part of the vllm-ascend project.
1717
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
18-
#
1918

2019
import gc
2120
import os
@@ -1334,86 +1333,77 @@ def _get_spec_token_ids(
13341333
assert isinstance(self.drafter, NgramProposer)
13351334
spec_token_ids = self._generate_draft_token_ids(
13361335
valid_sampled_token_ids, sampling_metadata)
1337-
elif self.speculative_config.method == "eagle":
1338-
raise NotImplementedError("Eagle Is Not Supported Yet.")
1339-
elif self.speculative_config.method == "eagle3":
1336+
elif self.use_eagle:
13401337
assert isinstance(self.drafter, EagleProposer)
1341-
if self.speculative_config.use_eagle():
1342-
next_token_ids: list[int] = []
1343-
for i, token_ids in enumerate(valid_sampled_token_ids):
1344-
if token_ids:
1345-
# Common case.
1346-
next_token_id = token_ids[-1]
1347-
else:
1348-
# Partial prefill (rare case).
1349-
# Get the next token id from the request state.
1350-
req_id = self.input_batch.req_ids[i]
1351-
req_state = self.requests[req_id]
1352-
seq_len = (
1353-
req_state.num_computed_tokens +
1354-
scheduler_output.num_scheduled_tokens[req_id])
1355-
1356-
next_token_id = req_state.get_token_id(seq_len)
1357-
next_token_ids.append(next_token_id)
1358-
next_token_ids = torch.tensor(next_token_ids,
1359-
dtype=torch.int32,
1360-
device=self.device)
1361-
eagle_attn_metadata = attn_metadata[
1362-
self.drafter.attn_layer_name]
1363-
num_input_tokens = scheduler_output.total_num_scheduled_tokens
1364-
if spec_decode_metadata is None:
1365-
# input_ids can be None for multimodal models.
1366-
target_token_ids = self.input_ids[:num_scheduled_tokens]
1367-
target_positions = positions[:num_scheduled_tokens]
1368-
if self.use_aux_hidden_state_outputs:
1369-
target_hidden_states = torch.cat([
1370-
h[:num_scheduled_tokens] for h in aux_hidden_states
1371-
],
1372-
dim=-1)
1373-
else:
1374-
target_hidden_states = hidden_states[:
1375-
num_scheduled_tokens]
1376-
target_slot_mapping = eagle_attn_metadata.slot_mapping
1377-
cu_num_tokens = eagle_attn_metadata.query_start_loc
1338+
next_token_ids: list[int] = []
1339+
for i, token_ids in enumerate(valid_sampled_token_ids):
1340+
if token_ids:
1341+
# Common case.
1342+
next_token_id = token_ids[-1]
13781343
else:
1379-
num_draft_tokens = spec_decode_metadata.num_draft_tokens
1380-
num_rejected_tokens = [
1381-
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
1382-
for i, n in enumerate(num_draft_tokens)
1383-
]
1384-
num_rejected_tokens = torch.tensor(
1385-
num_rejected_tokens,
1386-
dtype=torch.int32,
1387-
device=self.device,
1388-
)
1389-
num_tokens = num_scheduled_tokens - sum(
1390-
num_rejected_tokens)
1391-
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
1392-
eagle_attn_metadata.query_start_loc,
1393-
num_rejected_tokens, num_tokens)
1394-
target_token_ids = self.input_ids[token_indices]
1395-
target_positions = positions[token_indices]
1396-
if self.use_aux_hidden_state_outputs:
1397-
target_hidden_states = torch.cat(
1398-
[h[token_indices] for h in aux_hidden_states],
1399-
dim=-1)
1400-
else:
1401-
target_hidden_states = hidden_states[token_indices]
1402-
target_slot_mapping = eagle_attn_metadata.slot_mapping[
1403-
token_indices]
1404-
1405-
positions = self.positions[:num_input_tokens]
1406-
draft_token_ids = self.drafter.propose(
1407-
target_token_ids=target_token_ids,
1408-
target_positions=target_positions,
1409-
target_hidden_states=target_hidden_states,
1410-
target_slot_mapping=target_slot_mapping,
1411-
next_token_ids=next_token_ids,
1412-
cu_num_tokens=cu_num_tokens,
1413-
block_table=eagle_attn_metadata.block_tables,
1414-
sampling_metadata=sampling_metadata,
1344+
# Partial prefill (rare case).
1345+
# Get the next token id from the request state.
1346+
req_id = self.input_batch.req_ids[i]
1347+
req_state = self.requests[req_id]
1348+
seq_len = (req_state.num_computed_tokens +
1349+
scheduler_output.num_scheduled_tokens[req_id])
1350+
1351+
next_token_id = req_state.get_token_id(seq_len)
1352+
next_token_ids.append(next_token_id)
1353+
next_token_ids = torch.tensor(next_token_ids,
1354+
dtype=torch.int32,
1355+
device=self.device)
1356+
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
1357+
num_input_tokens = scheduler_output.total_num_scheduled_tokens
1358+
if spec_decode_metadata is None:
1359+
# input_ids can be None for multimodal models.
1360+
target_token_ids = self.input_ids[:num_scheduled_tokens]
1361+
target_positions = positions[:num_scheduled_tokens]
1362+
if self.use_aux_hidden_state_outputs:
1363+
target_hidden_states = torch.cat(
1364+
[h[:num_scheduled_tokens] for h in aux_hidden_states],
1365+
dim=-1)
1366+
else:
1367+
target_hidden_states = hidden_states[:num_scheduled_tokens]
1368+
target_slot_mapping = eagle_attn_metadata.slot_mapping
1369+
cu_num_tokens = eagle_attn_metadata.query_start_loc
1370+
else:
1371+
num_draft_tokens = spec_decode_metadata.num_draft_tokens
1372+
num_rejected_tokens = [
1373+
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
1374+
for i, n in enumerate(num_draft_tokens)
1375+
]
1376+
num_rejected_tokens = torch.tensor(
1377+
num_rejected_tokens,
1378+
dtype=torch.int32,
1379+
device=self.device,
14151380
)
1416-
spec_token_ids = draft_token_ids.tolist()
1381+
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
1382+
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
1383+
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
1384+
num_tokens)
1385+
target_token_ids = self.input_ids[token_indices]
1386+
target_positions = positions[token_indices]
1387+
if self.use_aux_hidden_state_outputs:
1388+
target_hidden_states = torch.cat(
1389+
[h[token_indices] for h in aux_hidden_states], dim=-1)
1390+
else:
1391+
target_hidden_states = hidden_states[token_indices]
1392+
target_slot_mapping = eagle_attn_metadata.slot_mapping[
1393+
token_indices]
1394+
1395+
positions = self.positions[:num_input_tokens]
1396+
draft_token_ids = self.drafter.propose(
1397+
target_token_ids=target_token_ids,
1398+
target_positions=target_positions,
1399+
target_hidden_states=target_hidden_states,
1400+
target_slot_mapping=target_slot_mapping,
1401+
next_token_ids=next_token_ids,
1402+
cu_num_tokens=cu_num_tokens,
1403+
block_table=eagle_attn_metadata.block_tables,
1404+
sampling_metadata=sampling_metadata,
1405+
)
1406+
spec_token_ids = draft_token_ids.tolist()
14171407
elif self.speculative_config.method == 'deepseek_mtp':
14181408
assert isinstance(self.drafter, MtpProposer)
14191409
spec_token_ids = self._generate_mtp_token_ids(
@@ -1797,7 +1787,7 @@ def load_model(self) -> None:
17971787
self.model = get_model(vllm_config=self.vllm_config)
17981788
if self.drafter:
17991789
logger.info("Loading drafter model...")
1800-
if self.use_aux_hidden_state_outputs:
1790+
if self.use_eagle:
18011791
self.drafter.load_model(self.model)
18021792
self.model.set_aux_hidden_state_layers(
18031793
self.model.get_eagle3_aux_hidden_state_layers())

0 commit comments

Comments
 (0)