Skip to content

Commit 042e190

Browse files
committed
move sync point further
Signed-off-by: Leo Tian <leo.tian@centml.ai>
1 parent 75c05ea commit 042e190

File tree

3 files changed

+84
-82
lines changed

3 files changed

+84
-82
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from vllm.logger import init_logger
1111
from vllm.model_executor.model_loader import get_model
1212
from vllm.model_executor.models import supports_multimodal
13-
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1413
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
1514
FlashAttentionMetadata)
1615
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -90,29 +89,14 @@ def propose(
9089
cu_num_tokens: torch.Tensor,
9190
# [batch_size, max_num_blocks_per_req]
9291
block_table: torch.Tensor,
93-
max_seq_len: int,
94-
max_num_tokens: int,
9592
sampling_metadata: SamplingMetadata,
93+
num_tokens: int,
94+
max_num_tokens: int,
95+
seq_lens: torch.Tensor,
96+
max_seq_len: int,
97+
last_token_indices: torch.Tensor,
9698
) -> torch.Tensor:
97-
num_tokens = target_token_ids.shape[0]
9899
batch_size = next_token_ids.shape[0]
99-
last_token_indices = cu_num_tokens[1:] - 1
100-
101-
if self.method == "eagle3":
102-
assert isinstance(self.model, Eagle3LlamaForCausalLM)
103-
target_hidden_states = self.model.combine_hidden_states(
104-
target_hidden_states)
105-
assert target_hidden_states.shape[-1] == self.hidden_size
106-
107-
# Shift the input ids by one token.
108-
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
109-
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
110-
# Replace the last token with the next token.
111-
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
112-
self.input_ids[last_token_indices] = next_token_ids
113-
114-
# FA requires seq_len to have dtype int32.
115-
seq_lens = (target_positions[last_token_indices] + 1).int()
116100

117101
if self.method in ["eagle", "eagle3"]:
118102
attn_metadata = FlashAttentionMetadata(
@@ -122,7 +106,7 @@ def propose(
122106
max_seq_len=max_seq_len,
123107
seq_lens=seq_lens,
124108
block_table=block_table,
125-
slot_mapping=target_slot_mapping,
109+
slot_mapping=target_slot_mapping[:num_tokens],
126110
# TODO(woosuk): Support cascade attention.
127111
use_cascade=False,
128112
common_prefix_len=0,
@@ -157,9 +141,6 @@ def propose(
157141
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
158142
else:
159143
num_input_tokens = num_tokens
160-
# copy inputs to buffer for cudagraph
161-
self.positions[:num_tokens] = target_positions
162-
self.hidden_states[:num_tokens] = target_hidden_states
163144

164145
with set_forward_context(per_layer_attn_metadata,
165146
self.vllm_config,

vllm/v1/worker/gpu_input_batch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ def num_tokens(self) -> int:
4848
def get_token_id(self, idx: int) -> int:
4949
if idx < self.num_prompt_tokens:
5050
return self.prompt_token_ids[idx]
51-
else:
51+
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
5252
return self.output_token_ids[idx - self.num_prompt_tokens]
53+
else:
54+
return -1 # Invalid token id
5355

5456

5557
class InputBatch:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 75 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from vllm.logger import init_logger
3030
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
3131
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
32+
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
3233
from vllm.multimodal import MULTIMODAL_REGISTRY
3334
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
3435
from vllm.multimodal.utils import group_mm_inputs_by_modality
@@ -1381,18 +1382,9 @@ def execute_model(
13811382

13821383
if not self.speculative_config or not self.speculative_config.use_eagle(
13831384
):
1384-
if max_gen_len == 1:
1385-
# No spec decode tokens.
1386-
valid_sampled_token_ids = sampled_token_ids.tolist()
1387-
else:
1388-
# Includes spec decode tokens.
1389-
valid_sampled_token_ids = self.rejection_sampler.parse_output(
1390-
sampled_token_ids,
1391-
self.input_batch.vocab_size,
1392-
)
1393-
# Mask out the sampled tokens that should not be sampled.
1394-
for i in discard_sampled_tokens_req_indices:
1395-
valid_sampled_token_ids[i].clear()
1385+
valid_sampled_token_ids = self.get_valid_sampled_token_ids(
1386+
max_gen_len, sampled_token_ids,
1387+
discard_sampled_tokens_req_indices)
13961388

13971389
if not self.speculative_config:
13981390
# Speculative decoding is not enabled.
@@ -1426,44 +1418,32 @@ def execute_model(
14261418
assert isinstance(self.drafter, EagleProposer)
14271419

14281420
valid_sampled_token_ids_gpu = sampled_token_ids[
1429-
self.remaining_req_indices[:self.remaining_req_count], :]
1421+
self.remaining_req_indices[:self.remaining_req_count]]
14301422

14311423
if max_gen_len == 1:
14321424
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
14331425
dtype=torch.bool)
14341426
else:
1435-
# Includes speculative decode tokens — apply rejection mask
14361427
valid_mask = ((valid_sampled_token_ids_gpu != -1) &
14371428
(valid_sampled_token_ids_gpu
14381429
< self.input_batch.vocab_size))
14391430

14401431
valid_sampled_count = valid_mask.sum(dim=1)
14411432

1442-
batch, seq_length = valid_sampled_token_ids_gpu.shape
1443-
device = valid_sampled_token_ids_gpu.device
1444-
1445-
# Compute positions (row-wise) of valid tokens
1446-
indices = torch.arange(seq_length,
1447-
device=device).expand(batch, seq_length)
1448-
masked_indices = torch.where(valid_mask, indices,
1449-
torch.full_like(indices, -1))
1433+
batch = valid_sampled_token_ids_gpu.shape[0]
14501434

14511435
# Get the rightmost valid index per row
1452-
last_valid_indices = masked_indices.max(dim=1).values
1453-
1454-
# Get next_token_ids for common case
1455-
row_indices = torch.arange(batch, device=device)
1456-
has_valid_token = last_valid_indices != -1
1436+
last_valid_indices = valid_sampled_count - 1
14571437

14581438
# Fill with -1 first (or PLACEHOLDER_ID)
14591439
# tokens selected for every row (valid or not)
1460-
selected_tokens = valid_sampled_token_ids_gpu[row_indices,
1440+
selected_tokens = valid_sampled_token_ids_gpu[:batch,
14611441
last_valid_indices]
14621442

1463-
# one-liner: keep backup unless row is valid
14641443
next_token_ids_gpu = torch.where(
1465-
has_valid_token, selected_tokens,
1444+
last_valid_indices != -1, selected_tokens,
14661445
self.backup_next_token_ids[:batch])
1446+
14671447
# At this moment, we assume all eagle layers belong to the same KV
14681448
# cache group, thus using the same attention metadata.
14691449
eagle_attn_metadata = attn_metadata[
@@ -1475,8 +1455,6 @@ def execute_model(
14751455
else:
14761456
block_table = None
14771457

1478-
num_rejected_tokens_np = np.zeros(len(self.input_batch.req_ids))
1479-
14801458
if spec_decode_metadata is None:
14811459
# input_ids can be None for multimodal models.
14821460
target_token_ids = self.input_ids[:num_scheduled_tokens]
@@ -1489,7 +1467,6 @@ def execute_model(
14891467
target_hidden_states = hidden_states[:num_scheduled_tokens]
14901468
target_slot_mapping = eagle_attn_metadata.slot_mapping
14911469
cu_num_tokens = eagle_attn_metadata.query_start_loc
1492-
num_tokens = num_scheduled_tokens
14931470
else:
14941471
num_draft_tokens_gpu = torch.cat([
14951472
spec_decode_metadata.cu_num_draft_tokens[:1],
@@ -1516,30 +1493,52 @@ def execute_model(
15161493
target_slot_mapping = eagle_attn_metadata.slot_mapping[
15171494
token_indices]
15181495

1519-
if max_gen_len == 1:
1520-
# No spec decode tokens.
1521-
valid_sampled_token_ids = sampled_token_ids.tolist()
1522-
else:
1523-
# Includes spec decode tokens.
1524-
valid_sampled_token_ids = self.rejection_sampler.parse_output(
1525-
sampled_token_ids,
1526-
self.input_batch.vocab_size,
1527-
)
1528-
# Mask out the sampled tokens that should not be sampled.
1529-
for i in discard_sampled_tokens_req_indices:
1530-
valid_sampled_token_ids[i].clear()
1496+
# Moved from EagleProposer.propose() to here
1497+
if self.drafter.method == "eagle3":
1498+
assert isinstance(self.drafter.model, Eagle3LlamaForCausalLM)
1499+
target_hidden_states = self.drafter.model.combine_hidden_states(
1500+
target_hidden_states)
1501+
assert target_hidden_states.shape[
1502+
-1] == self.drafter.hidden_size
1503+
1504+
# Shift the input ids by one token.
1505+
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
1506+
self.drafter.input_ids[:num_scheduled_tokens -
1507+
1] = target_token_ids[:
1508+
num_scheduled_tokens][
1509+
1:]
1510+
1511+
# Replace the last token with the next token.
1512+
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
1513+
last_token_indices = cu_num_tokens[1:] - 1
1514+
self.drafter.input_ids[last_token_indices] = next_token_ids_gpu
1515+
1516+
# FA requires seq_len to have dtype int32.
1517+
seq_lens = (target_positions[last_token_indices] + 1).int()
1518+
1519+
# copy inputs to buffer for cudagraph
1520+
self.drafter.positions[:num_scheduled_tokens] = \
1521+
target_positions[:num_scheduled_tokens]
1522+
self.drafter.hidden_states[:num_scheduled_tokens] = \
1523+
target_hidden_states[:num_scheduled_tokens]
15311524

1532-
if self.speculative_config.use_eagle(
1533-
) and spec_decode_metadata is not None:
1534-
# TODO(woosuk): Refactor this.
1535-
num_draft_tokens = spec_decode_metadata.num_draft_tokens
1525+
if self.speculative_config and self.speculative_config.use_eagle():
1526+
valid_sampled_token_ids = self.get_valid_sampled_token_ids(
1527+
max_gen_len, sampled_token_ids,
1528+
discard_sampled_tokens_req_indices)
15361529

1530+
if spec_decode_metadata is not None:
1531+
num_draft_tokens = spec_decode_metadata.num_draft_tokens
15371532
num_rejected_tokens_np = [
15381533
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
15391534
for i, n in enumerate(num_draft_tokens)
15401535
]
1536+
else:
1537+
num_rejected_tokens_np = np.zeros(len(
1538+
self.input_batch.req_ids))
15411539

1542-
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens_np)
1540+
num_tokens = num_scheduled_tokens - int(
1541+
sum(num_rejected_tokens_np))
15431542

15441543
max_seq_len = int(
15451544
(self.seq_lens_np[:num_reqs] - num_rejected_tokens_np).max())
@@ -1550,17 +1549,19 @@ def execute_model(
15501549
).max()) if spec_decode_metadata else max_seq_len
15511550

15521551
draft_token_ids = self.drafter.propose(
1553-
target_token_ids=target_token_ids[:num_tokens],
1554-
target_positions=target_positions[:num_tokens],
1555-
target_hidden_states=target_hidden_states[:num_tokens],
1556-
target_slot_mapping=target_slot_mapping[:num_tokens],
1552+
target_token_ids=target_token_ids,
1553+
target_positions=target_positions,
1554+
target_hidden_states=target_hidden_states,
1555+
target_slot_mapping=target_slot_mapping,
15571556
next_token_ids=next_token_ids_gpu,
15581557
cu_num_tokens=cu_num_tokens,
15591558
block_table=block_table,
15601559
sampling_metadata=sampling_metadata,
1561-
max_seq_len=max_seq_len,
1560+
num_tokens=num_tokens,
15621561
max_num_tokens=max_num_tokens,
1563-
)
1562+
seq_lens=seq_lens,
1563+
max_seq_len=max_seq_len,
1564+
last_token_indices=last_token_indices)
15641565
spec_token_ids = draft_token_ids.tolist()
15651566

15661567
# Clear KVConnector state after all KVs are generated.
@@ -1578,6 +1579,24 @@ def execute_model(
15781579
finished_recving=finished_recving,
15791580
)
15801581

1582+
def get_valid_sampled_token_ids(
1583+
self, max_gen_len: int, sampled_token_ids: torch.Tensor,
1584+
discard_sampled_tokens_req_indices: np.ndarray) -> list[list[int]]:
1585+
if max_gen_len == 1:
1586+
# No spec decode tokens.
1587+
valid_sampled_token_ids = sampled_token_ids.tolist()
1588+
else:
1589+
# Includes spec decode tokens.
1590+
valid_sampled_token_ids = self.rejection_sampler.parse_output(
1591+
sampled_token_ids,
1592+
self.input_batch.vocab_size,
1593+
)
1594+
# Mask out the sampled tokens that should not be sampled.
1595+
for i in discard_sampled_tokens_req_indices:
1596+
valid_sampled_token_ids[i].clear()
1597+
1598+
return valid_sampled_token_ids
1599+
15811600
def kv_connector_no_forward(
15821601
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
15831602
# KV send/recv even if no work to do.

0 commit comments

Comments
 (0)