Skip to content

Commit 9a01785

Browse files
refactor cpu_attn
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent e796669 commit 9a01785

File tree

1 file changed

+37
-33
lines changed

1 file changed

+37
-33
lines changed

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
AttentionMetadata, AttentionType,
1313
is_quantized_kv_cache)
1414
from vllm.attention.backends.utils import CommonAttentionState
15+
from vllm.config import VllmConfig
1516
from vllm.logger import init_logger
16-
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
17+
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
18+
CommonAttentionMetadata)
1719
from vllm.v1.core.sched.output import SchedulerOutput
1820
from vllm.v1.kv_cache_interface import AttentionSpec
19-
from vllm.v1.worker.block_table import BlockTable
20-
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
2121
from vllm.v1.worker.gpu_input_batch import InputBatch
2222

2323
try:
@@ -309,21 +309,23 @@ def get_seq_len_block_table_args(
309309
raise AttributeError(f"Invalid attention type {str(attn_type)}")
310310

311311

312-
class TorchSDPAMetadataBuilderV1:
312+
class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
313+
314+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
315+
device: torch.device) -> None:
316+
self.kv_cache_spec = kv_cache_spec
317+
self.vllm_config = vllm_config
318+
self.scheduler_config = vllm_config.scheduler_config
313319

314-
def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec,
315-
block_table: BlockTable) -> None:
316-
self.runner = runner
317-
self.block_table = block_table
318320
# For reorder
319-
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
320-
dtype=np.int64)
321-
self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs,
322-
dtype=np.int64)
321+
self.reorder_prompt_req_index_list = np.empty(
322+
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
323+
self.reorder_decode_req_index_list = np.empty(
324+
vllm_config.scheduler_config.max_num_seqs, dtype=np.int64)
323325
self.num_prompt_req: int = 0
324326

325327
self.seq_start_loc_cpu = torch.zeros(
326-
runner.max_num_reqs + 1,
328+
vllm_config.scheduler_config.max_num_seqs + 1,
327329
dtype=torch.int32,
328330
device="cpu",
329331
)
@@ -373,50 +375,52 @@ def reorder_batch(self, input_batch: InputBatch,
373375

374376
return True
375377

376-
def build(self, common_prefix_len: int,
377-
common_attn_metadata: CommonAttentionMetadata):
378+
def build(self,
379+
common_prefix_len: int,
380+
common_attn_metadata: CommonAttentionMetadata,
381+
fast_build: bool = False) -> TorchSDPAMetadata:
378382
num_reqs = common_attn_metadata.num_reqs
379-
num_actual_tokens = common_attn_metadata.num_actual_tokens
380383
max_query_len = common_attn_metadata.max_query_len
381384

382-
runner = self.runner
383-
block_table = self.block_table
384-
seq_lens_np = runner.seq_lens_np[:num_reqs]
385+
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
386+
seq_lens_np = seq_lens_cpu.numpy()
385387
num_prompt_req = self.num_prompt_req
386388
max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
387389
) if num_prompt_req > 0 else 0
388390
max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
389391
) if num_prompt_req < num_reqs else 0
390392
self.seq_start_loc_np[0] = 0
391393
np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
392-
num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item()
393-
num_decode_tokens = runner.query_start_loc_np[num_reqs].item(
394-
) - num_prefill_tokens
395-
slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long()
396-
block_table_tensor = block_table.get_device_tensor()
394+
395+
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
396+
num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item())
397+
num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() -
398+
num_prefill_tokens)
399+
400+
slot_mapping = common_attn_metadata.slot_mapping.long()
401+
block_table_tensor = common_attn_metadata.block_table_tensor
402+
397403
attn_metadata = TorchSDPAMetadata(
398404
num_prefills=num_prompt_req,
399405
num_prefill_tokens=num_prefill_tokens,
400406
num_decode_tokens=num_decode_tokens,
401407
slot_mapping=slot_mapping,
402408
# to ensure inference when chunked_prefill is disabled
403-
seq_lens=runner.seq_lens_cpu[:num_reqs].tolist(),
404-
seq_lens_tensor=runner.
405-
seq_lens_cpu[num_prompt_req:num_reqs], # decode
409+
seq_lens=seq_lens_cpu.tolist(),
410+
seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode
406411
max_decode_seq_len=max_decode_seq_len, # decode
407412
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
408-
chunked_prefill=self.runner.scheduler_config.
409-
chunked_prefill_enabled,
413+
chunked_prefill=self.scheduler_config.chunked_prefill_enabled,
410414
max_query_len=max_query_len,
411415
max_kv_len=max_prefill_seq_len,
412-
prefill_query_start_loc=runner.
413-
query_start_loc_cpu[:num_prompt_req + 1], # prefill
416+
prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req +
417+
1], # prefill
414418
kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
415419
1], # prefill
416420
prefill_block_tables=block_table_tensor[:
417421
num_prompt_req], # prefill
418-
query_start_loc=runner.query_start_loc_cpu[:num_reqs +
419-
1], # for logits index
422+
query_start_loc=query_start_loc_cpu[:num_reqs +
423+
1], # for logits index
420424
multi_modal_placeholder_index_maps=None,
421425
enable_kv_scales_calculation=False,
422426
)

0 commit comments

Comments
 (0)