Skip to content

Commit c67eb72

Browse files
committed
fix indexing
Signed-off-by: Leo Tian <leo.tian@centml.ai>
1 parent b2070a4 commit c67eb72

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,12 +204,10 @@ def advance_speculative_state(self, draft_token_ids: torch.Tensor,
204204
# === Input tensors ===
205205
draft_token_ids,
206206
positions,
207-
hidden_states,
208207

209208
# === Model input buffers to be updated ===
210209
self.input_ids[:batch_size],
211210
self.positions[:batch_size],
212-
self.hidden_states[:batch_size],
213211

214212
# === Metadata tensors ===
215213
attn_metadata.seq_lens,
@@ -219,12 +217,15 @@ def advance_speculative_state(self, draft_token_ids: torch.Tensor,
219217
# === Scalar configuration ===
220218
self.max_model_len,
221219
self.block_size,
220+
self.max_model_len // self.block_size,
222221

223222
# === Execution control ===
224223
batch_size,
225224
BLOCK_SIZE=1024,
226225
PADDING_SLOT_ID=PADDING_SLOT_ID)
227226

227+
self.hidden_states[:batch_size] = hidden_states
228+
228229
# Increment the sequence lengths.
229230
attn_metadata.max_seq_len += 1
230231
# Consider max model length.
@@ -419,12 +420,10 @@ def prepare_input_kernel(
419420
def advance_state_kernel(
420421
draft_token_ids_ptr,
421422
positions_ptr,
422-
hidden_states_ptr,
423423

424424
# === Model input buffers to be updated ===
425425
model_input_ids_ptr,
426426
model_positions_ptr,
427-
model_hidden_states_ptr,
428427

429428
# === Metadata tensors ===
430429
seq_lens_ptr,
@@ -434,6 +433,7 @@ def advance_state_kernel(
434433
# === Scalar configuration ===
435434
model_max_len: int,
436435
model_block_size: int,
436+
model_block_stride: int,
437437

438438
# === Execution control ===
439439
n_elements: int,
@@ -447,7 +447,6 @@ def advance_state_kernel(
447447
draft_token_list_last = tl.load(draft_token_ids_ptr + offsets, mask=mask)
448448
position = tl.load(positions_ptr + offsets, mask=mask)
449449
seq_lens = tl.load(seq_lens_ptr + offsets, mask=mask)
450-
hidden_states = tl.load(hidden_states_ptr + offsets, mask=mask)
451450

452451
# Update the inputs.
453452
# cast to int32 is crucial when eagle model is compiled.
@@ -474,8 +473,9 @@ def advance_state_kernel(
474473
block_numbers = clamped_position // model_block_size
475474
block_offsets = clamped_position % model_block_size
476475

477-
# Gather from block_table[0, block_numbers]
478-
block_ids = tl.load(block_table_ptr + block_numbers, mask=mask)
476+
block_ids = tl.load(block_table_ptr + model_block_stride * offsets +
477+
block_numbers,
478+
mask=mask)
479479

480480
# Compute slot mapping
481481
slot_mapping = block_ids * model_block_size + block_offsets
@@ -491,4 +491,3 @@ def advance_state_kernel(
491491
tl.store(model_positions_ptr + offsets, clamped_position, mask=mask)
492492
tl.store(seq_lens_ptr + offsets, seq_lens, mask=mask)
493493
tl.store(slot_mapping_ptr + offsets, slot_mapping, mask=mask)
494-
tl.store(model_hidden_states_ptr + offsets, hidden_states, mask=mask)

0 commit comments

Comments
 (0)