@@ -204,12 +204,10 @@ def advance_speculative_state(self, draft_token_ids: torch.Tensor,
204
204
# === Input tensors ===
205
205
draft_token_ids ,
206
206
positions ,
207
- hidden_states ,
208
207
209
208
# === Model input buffers to be updated ===
210
209
self .input_ids [:batch_size ],
211
210
self .positions [:batch_size ],
212
- self .hidden_states [:batch_size ],
213
211
214
212
# === Metadata tensors ===
215
213
attn_metadata .seq_lens ,
@@ -219,12 +217,15 @@ def advance_speculative_state(self, draft_token_ids: torch.Tensor,
219
217
# === Scalar configuration ===
220
218
self .max_model_len ,
221
219
self .block_size ,
220
+ self .max_model_len // self .block_size ,
222
221
223
222
# === Execution control ===
224
223
batch_size ,
225
224
BLOCK_SIZE = 1024 ,
226
225
PADDING_SLOT_ID = PADDING_SLOT_ID )
227
226
227
+ self .hidden_states [:batch_size ] = hidden_states
228
+
228
229
# Increment the sequence lengths.
229
230
attn_metadata .max_seq_len += 1
230
231
# Consider max model length.
@@ -419,12 +420,10 @@ def prepare_input_kernel(
419
420
def advance_state_kernel (
420
421
draft_token_ids_ptr ,
421
422
positions_ptr ,
422
- hidden_states_ptr ,
423
423
424
424
# === Model input buffers to be updated ===
425
425
model_input_ids_ptr ,
426
426
model_positions_ptr ,
427
- model_hidden_states_ptr ,
428
427
429
428
# === Metadata tensors ===
430
429
seq_lens_ptr ,
@@ -434,6 +433,7 @@ def advance_state_kernel(
434
433
# === Scalar configuration ===
435
434
model_max_len : int ,
436
435
model_block_size : int ,
436
+ model_block_stride : int ,
437
437
438
438
# === Execution control ===
439
439
n_elements : int ,
@@ -447,7 +447,6 @@ def advance_state_kernel(
447
447
draft_token_list_last = tl .load (draft_token_ids_ptr + offsets , mask = mask )
448
448
position = tl .load (positions_ptr + offsets , mask = mask )
449
449
seq_lens = tl .load (seq_lens_ptr + offsets , mask = mask )
450
- hidden_states = tl .load (hidden_states_ptr + offsets , mask = mask )
451
450
452
451
# Update the inputs.
453
452
# cast to int32 is crucial when eagle model is compiled.
@@ -474,8 +473,9 @@ def advance_state_kernel(
474
473
block_numbers = clamped_position // model_block_size
475
474
block_offsets = clamped_position % model_block_size
476
475
477
- # Gather from block_table[0, block_numbers]
478
- block_ids = tl .load (block_table_ptr + block_numbers , mask = mask )
476
+ block_ids = tl .load (block_table_ptr + model_block_stride * offsets +
477
+ block_numbers ,
478
+ mask = mask )
479
479
480
480
# Compute slot mapping
481
481
slot_mapping = block_ids * model_block_size + block_offsets
@@ -491,4 +491,3 @@ def advance_state_kernel(
491
491
tl .store (model_positions_ptr + offsets , clamped_position , mask = mask )
492
492
tl .store (seq_lens_ptr + offsets , seq_lens , mask = mask )
493
493
tl .store (slot_mapping_ptr + offsets , slot_mapping , mask = mask )
494
- tl .store (model_hidden_states_ptr + offsets , hidden_states , mask = mask )
0 commit comments