Skip to content

Commit 6f67282

Browse files
committed
add docstrings
Signed-off-by: Leo Tian <leo.tian@centml.ai>
1 parent 0bceac9 commit 6f67282

File tree

5 files changed

+390
-149
lines changed

5 files changed

+390
-149
lines changed

requirements/test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,4 +874,4 @@ yarl==1.17.1
874874
# aiohttp
875875
# schemathesis
876876
zstandard==0.23.0
877-
# via lm-eval
877+
# via lm-eval

vllm/v1/spec_decode/eagle.py

Lines changed: 126 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
from vllm.model_executor.model_loader import get_model
1313
from vllm.model_executor.models import supports_multimodal
1414
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
15+
from vllm.triton_utils import triton
1516
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
1617
FlashAttentionMetadata)
1718
from vllm.v1.kv_cache_interface import KVCacheConfig
1819
from vllm.v1.sample.metadata import SamplingMetadata
19-
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel
20+
from vllm.v1.spec_decode.utils import (advance_state_kernel,
21+
prepare_eagle_input_kernel)
2022

2123
logger = init_logger(__name__)
2224

@@ -75,6 +77,14 @@ def __init__(
7577
device=device,
7678
dtype=torch.int32)
7779

80+
# Used to store precomputed values from load_model() so they can be used in propose()
81+
self.last_token_indices = torch.zeros(self.max_num_tokens,
82+
dtype=torch.int32,
83+
device=device)
84+
self.seq_lens = torch.zeros(self.max_num_tokens,
85+
dtype=torch.int32,
86+
device=device)
87+
7888
def propose(
7989
self,
8090
# [num_tokens]
@@ -92,40 +102,21 @@ def propose(
92102
# [batch_size, max_num_blocks_per_req]
93103
block_table: torch.Tensor,
94104
sampling_metadata: SamplingMetadata,
105+
num_tokens: int,
106+
max_num_tokens: int,
107+
max_seq_len: int,
95108
) -> torch.Tensor:
96-
num_tokens = target_token_ids.shape[0]
97109
batch_size = next_token_ids.shape[0]
98-
last_token_indices = cu_num_tokens[1:] - 1
99-
100-
if self.method == "eagle3":
101-
assert isinstance(self.model, Eagle3LlamaForCausalLM)
102-
target_hidden_states = self.model.combine_hidden_states(
103-
target_hidden_states)
104-
assert target_hidden_states.shape[-1] == self.hidden_size
105-
106-
# Shift the input ids by one token.
107-
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
108-
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
109-
# Replace the last token with the next token.
110-
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
111-
self.input_ids[last_token_indices] = next_token_ids
112-
113-
# FA requires seq_len to have dtype int32.
114-
seq_lens = (target_positions[last_token_indices] + 1).int()
115110

116111
if self.method in ["eagle", "eagle3"]:
117-
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
118-
max_seq_len = seq_lens.max().item()
119-
max_num_tokens = (cu_num_tokens[1:] -
120-
cu_num_tokens[:-1]).max().item()
121112
attn_metadata = FlashAttentionMetadata(
122113
num_actual_tokens=num_tokens,
123114
max_query_len=max_num_tokens,
124115
query_start_loc=cu_num_tokens,
125116
max_seq_len=max_seq_len,
126-
seq_lens=seq_lens,
117+
seq_lens=self.seq_lens,
127118
block_table=block_table,
128-
slot_mapping=target_slot_mapping,
119+
slot_mapping=target_slot_mapping[:num_tokens],
129120
# TODO(woosuk): Support cascade attention.
130121
use_cascade=False,
131122
common_prefix_len=0,
@@ -134,15 +125,12 @@ def propose(
134125
suffix_kv_lens=None,
135126
)
136127
elif self.method == "deepseek_mtp":
137-
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
138-
max_query_len = query_lens.max().item()
139-
140128
common_attn_metadata = CommonAttentionMetadata(
141129
query_start_loc=cu_num_tokens,
142-
seq_lens=seq_lens,
130+
seq_lens=self.seq_lens,
143131
num_reqs=batch_size,
144132
num_actual_tokens=num_tokens,
145-
max_query_len=max_query_len,
133+
max_query_len=self.max_num_tokens,
146134
)
147135

148136
assert self.runner is not None
@@ -165,9 +153,6 @@ def propose(
165153
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
166154
else:
167155
num_input_tokens = num_tokens
168-
# copy inputs to buffer for cudagraph
169-
self.positions[:num_tokens] = target_positions
170-
self.hidden_states[:num_tokens] = target_hidden_states
171156

172157
with set_forward_context(per_layer_attn_metadata,
173158
self.vllm_config,
@@ -181,7 +166,7 @@ def propose(
181166
last_hidden_states = ret_hidden_states
182167
else:
183168
last_hidden_states, hidden_states = ret_hidden_states
184-
sample_hidden_states = last_hidden_states[last_token_indices]
169+
sample_hidden_states = last_hidden_states[self.last_token_indices]
185170
logits = self.model.compute_logits(sample_hidden_states, None)
186171
draft_token_ids = logits.argmax(dim=-1)
187172

@@ -197,8 +182,8 @@ def propose(
197182
# Generate the remaining draft tokens.
198183
draft_token_ids_list = [draft_token_ids]
199184

200-
positions = target_positions[last_token_indices]
201-
hidden_states = hidden_states[last_token_indices]
185+
positions = target_positions[self.last_token_indices]
186+
hidden_states = hidden_states[self.last_token_indices]
202187
if self.use_cuda_graph and \
203188
batch_size <= self.cudagraph_batch_sizes[-1]:
204189
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
@@ -208,52 +193,12 @@ def propose(
208193
attn_metadata.max_query_len = 1
209194
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
210195
for _ in range(self.num_speculative_tokens - 1):
211-
# Update the inputs.
212-
# cast to int32 is crucial when eagle model is compiled.
213-
# tensor.argmax() returns int64 by default.
214-
input_ids = draft_token_ids_list[-1].int()
215-
positions += 1
216-
217-
# NOTE(woosuk): We should handle the case where the draft model
218-
# generates tokens beyond the max model length. Since it is complex
219-
# to remove such requests from the batch, we keep them in the batch
220-
# but adjust the position ids and slot mappings to avoid the
221-
# out-of-range access during the model execution. The draft tokens
222-
# generated with this adjustment should be ignored.
223-
exceeds_max_model_len = positions >= self.max_model_len
224-
# Mask out the position ids that exceed the max model length.
225-
# Otherwise, we may get out-of-range error in RoPE.
226-
clamped_positions = torch.where(exceeds_max_model_len, 0,
227-
positions)
228-
229-
# Increment the sequence lengths.
230-
attn_metadata.max_seq_len += 1
231-
attn_metadata.seq_lens += 1
232-
# Consider max model length.
233-
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
234-
self.max_model_len)
235-
# For the requests that exceed the max model length, we set the
236-
# sequence length to 1 to minimize their overheads in attention.
237-
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
238-
239-
# Compute the slot mapping.
240-
block_numbers = clamped_positions // self.block_size
241-
block_ids = block_table.gather(dim=1,
242-
index=block_numbers.view(-1, 1))
243-
block_ids = block_ids.view(-1)
244-
attn_metadata.slot_mapping = (block_ids * self.block_size +
245-
clamped_positions % self.block_size)
246-
# Mask out the slot mappings that exceed the max model length.
247-
# Otherwise, the KV cache will be inadvertently updated with the
248-
# padding tokens.
249-
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
250-
PADDING_SLOT_ID)
251196

252-
# copy inputs to buffer for cudagraph
253-
self.input_ids[:batch_size] = input_ids
254-
self.positions[:batch_size] = clamped_positions
255-
self.hidden_states[:batch_size] = hidden_states
197+
self.advance_speculative_state(draft_token_ids_list[-1], positions,
198+
hidden_states, attn_metadata,
199+
batch_size)
256200

201+
# copy inputs to buffer for cudagraph
257202
# Run the model.
258203
with set_forward_context(per_layer_attn_metadata,
259204
self.vllm_config,
@@ -275,6 +220,58 @@ def propose(
275220
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
276221
return draft_token_ids
277222

223+
def advance_speculative_state(self, draft_token_ids: torch.Tensor,
224+
positions: torch.Tensor,
225+
hidden_states: torch.Tensor,
226+
attn_metadata: FlashAttentionMetadata,
227+
batch_size: int):
228+
"""
229+
Advances the speculative decoding state and metadata by one step
230+
231+
Parameters:
232+
----------
233+
draft_token_ids (torch.Tensor): Token IDs generated by the draft model
234+
positions (torch.Tensor): Position indices for the draft tokens
235+
hidden_states (torch.Tensor): Corresponding hidden states for the tokens
236+
attn_metadata (FlashAttentionMetadata): Metadata required for FlashAttention (e.g., sequence lengths, block table).
237+
batch_size (int): Number of sequences to update.
238+
"""
239+
240+
# Calculate number of thread blocks
241+
grid = lambda meta: (triton.cdiv(batch_size, meta['BLOCK_SIZE']), )
242+
attn_metadata.slot_mapping = torch.empty_like(positions)
243+
advance_state_kernel[grid](
244+
# === Input tensors ===
245+
draft_token_ids,
246+
positions,
247+
248+
# === Model input buffers to be updated ===
249+
self.input_ids[:batch_size],
250+
self.positions[:batch_size],
251+
252+
# === Metadata tensors ===
253+
attn_metadata.seq_lens,
254+
attn_metadata.block_table,
255+
attn_metadata.slot_mapping,
256+
257+
# === Scalar configuration ===
258+
self.max_model_len,
259+
self.block_size,
260+
self.max_model_len // self.block_size,
261+
262+
# === Execution control ===
263+
batch_size,
264+
BLOCK_SIZE=1024,
265+
PADDING_SLOT_ID=PADDING_SLOT_ID)
266+
267+
self.hidden_states[:batch_size] = hidden_states
268+
269+
# Increment the sequence lengths.
270+
attn_metadata.max_seq_len += 1
271+
# Consider max model length.
272+
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
273+
self.max_model_len)
274+
278275
@staticmethod
279276
def prepare_inputs(
280277
# [batch_size + 1]
@@ -301,7 +298,7 @@ def prepare_inputs(
301298
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
302299
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
303300
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
304-
token_indices = torch.empty(
301+
token_indices = torch.zeros(
305302
num_tokens,
306303
dtype=torch.int32,
307304
device=cu_target_query_lens.device,
@@ -316,6 +313,54 @@ def prepare_inputs(
316313
)
317314
return cu_num_tokens, token_indices
318315

316+
def load_inputs(self, target_token_ids: torch.Tensor,
317+
target_positions: torch.Tensor,
318+
target_hidden_states: torch.Tensor,
319+
next_token_ids_gpu: torch.Tensor,
320+
cu_num_tokens: torch.Tensor, num_scheduled_tokens: int):
321+
"""
322+
Loads token ids, positions, etc. into the eagle model
323+
324+
Logic moved from EagleProposer.propose() to here
325+
326+
Parameters:
327+
----------
328+
target_token_ids (torch.Tensor): Draft-step token IDs
329+
target_positions (torch.Tensor): Position indices for the tokens
330+
target_hidden_states (torch.Tensor): Corresponding hidden states for the tokens
331+
next_token_ids_gpu (torch.Tensor): Sampled next token IDs to overwrite final token
332+
cu_num_tokens (torch.Tensor): Cumulative number of tokens from prepare_inputs()
333+
num_scheduled_tokens (int): Total number of tokens scheduled
334+
"""
335+
336+
self.last_token_indices = cu_num_tokens[1:] - 1
337+
338+
# FA requires seq_len to have dtype int32.
339+
self.seq_lens = (target_positions[self.last_token_indices] + 1).int()
340+
341+
if self.method == "eagle3":
342+
assert isinstance(self.model, Eagle3LlamaForCausalLM)
343+
target_hidden_states = self.model.combine_hidden_states(
344+
target_hidden_states)
345+
assert target_hidden_states.shape[-1] == self.hidden_size
346+
347+
# Shift the input ids by one token.
348+
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
349+
self.input_ids[:num_scheduled_tokens -
350+
1] = target_token_ids[:num_scheduled_tokens][1:]
351+
352+
# Replace the last token with the next token.
353+
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
354+
self.input_ids[self.last_token_indices] = next_token_ids_gpu
355+
356+
# copy inputs to buffer for cudagraph
357+
self.positions[:
358+
num_scheduled_tokens] = target_positions[:
359+
num_scheduled_tokens]
360+
self.hidden_states[:
361+
num_scheduled_tokens] = target_hidden_states[:
362+
num_scheduled_tokens]
363+
319364
def load_model(self, target_model: nn.Module) -> None:
320365
draft_model_config = \
321366
self.vllm_config.speculative_config.draft_model_config

vllm/v1/spec_decode/utils.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,82 @@ def prepare_eagle_input_kernel(
4444
index_start + offset,
4545
mask=offset < num_tokens,
4646
)
47+
48+
49+
@triton.jit
50+
def advance_state_kernel(
51+
draft_token_ids_ptr,
52+
positions_ptr,
53+
54+
# === Model input buffers to be updated ===
55+
model_input_ids_ptr,
56+
model_positions_ptr,
57+
58+
# === Metadata tensors ===
59+
seq_lens_ptr,
60+
block_table_ptr,
61+
slot_mapping_ptr,
62+
63+
# === Scalar configuration ===
64+
model_max_len: int,
65+
model_block_size: int,
66+
model_block_stride: int,
67+
68+
# === Execution control ===
69+
n_elements: int,
70+
BLOCK_SIZE: tl.constexpr,
71+
PADDING_SLOT_ID: tl.constexpr,
72+
):
73+
# Triton kernel to perform draft model state advancement.
74+
75+
pid = tl.program_id(axis=0)
76+
block_start = pid * BLOCK_SIZE
77+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
78+
mask = offsets < n_elements
79+
draft_token_list_last = tl.load(draft_token_ids_ptr + offsets, mask=mask)
80+
position = tl.load(positions_ptr + offsets, mask=mask)
81+
seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask)
82+
83+
# Update the inputs.
84+
# cast to int32 is crucial when eagle model is compiled.
85+
# tensor.argmax() returns int64 by default.
86+
input_id = draft_token_list_last.cast(tl.int32)
87+
position = position + 1
88+
89+
# NOTE(woosuk): We should handle the case where the draft model
90+
# generates tokens beyond the max model length. Since it is complex
91+
# to remove such requests from the batch, we keep them in the batch
92+
# but adjust the position ids and slot mappings to avoid the
93+
# out-of-range access during the model execution. The draft tokens
94+
# generated with this adjustment should be ignored.
95+
exceeds_max_model_len = position >= model_max_len
96+
# Mask out the position ids that exceed the max model length.
97+
# Otherwise, we may get out-of-range error in RoPE.
98+
clamped_position = tl.where(exceeds_max_model_len, 0, position)
99+
100+
# For the requests that exceed the max model length, we set the
101+
# sequence length to 1 to minimize their overheads in attention.
102+
seq_lens += 1
103+
seq_lens = tl.where(exceeds_max_model_len, 1, seq_lens)
104+
105+
block_numbers = clamped_position // model_block_size
106+
block_offsets = clamped_position % model_block_size
107+
108+
block_ids = tl.load(block_table_ptr + model_block_stride * offsets +
109+
block_numbers,
110+
mask=mask)
111+
112+
# Compute slot mapping
113+
slot_mapping = block_ids * model_block_size + block_offsets
114+
115+
# Mask out the slot mappings that exceed the max model length.
116+
# Otherwise, the KV cache will be inadvertently updated with the
117+
# padding tokens.
118+
slot_mapping = tl.where(exceeds_max_model_len, PADDING_SLOT_ID,
119+
slot_mapping)
120+
121+
tl.store(model_input_ids_ptr + offsets, input_id, mask=mask)
122+
tl.store(positions_ptr + offsets, position, mask=mask)
123+
tl.store(model_positions_ptr + offsets, clamped_position, mask=mask)
124+
tl.store(seq_lens_ptr + offsets, seq_lens, mask=mask)
125+
tl.store(slot_mapping_ptr + offsets, slot_mapping, mask=mask)

0 commit comments

Comments
 (0)