@@ -169,8 +169,8 @@ class FlashInferMetadata:
169
169
# [0, 3, 6, 8]
170
170
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
171
171
paged_kv_indptr_cpu : torch .Tensor
172
- # The page indices of the paged kv cache (CPU for plan)
173
- paged_kv_indices_cpu : torch .Tensor
172
+ # The page indices of the paged kv cache (on device for plan)
173
+ paged_kv_indices : torch .Tensor
174
174
# The number of entries in the last page of each request in
175
175
# the paged kv cache, shape: [batch_size] (CPU for plan)
176
176
paged_kv_last_page_len_cpu : torch .Tensor
@@ -292,7 +292,7 @@ def _plan(self, num_prefills: int, num_decodes: int,
292
292
# Ensure CPU tensors are not None
293
293
assert attn_metadata .qo_indptr_cpu is not None
294
294
assert attn_metadata .paged_kv_indptr_cpu is not None
295
- assert attn_metadata .paged_kv_indices_cpu is not None
295
+ assert attn_metadata .paged_kv_indices is not None
296
296
assert attn_metadata .paged_kv_last_page_len_cpu is not None
297
297
298
298
if attn_metadata .use_cascade :
@@ -308,7 +308,7 @@ def _plan(self, num_prefills: int, num_decodes: int,
308
308
],
309
309
[
310
310
attn_metadata .shared_kv_page_indices_cpu ,
311
- attn_metadata .paged_kv_indices_cpu
311
+ attn_metadata .paged_kv_indices
312
312
],
313
313
[
314
314
attn_metadata .shared_kv_last_page_len_cpu ,
@@ -347,7 +347,7 @@ def _plan(self, num_prefills: int, num_decodes: int,
347
347
attn_metadata .prefill_wrapper .plan (
348
348
qo_indptr_cpu ,
349
349
attn_metadata .paged_kv_indptr_cpu [prefill_start :],
350
- attn_metadata .paged_kv_indices_cpu ,
350
+ attn_metadata .paged_kv_indices ,
351
351
attn_metadata .paged_kv_last_page_len_cpu [prefill_start :],
352
352
attn_metadata .num_qo_heads ,
353
353
attn_metadata .num_kv_heads ,
@@ -370,7 +370,7 @@ def _plan(self, num_prefills: int, num_decodes: int,
370
370
attn_metadata .num_kv_heads , attn_metadata .head_dim ):
371
371
attn_metadata .decode_wrapper .plan (
372
372
attn_metadata .paged_kv_indptr_cpu [:num_decodes + 1 ],
373
- attn_metadata .paged_kv_indices_cpu ,
373
+ attn_metadata .paged_kv_indices ,
374
374
attn_metadata .paged_kv_last_page_len_cpu [:num_decodes ],
375
375
attn_metadata .num_qo_heads ,
376
376
attn_metadata .num_kv_heads ,
@@ -438,7 +438,7 @@ def build(self,
438
438
dtype = torch .int32 ,
439
439
device = 'cpu' ).unsqueeze (0 )
440
440
< block_table_bounds_cpu .unsqueeze (1 ))
441
- paged_kv_indices_cpu = block_table_tensor . cpu () [mask_cpu ]
441
+ paged_kv_indices = block_table_tensor [mask_cpu ]
442
442
443
443
# paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu
444
444
paged_kv_indptr_cpu = torch .cat ([
@@ -462,7 +462,7 @@ def build(self,
462
462
qo_indptr = qo_indptr ,
463
463
qo_indptr_cpu = common_attn_metadata .query_start_loc_cpu ,
464
464
paged_kv_indptr_cpu = paged_kv_indptr_cpu ,
465
- paged_kv_indices_cpu = paged_kv_indices_cpu ,
465
+ paged_kv_indices = paged_kv_indices ,
466
466
paged_kv_last_page_len_cpu = paged_kv_last_page_len_cpu ,
467
467
num_qo_heads = self .vllm_config .model_config .get_num_attention_heads (
468
468
self .vllm_config .parallel_config ),
0 commit comments