30
30
from vllm .logger import init_logger
31
31
from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
32
32
from vllm .model_executor .model_loader import TensorizerLoader , get_model_loader
33
- from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
34
33
from vllm .multimodal import MULTIMODAL_REGISTRY
35
34
from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
36
35
from vllm .multimodal .utils import group_mm_inputs_by_modality
@@ -626,6 +625,7 @@ def _prepare_inputs(
626
625
self .query_start_loc_np [0 ] = 0
627
626
self .query_start_loc_np [1 :num_reqs + 1 ] = cu_num_tokens
628
627
628
+ # Prepare seq_len and num_token for eagle metadata
629
629
self .seq_lens_np [:num_reqs ] = (
630
630
self .input_batch .num_computed_tokens_cpu [:num_reqs ] +
631
631
num_scheduled_tokens )
@@ -635,9 +635,12 @@ def _prepare_inputs(
635
635
]
636
636
num_tokens_np = np .array (num_tokens , dtype = np .int32 )
637
637
638
+ # Record the index of requests that should not be sampled,
639
+ # so that we could clear the sampled tokens before returning
638
640
self .discard_req_np [:num_reqs ] = \
639
641
self .seq_lens_np [:num_reqs ] < num_tokens_np
640
642
643
+ # Also record indices of requests that should be sampled
641
644
self .remaining_req_count = np .count_nonzero (
642
645
self .discard_req_np [:num_reqs ] == 0 )
643
646
self .remaining_req_indices_np [:self .remaining_req_count ] = np .nonzero (
@@ -647,13 +650,14 @@ def _prepare_inputs(
647
650
self .remaining_req_indices_cpu [:self .remaining_req_count ],
648
651
non_blocking = True )
649
652
653
+ # Precompute get_token_id for when there is no valid next token
650
654
self .backup_next_token_ids_np [:num_reqs ] = np .array ([
651
655
self .requests [self .input_batch .req_ids [i ]].get_token_id (
652
656
self .seq_lens_np [i ]) for i in range (num_reqs )
653
657
])
654
658
655
659
self .backup_next_token_ids [:num_reqs ].copy_ (
656
- self .backup_next_token_ids_cpu [:num_reqs ])
660
+ self .backup_next_token_ids_cpu [:num_reqs ], non_blocking = True )
657
661
658
662
# Copy the tensors to the GPU.
659
663
self .input_ids [:total_num_scheduled_tokens ].copy_ (
@@ -1418,9 +1422,11 @@ def execute_model(
1418
1422
elif self .speculative_config .use_eagle ():
1419
1423
assert isinstance (self .drafter , EagleProposer )
1420
1424
1425
+ # Get all sampled tokens from valid requests
1421
1426
valid_sampled_token_ids_gpu = sampled_token_ids [
1422
1427
self .remaining_req_indices [:self .remaining_req_count ]]
1423
1428
1429
+ # Generate a mask for all valid tokens within those requests
1424
1430
if max_gen_len == 1 :
1425
1431
valid_mask = torch .ones_like (valid_sampled_token_ids_gpu ,
1426
1432
dtype = torch .bool )
@@ -1429,19 +1435,21 @@ def execute_model(
1429
1435
(valid_sampled_token_ids_gpu
1430
1436
< self .input_batch .vocab_size ))
1431
1437
1438
+ # Count valid tokens in each request
1432
1439
valid_sampled_count = valid_mask .sum (dim = 1 )
1433
1440
1434
1441
batch = valid_sampled_token_ids_gpu .shape [0 ]
1435
1442
1436
1443
# Get the rightmost valid index per row
1437
1444
last_valid_indices = valid_sampled_count - 1
1438
1445
1439
- # Fill with -1 first (or PLACEHOLDER_ID)
1440
- # tokens selected for every row (valid or not )
1446
+ # Get last valid token from each row
1447
+ # (assume undefined state where there is no valid token )
1441
1448
selected_tokens = torch .gather (
1442
1449
valid_sampled_token_ids_gpu , 1 ,
1443
1450
last_valid_indices .unsqueeze (1 )).squeeze (1 )
1444
1451
1452
+ # Use last token if valid, pre-computed backup if not
1445
1453
next_token_ids_gpu = torch .where (
1446
1454
last_valid_indices != - 1 , selected_tokens ,
1447
1455
self .backup_next_token_ids [:batch ])
@@ -1470,8 +1478,9 @@ def execute_model(
1470
1478
target_slot_mapping = eagle_attn_metadata .slot_mapping
1471
1479
cu_num_tokens = eagle_attn_metadata .query_start_loc
1472
1480
else :
1481
+ # Recompute num_draft_tokens from cumsum
1473
1482
num_draft_tokens_gpu = torch .cat ([
1474
- spec_decode_metadata .cu_num_draft_tokens [:1 ],
1483
+ spec_decode_metadata .cu_num_draft_tokens [0 :1 ],
1475
1484
spec_decode_metadata .cu_num_draft_tokens [1 :] -
1476
1485
spec_decode_metadata .cu_num_draft_tokens [:- 1 ]
1477
1486
])
@@ -1495,34 +1504,10 @@ def execute_model(
1495
1504
target_slot_mapping = eagle_attn_metadata .slot_mapping [
1496
1505
token_indices ]
1497
1506
1498
- # Moved from EagleProposer.propose() to here
1499
- if self .drafter .method == "eagle3" :
1500
- assert isinstance (self .drafter .model , Eagle3LlamaForCausalLM )
1501
- target_hidden_states = self .drafter .model .combine_hidden_states (
1502
- target_hidden_states )
1503
- assert target_hidden_states .shape [
1504
- - 1 ] == self .drafter .hidden_size
1505
-
1506
- # Shift the input ids by one token.
1507
- # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
1508
- self .drafter .input_ids [:num_scheduled_tokens -
1509
- 1 ] = target_token_ids [:
1510
- num_scheduled_tokens ][
1511
- 1 :]
1512
-
1513
- # Replace the last token with the next token.
1514
- # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
1515
- last_token_indices = cu_num_tokens [1 :] - 1
1516
- self .drafter .input_ids [last_token_indices ] = next_token_ids_gpu
1517
-
1518
- # FA requires seq_len to have dtype int32.
1519
- seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
1520
-
1521
- # copy inputs to buffer for cudagraph
1522
- self .drafter .positions [:num_scheduled_tokens ] = \
1523
- target_positions [:num_scheduled_tokens ]
1524
- self .drafter .hidden_states [:num_scheduled_tokens ] = \
1525
- target_hidden_states [:num_scheduled_tokens ]
1507
+ # load token ids, positions, etc. into the eagle model
1508
+ self .drafter .load_inputs (target_token_ids , target_positions ,
1509
+ target_hidden_states , next_token_ids_gpu ,
1510
+ cu_num_tokens , num_scheduled_tokens )
1526
1511
1527
1512
if self .speculative_config and self .speculative_config .use_eagle ():
1528
1513
valid_sampled_token_ids = self .get_valid_sampled_token_ids (
@@ -1561,9 +1546,7 @@ def execute_model(
1561
1546
sampling_metadata = sampling_metadata ,
1562
1547
num_tokens = num_tokens ,
1563
1548
max_num_tokens = max_num_tokens ,
1564
- seq_lens = seq_lens ,
1565
- max_seq_len = max_seq_len ,
1566
- last_token_indices = last_token_indices )
1549
+ max_seq_len = max_seq_len )
1567
1550
spec_token_ids = draft_token_ids .tolist ()
1568
1551
1569
1552
# Clear KVConnector state after all KVs are generated.
@@ -1584,6 +1567,8 @@ def execute_model(
1584
1567
def get_valid_sampled_token_ids (
1585
1568
self , max_gen_len : int , sampled_token_ids : torch .Tensor ,
1586
1569
discard_sampled_tokens_req_indices : np .ndarray ) -> list [list [int ]]:
1570
+ # Returns valid sampled tokens in a list of lists based on
1571
+ # max gen length and discard indices
1587
1572
if max_gen_len == 1 :
1588
1573
# No spec decode tokens.
1589
1574
valid_sampled_token_ids = sampled_token_ids .tolist ()
0 commit comments