|
15 | 15 | # limitations under the License.
|
16 | 16 | # This file is a part of the vllm-ascend project.
|
17 | 17 | # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
|
18 |
| -# |
| 18 | + |
19 | 19 |
|
20 | 20 | import gc
|
21 | 21 | import os
|
@@ -1367,73 +1367,70 @@ def _get_spec_token_ids(
|
1367 | 1367 | assert isinstance(self.drafter, NgramProposer)
|
1368 | 1368 | spec_token_ids = self._generate_draft_token_ids(
|
1369 | 1369 | valid_sampled_token_ids, sampling_metadata)
|
1370 |
| - elif self.speculative_config.method == "eagle": |
1371 |
| - raise NotImplementedError("Eagle Is Not Supported Yet.") |
1372 |
| - elif self.speculative_config.method == "eagle3": |
| 1370 | + elif self.use_eagle: |
1373 | 1371 | assert isinstance(self.drafter, EagleProposer)
|
1374 |
| - if self.speculative_config.use_eagle(): |
1375 |
| - next_token_ids: list[int] = [] |
1376 |
| - for i, token_ids in enumerate(valid_sampled_token_ids): |
1377 |
| - if token_ids: |
1378 |
| - # Common case. |
1379 |
| - next_token_id = token_ids[-1] |
1380 |
| - else: |
1381 |
| - # Partial prefill (rare case). |
1382 |
| - # Get the next token id from the request state. |
1383 |
| - req_id = self.input_batch.req_ids[i] |
1384 |
| - req_state = self.requests[req_id] |
1385 |
| - seq_len = ( |
1386 |
| - req_state.num_computed_tokens + |
1387 |
| - scheduler_output.num_scheduled_tokens[req_id]) |
1388 |
| - |
1389 |
| - next_token_id = req_state.get_token_id(seq_len) |
1390 |
| - next_token_ids.append(next_token_id) |
1391 |
| - next_token_ids = torch.tensor(next_token_ids, |
1392 |
| - dtype=torch.int32, |
1393 |
| - device=self.device) |
1394 |
| - eagle_attn_metadata = attn_metadata[ |
1395 |
| - self.drafter.attn_layer_name] |
1396 |
| - num_input_tokens = scheduler_output.total_num_scheduled_tokens |
1397 |
| - if spec_decode_metadata is None: |
1398 |
| - # input_ids can be None for multimodal models. |
1399 |
| - target_token_ids = self.input_ids[:num_scheduled_tokens] |
1400 |
| - target_positions = positions[:num_scheduled_tokens] |
1401 |
| - if self.use_aux_hidden_state_outputs: |
1402 |
| - target_hidden_states = torch.cat([ |
1403 |
| - h[:num_scheduled_tokens] for h in aux_hidden_states |
1404 |
| - ], |
1405 |
| - dim=-1) |
1406 |
| - else: |
1407 |
| - target_hidden_states = hidden_states[: |
1408 |
| - num_scheduled_tokens] |
1409 |
| - target_slot_mapping = eagle_attn_metadata.slot_mapping |
1410 |
| - cu_num_tokens = eagle_attn_metadata.query_start_loc |
| 1372 | + next_token_ids: list[int] = [] |
| 1373 | + for i, token_ids in enumerate(valid_sampled_token_ids): |
| 1374 | + if token_ids: |
| 1375 | + # Common case. |
| 1376 | + next_token_id = token_ids[-1] |
1411 | 1377 | else:
|
1412 |
| - num_draft_tokens = spec_decode_metadata.num_draft_tokens |
1413 |
| - num_rejected_tokens = [ |
1414 |
| - n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 |
1415 |
| - for i, n in enumerate(num_draft_tokens) |
1416 |
| - ] |
1417 |
| - num_rejected_tokens = torch.tensor( |
1418 |
| - num_rejected_tokens, |
1419 |
| - dtype=torch.int32, |
1420 |
| - device=self.device, |
1421 |
| - ) |
1422 |
| - num_tokens = num_scheduled_tokens - sum( |
1423 |
| - num_rejected_tokens) |
1424 |
| - cu_num_tokens, token_indices = self.drafter.prepare_inputs( |
1425 |
| - eagle_attn_metadata.query_start_loc, |
1426 |
| - num_rejected_tokens, num_tokens) |
1427 |
| - target_token_ids = self.input_ids[token_indices] |
1428 |
| - target_positions = positions[token_indices] |
1429 |
| - if self.use_aux_hidden_state_outputs: |
1430 |
| - target_hidden_states = torch.cat( |
1431 |
| - [h[token_indices] for h in aux_hidden_states], |
1432 |
| - dim=-1) |
1433 |
| - else: |
1434 |
| - target_hidden_states = hidden_states[token_indices] |
1435 |
| - target_slot_mapping = eagle_attn_metadata.slot_mapping[ |
1436 |
| - token_indices] |
| 1378 | + # Partial prefill (rare case). |
| 1379 | + # Get the next token id from the request state. |
| 1380 | + req_id = self.input_batch.req_ids[i] |
| 1381 | + req_state = self.requests[req_id] |
| 1382 | + seq_len = ( |
| 1383 | + req_state.num_computed_tokens + |
| 1384 | + scheduler_output.num_scheduled_tokens[req_id]) |
| 1385 | + |
| 1386 | + next_token_id = req_state.get_token_id(seq_len) |
| 1387 | + next_token_ids.append(next_token_id) |
| 1388 | + next_token_ids = torch.tensor(next_token_ids, |
| 1389 | + dtype=torch.int32, |
| 1390 | + device=self.device) |
| 1391 | + eagle_attn_metadata = attn_metadata[ |
| 1392 | + self.drafter.attn_layer_name] |
| 1393 | + num_input_tokens = scheduler_output.total_num_scheduled_tokens |
| 1394 | + if spec_decode_metadata is None: |
| 1395 | + # input_ids can be None for multimodal models. |
| 1396 | + target_token_ids = self.input_ids[:num_scheduled_tokens] |
| 1397 | + target_positions = positions[:num_scheduled_tokens] |
| 1398 | + if self.use_aux_hidden_state_outputs: |
| 1399 | + target_hidden_states = torch.cat([ |
| 1400 | + h[:num_scheduled_tokens] for h in aux_hidden_states |
| 1401 | + ], |
| 1402 | + dim=-1) |
| 1403 | + else: |
| 1404 | + target_hidden_states = hidden_states[: |
| 1405 | + num_scheduled_tokens] |
| 1406 | + target_slot_mapping = eagle_attn_metadata.slot_mapping |
| 1407 | + cu_num_tokens = eagle_attn_metadata.query_start_loc |
| 1408 | + else: |
| 1409 | + num_draft_tokens = spec_decode_metadata.num_draft_tokens |
| 1410 | + num_rejected_tokens = [ |
| 1411 | + n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0 |
| 1412 | + for i, n in enumerate(num_draft_tokens) |
| 1413 | + ] |
| 1414 | + num_rejected_tokens = torch.tensor( |
| 1415 | + num_rejected_tokens, |
| 1416 | + dtype=torch.int32, |
| 1417 | + device=self.device, |
| 1418 | + ) |
| 1419 | + num_tokens = num_scheduled_tokens - sum( |
| 1420 | + num_rejected_tokens) |
| 1421 | + cu_num_tokens, token_indices = self.drafter.prepare_inputs( |
| 1422 | + eagle_attn_metadata.query_start_loc, |
| 1423 | + num_rejected_tokens, num_tokens) |
| 1424 | + target_token_ids = self.input_ids[token_indices] |
| 1425 | + target_positions = positions[token_indices] |
| 1426 | + if self.use_aux_hidden_state_outputs: |
| 1427 | + target_hidden_states = torch.cat( |
| 1428 | + [h[token_indices] for h in aux_hidden_states], |
| 1429 | + dim=-1) |
| 1430 | + else: |
| 1431 | + target_hidden_states = hidden_states[token_indices] |
| 1432 | + target_slot_mapping = eagle_attn_metadata.slot_mapping[ |
| 1433 | + token_indices] |
1437 | 1434 |
|
1438 | 1435 | positions = self.positions[:num_input_tokens]
|
1439 | 1436 | draft_token_ids = self.drafter.propose(
|
@@ -1832,7 +1829,7 @@ def load_model(self) -> None:
|
1832 | 1829 | self.model = get_model(vllm_config=self.vllm_config)
|
1833 | 1830 | if hasattr(self, "drafter"):
|
1834 | 1831 | logger.info("Loading drafter model...")
|
1835 |
| - if self.use_aux_hidden_state_outputs: |
| 1832 | + if self.use_eagle: |
1836 | 1833 | self.drafter.load_model(self.model)
|
1837 | 1834 | else:
|
1838 | 1835 | self.drafter.load_model()
|
|
0 commit comments