Skip to content

Commit 19b2d52

Browse files
dont transfer block table
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent d79aed6 commit 19b2d52

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

vllm/v1/attention/backends/flashinfer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ class FlashInferMetadata:
169169
# [0, 3, 6, 8]
170170
# The indptr of the paged kv cache, shape: [batch_size + 1] (CPU for plan)
171171
paged_kv_indptr_cpu: torch.Tensor
172-
# The page indices of the paged kv cache (CPU for plan)
173-
paged_kv_indices_cpu: torch.Tensor
172+
# The page indices of the paged kv cache (on device for plan)
173+
paged_kv_indices: torch.Tensor
174174
# The number of entries in the last page of each request in
175175
# the paged kv cache, shape: [batch_size] (CPU for plan)
176176
paged_kv_last_page_len_cpu: torch.Tensor
@@ -292,7 +292,7 @@ def _plan(self, num_prefills: int, num_decodes: int,
292292
# Ensure CPU tensors are not None
293293
assert attn_metadata.qo_indptr_cpu is not None
294294
assert attn_metadata.paged_kv_indptr_cpu is not None
295-
assert attn_metadata.paged_kv_indices_cpu is not None
295+
assert attn_metadata.paged_kv_indices is not None
296296
assert attn_metadata.paged_kv_last_page_len_cpu is not None
297297

298298
if attn_metadata.use_cascade:
@@ -308,7 +308,7 @@ def _plan(self, num_prefills: int, num_decodes: int,
308308
],
309309
[
310310
attn_metadata.shared_kv_page_indices_cpu,
311-
attn_metadata.paged_kv_indices_cpu
311+
attn_metadata.paged_kv_indices
312312
],
313313
[
314314
attn_metadata.shared_kv_last_page_len_cpu,
@@ -347,7 +347,7 @@ def _plan(self, num_prefills: int, num_decodes: int,
347347
attn_metadata.prefill_wrapper.plan(
348348
qo_indptr_cpu,
349349
attn_metadata.paged_kv_indptr_cpu[prefill_start:],
350-
attn_metadata.paged_kv_indices_cpu,
350+
attn_metadata.paged_kv_indices,
351351
attn_metadata.paged_kv_last_page_len_cpu[prefill_start:],
352352
attn_metadata.num_qo_heads,
353353
attn_metadata.num_kv_heads,
@@ -370,7 +370,7 @@ def _plan(self, num_prefills: int, num_decodes: int,
370370
attn_metadata.num_kv_heads, attn_metadata.head_dim):
371371
attn_metadata.decode_wrapper.plan(
372372
attn_metadata.paged_kv_indptr_cpu[:num_decodes + 1],
373-
attn_metadata.paged_kv_indices_cpu,
373+
attn_metadata.paged_kv_indices,
374374
attn_metadata.paged_kv_last_page_len_cpu[:num_decodes],
375375
attn_metadata.num_qo_heads,
376376
attn_metadata.num_kv_heads,
@@ -438,7 +438,7 @@ def build(self,
438438
dtype=torch.int32,
439439
device='cpu').unsqueeze(0)
440440
< block_table_bounds_cpu.unsqueeze(1))
441-
paged_kv_indices_cpu = block_table_tensor.cpu()[mask_cpu]
441+
paged_kv_indices = block_table_tensor[mask_cpu]
442442

443443
# paged_kv_indptr_cpu: cumulative sum of block_table_bounds_cpu
444444
paged_kv_indptr_cpu = torch.cat([
@@ -462,7 +462,7 @@ def build(self,
462462
qo_indptr=qo_indptr,
463463
qo_indptr_cpu=common_attn_metadata.query_start_loc_cpu,
464464
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
465-
paged_kv_indices_cpu=paged_kv_indices_cpu,
465+
paged_kv_indices=paged_kv_indices,
466466
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
467467
num_qo_heads=self.vllm_config.model_config.get_num_attention_heads(
468468
self.vllm_config.parallel_config),

0 commit comments

Comments
 (0)