29
29
from vllm .logger import init_logger
30
30
from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
31
31
from vllm .model_executor .model_loader import TensorizerLoader , get_model_loader
32
+ from vllm .model_executor .models .llama_eagle3 import Eagle3LlamaForCausalLM
32
33
from vllm .multimodal import MULTIMODAL_REGISTRY
33
34
from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
34
35
from vllm .multimodal .utils import group_mm_inputs_by_modality
@@ -1381,18 +1382,9 @@ def execute_model(
1381
1382
1382
1383
if not self .speculative_config or not self .speculative_config .use_eagle (
1383
1384
):
1384
- if max_gen_len == 1 :
1385
- # No spec decode tokens.
1386
- valid_sampled_token_ids = sampled_token_ids .tolist ()
1387
- else :
1388
- # Includes spec decode tokens.
1389
- valid_sampled_token_ids = self .rejection_sampler .parse_output (
1390
- sampled_token_ids ,
1391
- self .input_batch .vocab_size ,
1392
- )
1393
- # Mask out the sampled tokens that should not be sampled.
1394
- for i in discard_sampled_tokens_req_indices :
1395
- valid_sampled_token_ids [i ].clear ()
1385
+ valid_sampled_token_ids = self .get_valid_sampled_token_ids (
1386
+ max_gen_len , sampled_token_ids ,
1387
+ discard_sampled_tokens_req_indices )
1396
1388
1397
1389
if not self .speculative_config :
1398
1390
# Speculative decoding is not enabled.
@@ -1426,44 +1418,32 @@ def execute_model(
1426
1418
assert isinstance (self .drafter , EagleProposer )
1427
1419
1428
1420
valid_sampled_token_ids_gpu = sampled_token_ids [
1429
- self .remaining_req_indices [:self .remaining_req_count ], : ]
1421
+ self .remaining_req_indices [:self .remaining_req_count ]]
1430
1422
1431
1423
if max_gen_len == 1 :
1432
1424
valid_mask = torch .ones_like (valid_sampled_token_ids_gpu ,
1433
1425
dtype = torch .bool )
1434
1426
else :
1435
- # Includes speculative decode tokens — apply rejection mask
1436
1427
valid_mask = ((valid_sampled_token_ids_gpu != - 1 ) &
1437
1428
(valid_sampled_token_ids_gpu
1438
1429
< self .input_batch .vocab_size ))
1439
1430
1440
1431
valid_sampled_count = valid_mask .sum (dim = 1 )
1441
1432
1442
- batch , seq_length = valid_sampled_token_ids_gpu .shape
1443
- device = valid_sampled_token_ids_gpu .device
1444
-
1445
- # Compute positions (row-wise) of valid tokens
1446
- indices = torch .arange (seq_length ,
1447
- device = device ).expand (batch , seq_length )
1448
- masked_indices = torch .where (valid_mask , indices ,
1449
- torch .full_like (indices , - 1 ))
1433
+ batch = valid_sampled_token_ids_gpu .shape [0 ]
1450
1434
1451
1435
# Get the rightmost valid index per row
1452
- last_valid_indices = masked_indices .max (dim = 1 ).values
1453
-
1454
- # Get next_token_ids for common case
1455
- row_indices = torch .arange (batch , device = device )
1456
- has_valid_token = last_valid_indices != - 1
1436
+ last_valid_indices = valid_sampled_count - 1
1457
1437
1458
1438
# Fill with -1 first (or PLACEHOLDER_ID)
1459
1439
# tokens selected for every row (valid or not)
1460
- selected_tokens = valid_sampled_token_ids_gpu [row_indices ,
1440
+ selected_tokens = valid_sampled_token_ids_gpu [: batch ,
1461
1441
last_valid_indices ]
1462
1442
1463
- # one-liner: keep backup unless row is valid
1464
1443
next_token_ids_gpu = torch .where (
1465
- has_valid_token , selected_tokens ,
1444
+ last_valid_indices != - 1 , selected_tokens ,
1466
1445
self .backup_next_token_ids [:batch ])
1446
+
1467
1447
# At this moment, we assume all eagle layers belong to the same KV
1468
1448
# cache group, thus using the same attention metadata.
1469
1449
eagle_attn_metadata = attn_metadata [
@@ -1475,8 +1455,6 @@ def execute_model(
1475
1455
else :
1476
1456
block_table = None
1477
1457
1478
- num_rejected_tokens_np = np .zeros (len (self .input_batch .req_ids ))
1479
-
1480
1458
if spec_decode_metadata is None :
1481
1459
# input_ids can be None for multimodal models.
1482
1460
target_token_ids = self .input_ids [:num_scheduled_tokens ]
@@ -1489,7 +1467,6 @@ def execute_model(
1489
1467
target_hidden_states = hidden_states [:num_scheduled_tokens ]
1490
1468
target_slot_mapping = eagle_attn_metadata .slot_mapping
1491
1469
cu_num_tokens = eagle_attn_metadata .query_start_loc
1492
- num_tokens = num_scheduled_tokens
1493
1470
else :
1494
1471
num_draft_tokens_gpu = torch .cat ([
1495
1472
spec_decode_metadata .cu_num_draft_tokens [:1 ],
@@ -1516,30 +1493,52 @@ def execute_model(
1516
1493
target_slot_mapping = eagle_attn_metadata .slot_mapping [
1517
1494
token_indices ]
1518
1495
1519
- if max_gen_len == 1 :
1520
- # No spec decode tokens.
1521
- valid_sampled_token_ids = sampled_token_ids .tolist ()
1522
- else :
1523
- # Includes spec decode tokens.
1524
- valid_sampled_token_ids = self .rejection_sampler .parse_output (
1525
- sampled_token_ids ,
1526
- self .input_batch .vocab_size ,
1527
- )
1528
- # Mask out the sampled tokens that should not be sampled.
1529
- for i in discard_sampled_tokens_req_indices :
1530
- valid_sampled_token_ids [i ].clear ()
1496
+ # Moved from EagleProposer.propose() to here
1497
+ if self .drafter .method == "eagle3" :
1498
+ assert isinstance (self .drafter .model , Eagle3LlamaForCausalLM )
1499
+ target_hidden_states = self .drafter .model .combine_hidden_states (
1500
+ target_hidden_states )
1501
+ assert target_hidden_states .shape [
1502
+ - 1 ] == self .drafter .hidden_size
1503
+
1504
+ # Shift the input ids by one token.
1505
+ # E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
1506
+ self .drafter .input_ids [:num_scheduled_tokens -
1507
+ 1 ] = target_token_ids [:
1508
+ num_scheduled_tokens ][
1509
+ 1 :]
1510
+
1511
+ # Replace the last token with the next token.
1512
+ # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
1513
+ last_token_indices = cu_num_tokens [1 :] - 1
1514
+ self .drafter .input_ids [last_token_indices ] = next_token_ids_gpu
1515
+
1516
+ # FA requires seq_len to have dtype int32.
1517
+ seq_lens = (target_positions [last_token_indices ] + 1 ).int ()
1518
+
1519
+ # copy inputs to buffer for cudagraph
1520
+ self .drafter .positions [:num_scheduled_tokens ] = \
1521
+ target_positions [:num_scheduled_tokens ]
1522
+ self .drafter .hidden_states [:num_scheduled_tokens ] = \
1523
+ target_hidden_states [:num_scheduled_tokens ]
1531
1524
1532
- if self .speculative_config . use_eagle (
1533
- ) and spec_decode_metadata is not None :
1534
- # TODO(woosuk): Refactor this.
1535
- num_draft_tokens = spec_decode_metadata . num_draft_tokens
1525
+ if self .speculative_config and self . speculative_config . use_eagle ():
1526
+ valid_sampled_token_ids = self . get_valid_sampled_token_ids (
1527
+ max_gen_len , sampled_token_ids ,
1528
+ discard_sampled_tokens_req_indices )
1536
1529
1530
+ if spec_decode_metadata is not None :
1531
+ num_draft_tokens = spec_decode_metadata .num_draft_tokens
1537
1532
num_rejected_tokens_np = [
1538
1533
n + 1 - len (valid_sampled_token_ids [i ]) if n > 0 else 0
1539
1534
for i , n in enumerate (num_draft_tokens )
1540
1535
]
1536
+ else :
1537
+ num_rejected_tokens_np = np .zeros (len (
1538
+ self .input_batch .req_ids ))
1541
1539
1542
- num_tokens = num_scheduled_tokens - sum (num_rejected_tokens_np )
1540
+ num_tokens = num_scheduled_tokens - int (
1541
+ sum (num_rejected_tokens_np ))
1543
1542
1544
1543
max_seq_len = int (
1545
1544
(self .seq_lens_np [:num_reqs ] - num_rejected_tokens_np ).max ())
@@ -1550,17 +1549,19 @@ def execute_model(
1550
1549
).max ()) if spec_decode_metadata else max_seq_len
1551
1550
1552
1551
draft_token_ids = self .drafter .propose (
1553
- target_token_ids = target_token_ids [: num_tokens ] ,
1554
- target_positions = target_positions [: num_tokens ] ,
1555
- target_hidden_states = target_hidden_states [: num_tokens ] ,
1556
- target_slot_mapping = target_slot_mapping [: num_tokens ] ,
1552
+ target_token_ids = target_token_ids ,
1553
+ target_positions = target_positions ,
1554
+ target_hidden_states = target_hidden_states ,
1555
+ target_slot_mapping = target_slot_mapping ,
1557
1556
next_token_ids = next_token_ids_gpu ,
1558
1557
cu_num_tokens = cu_num_tokens ,
1559
1558
block_table = block_table ,
1560
1559
sampling_metadata = sampling_metadata ,
1561
- max_seq_len = max_seq_len ,
1560
+ num_tokens = num_tokens ,
1562
1561
max_num_tokens = max_num_tokens ,
1563
- )
1562
+ seq_lens = seq_lens ,
1563
+ max_seq_len = max_seq_len ,
1564
+ last_token_indices = last_token_indices )
1564
1565
spec_token_ids = draft_token_ids .tolist ()
1565
1566
1566
1567
# Clear KVConnector state after all KVs are generated.
@@ -1578,6 +1579,24 @@ def execute_model(
1578
1579
finished_recving = finished_recving ,
1579
1580
)
1580
1581
1582
+ def get_valid_sampled_token_ids (
1583
+ self , max_gen_len : int , sampled_token_ids : torch .Tensor ,
1584
+ discard_sampled_tokens_req_indices : np .ndarray ) -> list [list [int ]]:
1585
+ if max_gen_len == 1 :
1586
+ # No spec decode tokens.
1587
+ valid_sampled_token_ids = sampled_token_ids .tolist ()
1588
+ else :
1589
+ # Includes spec decode tokens.
1590
+ valid_sampled_token_ids = self .rejection_sampler .parse_output (
1591
+ sampled_token_ids ,
1592
+ self .input_batch .vocab_size ,
1593
+ )
1594
+ # Mask out the sampled tokens that should not be sampled.
1595
+ for i in discard_sampled_tokens_req_indices :
1596
+ valid_sampled_token_ids [i ].clear ()
1597
+
1598
+ return valid_sampled_token_ids
1599
+
1581
1600
def kv_connector_no_forward (
1582
1601
self , scheduler_output : "SchedulerOutput" ) -> ModelRunnerOutput :
1583
1602
# KV send/recv even if no work to do.
0 commit comments