Skip to content

Commit 2822ec5

Browse files
committed
add comments and refactors
Signed-off-by: Leo Tian <leo.tian@centml.ai>
1 parent 96e95c7 commit 2822ec5

File tree

2 files changed

+69
-43
lines changed

2 files changed

+69
-43
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.logger import init_logger
1212
from vllm.model_executor.model_loader import get_model
1313
from vllm.model_executor.models import supports_multimodal
14+
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
1415
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
1516
FlashAttentionMetadata)
1617
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -74,6 +75,13 @@ def __init__(
7475
device=device,
7576
dtype=torch.int32)
7677

78+
self.last_token_indices = torch.zeros(self.max_num_tokens,
79+
dtype=torch.int32,
80+
device=device)
81+
self.seq_lens = torch.zeros(self.max_num_tokens,
82+
dtype=torch.int32,
83+
device=device)
84+
7785
def propose(
7886
self,
7987
# [num_tokens]
@@ -93,9 +101,7 @@ def propose(
93101
sampling_metadata: SamplingMetadata,
94102
num_tokens: int,
95103
max_num_tokens: int,
96-
seq_lens: torch.Tensor,
97104
max_seq_len: int,
98-
last_token_indices: torch.Tensor,
99105
) -> torch.Tensor:
100106
batch_size = next_token_ids.shape[0]
101107

@@ -105,7 +111,7 @@ def propose(
105111
max_query_len=max_num_tokens,
106112
query_start_loc=cu_num_tokens,
107113
max_seq_len=max_seq_len,
108-
seq_lens=seq_lens,
114+
seq_lens=self.seq_lens,
109115
block_table=block_table,
110116
slot_mapping=target_slot_mapping[:num_tokens],
111117
# TODO(woosuk): Support cascade attention.
@@ -117,7 +123,7 @@ def propose(
117123
)
118124
elif self.method == "deepseek_mtp":
119125
common_attn_metadata = CommonAttentionMetadata(
120-
query_start_loc=cu_num_tokens, seq_lens=seq_lens)
126+
query_start_loc=cu_num_tokens, seq_lens=self.seq_lens)
121127

122128
assert self.runner is not None
123129

@@ -155,7 +161,7 @@ def propose(
155161
last_hidden_states = ret_hidden_states
156162
else:
157163
last_hidden_states, hidden_states = ret_hidden_states
158-
sample_hidden_states = last_hidden_states[last_token_indices]
164+
sample_hidden_states = last_hidden_states[self.last_token_indices]
159165
logits = self.model.compute_logits(sample_hidden_states, None)
160166
draft_token_ids = logits.argmax(dim=-1)
161167

@@ -171,8 +177,8 @@ def propose(
171177
# Generate the remaining draft tokens.
172178
draft_token_ids_list = [draft_token_ids]
173179

174-
positions = target_positions[last_token_indices]
175-
hidden_states = hidden_states[last_token_indices]
180+
positions = target_positions[self.last_token_indices]
181+
hidden_states = hidden_states[self.last_token_indices]
176182
if self.use_cuda_graph and \
177183
batch_size <= self.cudagraph_batch_sizes[-1]:
178184
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
@@ -290,6 +296,41 @@ def prepare_inputs(
290296
)
291297
return cu_num_tokens, token_indices
292298

299+
def load_inputs(self, target_token_ids: torch.Tensor,
300+
target_positions: torch.Tensor,
301+
target_hidden_states: torch.Tensor,
302+
next_token_ids_gpu: torch.Tensor,
303+
cu_num_tokens: torch.Tensor, num_scheduled_tokens: int):
304+
# Loads token ids, positions, etc. into the eagle model
305+
# Moved from EagleProposer.propose() to here
306+
self.last_token_indices = cu_num_tokens[1:] - 1
307+
308+
# FA requires seq_len to have dtype int32.
309+
self.seq_lens = (target_positions[self.last_token_indices] + 1).int()
310+
311+
if self.method == "eagle3":
312+
assert isinstance(self.model, Eagle3LlamaForCausalLM)
313+
target_hidden_states = self.model.combine_hidden_states(
314+
target_hidden_states)
315+
assert target_hidden_states.shape[-1] == self.hidden_size
316+
317+
# Shift the input ids by one token.
318+
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
319+
self.input_ids[:num_scheduled_tokens -
320+
1] = target_token_ids[:num_scheduled_tokens][1:]
321+
322+
# Replace the last token with the next token.
323+
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
324+
self.input_ids[self.last_token_indices] = next_token_ids_gpu
325+
326+
# copy inputs to buffer for cudagraph
327+
self.positions[:
328+
num_scheduled_tokens] = target_positions[:
329+
num_scheduled_tokens]
330+
self.hidden_states[:
331+
num_scheduled_tokens] = target_hidden_states[:
332+
num_scheduled_tokens]
333+
293334
def load_model(self, target_model: nn.Module) -> None:
294335
draft_model_config = \
295336
self.vllm_config.speculative_config.draft_model_config

vllm/v1/worker/gpu_model_runner.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from vllm.logger import init_logger
3131
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
3232
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
33-
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
3433
from vllm.multimodal import MULTIMODAL_REGISTRY
3534
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
3635
from vllm.multimodal.utils import group_mm_inputs_by_modality
@@ -626,6 +625,7 @@ def _prepare_inputs(
626625
self.query_start_loc_np[0] = 0
627626
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
628627

628+
# Prepare seq_len and num_token for eagle metadata
629629
self.seq_lens_np[:num_reqs] = (
630630
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
631631
num_scheduled_tokens)
@@ -635,9 +635,12 @@ def _prepare_inputs(
635635
]
636636
num_tokens_np = np.array(num_tokens, dtype=np.int32)
637637

638+
# Record the index of requests that should not be sampled,
639+
# so that we could clear the sampled tokens before returning
638640
self.discard_req_np[:num_reqs] = \
639641
self.seq_lens_np[:num_reqs] < num_tokens_np
640642

643+
# Also record indices of requests that should be sampled
641644
self.remaining_req_count = np.count_nonzero(
642645
self.discard_req_np[:num_reqs] == 0)
643646
self.remaining_req_indices_np[:self.remaining_req_count] = np.nonzero(
@@ -647,13 +650,14 @@ def _prepare_inputs(
647650
self.remaining_req_indices_cpu[:self.remaining_req_count],
648651
non_blocking=True)
649652

653+
# Precompute get_token_id for when there is no valid next token
650654
self.backup_next_token_ids_np[:num_reqs] = np.array([
651655
self.requests[self.input_batch.req_ids[i]].get_token_id(
652656
self.seq_lens_np[i]) for i in range(num_reqs)
653657
])
654658

655659
self.backup_next_token_ids[:num_reqs].copy_(
656-
self.backup_next_token_ids_cpu[:num_reqs])
660+
self.backup_next_token_ids_cpu[:num_reqs], non_blocking=True)
657661

658662
# Copy the tensors to the GPU.
659663
self.input_ids[:total_num_scheduled_tokens].copy_(
@@ -1418,9 +1422,11 @@ def execute_model(
14181422
elif self.speculative_config.use_eagle():
14191423
assert isinstance(self.drafter, EagleProposer)
14201424

1425+
# Get all sampled tokens from valid requests
14211426
valid_sampled_token_ids_gpu = sampled_token_ids[
14221427
self.remaining_req_indices[:self.remaining_req_count]]
14231428

1429+
# Generate a mask for all valid tokens within those requests
14241430
if max_gen_len == 1:
14251431
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
14261432
dtype=torch.bool)
@@ -1429,19 +1435,21 @@ def execute_model(
14291435
(valid_sampled_token_ids_gpu
14301436
< self.input_batch.vocab_size))
14311437

1438+
# Count valid tokens in each request
14321439
valid_sampled_count = valid_mask.sum(dim=1)
14331440

14341441
batch = valid_sampled_token_ids_gpu.shape[0]
14351442

14361443
# Get the rightmost valid index per row
14371444
last_valid_indices = valid_sampled_count - 1
14381445

1439-
# Fill with -1 first (or PLACEHOLDER_ID)
1440-
# tokens selected for every row (valid or not)
1446+
# Get last valid token from each row
1447+
# (assume undefined state where there is no valid token)
14411448
selected_tokens = torch.gather(
14421449
valid_sampled_token_ids_gpu, 1,
14431450
last_valid_indices.unsqueeze(1)).squeeze(1)
14441451

1452+
# Use last token if valid, pre-computed backup if not
14451453
next_token_ids_gpu = torch.where(
14461454
last_valid_indices != -1, selected_tokens,
14471455
self.backup_next_token_ids[:batch])
@@ -1470,8 +1478,9 @@ def execute_model(
14701478
target_slot_mapping = eagle_attn_metadata.slot_mapping
14711479
cu_num_tokens = eagle_attn_metadata.query_start_loc
14721480
else:
1481+
# Recompute num_draft_tokens from cumsum
14731482
num_draft_tokens_gpu = torch.cat([
1474-
spec_decode_metadata.cu_num_draft_tokens[:1],
1483+
spec_decode_metadata.cu_num_draft_tokens[0:1],
14751484
spec_decode_metadata.cu_num_draft_tokens[1:] -
14761485
spec_decode_metadata.cu_num_draft_tokens[:-1]
14771486
])
@@ -1495,34 +1504,10 @@ def execute_model(
14951504
target_slot_mapping = eagle_attn_metadata.slot_mapping[
14961505
token_indices]
14971506

1498-
# Moved from EagleProposer.propose() to here
1499-
if self.drafter.method == "eagle3":
1500-
assert isinstance(self.drafter.model, Eagle3LlamaForCausalLM)
1501-
target_hidden_states = self.drafter.model.combine_hidden_states(
1502-
target_hidden_states)
1503-
assert target_hidden_states.shape[
1504-
-1] == self.drafter.hidden_size
1505-
1506-
# Shift the input ids by one token.
1507-
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
1508-
self.drafter.input_ids[:num_scheduled_tokens -
1509-
1] = target_token_ids[:
1510-
num_scheduled_tokens][
1511-
1:]
1512-
1513-
# Replace the last token with the next token.
1514-
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
1515-
last_token_indices = cu_num_tokens[1:] - 1
1516-
self.drafter.input_ids[last_token_indices] = next_token_ids_gpu
1517-
1518-
# FA requires seq_len to have dtype int32.
1519-
seq_lens = (target_positions[last_token_indices] + 1).int()
1520-
1521-
# copy inputs to buffer for cudagraph
1522-
self.drafter.positions[:num_scheduled_tokens] = \
1523-
target_positions[:num_scheduled_tokens]
1524-
self.drafter.hidden_states[:num_scheduled_tokens] = \
1525-
target_hidden_states[:num_scheduled_tokens]
1507+
# load token ids, positions, etc. into the eagle model
1508+
self.drafter.load_inputs(target_token_ids, target_positions,
1509+
target_hidden_states, next_token_ids_gpu,
1510+
cu_num_tokens, num_scheduled_tokens)
15261511

15271512
if self.speculative_config and self.speculative_config.use_eagle():
15281513
valid_sampled_token_ids = self.get_valid_sampled_token_ids(
@@ -1561,9 +1546,7 @@ def execute_model(
15611546
sampling_metadata=sampling_metadata,
15621547
num_tokens=num_tokens,
15631548
max_num_tokens=max_num_tokens,
1564-
seq_lens=seq_lens,
1565-
max_seq_len=max_seq_len,
1566-
last_token_indices=last_token_indices)
1549+
max_seq_len=max_seq_len)
15671550
spec_token_ids = draft_token_ids.tolist()
15681551

15691552
# Clear KVConnector state after all KVs are generated.
@@ -1584,6 +1567,8 @@ def execute_model(
15841567
def get_valid_sampled_token_ids(
15851568
self, max_gen_len: int, sampled_token_ids: torch.Tensor,
15861569
discard_sampled_tokens_req_indices: np.ndarray) -> list[list[int]]:
1570+
# Returns valid sampled tokens in a list of lists based on
1571+
# max gen length and discard indices
15871572
if max_gen_len == 1:
15881573
# No spec decode tokens.
15891574
valid_sampled_token_ids = sampled_token_ids.tolist()

0 commit comments

Comments
 (0)