|
12 | 12 | AttentionMetadata, AttentionType,
|
13 | 13 | is_quantized_kv_cache)
|
14 | 14 | from vllm.attention.backends.utils import CommonAttentionState
|
| 15 | +from vllm.config import VllmConfig |
15 | 16 | from vllm.logger import init_logger
|
16 |
| -from vllm.v1.attention.backends.utils import CommonAttentionMetadata |
| 17 | +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, |
| 18 | + CommonAttentionMetadata) |
17 | 19 | from vllm.v1.core.sched.output import SchedulerOutput
|
18 | 20 | from vllm.v1.kv_cache_interface import AttentionSpec
|
19 |
| -from vllm.v1.worker.block_table import BlockTable |
20 |
| -from vllm.v1.worker.cpu_model_runner import CPUModelRunner |
21 | 21 | from vllm.v1.worker.gpu_input_batch import InputBatch
|
22 | 22 |
|
23 | 23 | try:
|
@@ -309,21 +309,23 @@ def get_seq_len_block_table_args(
|
309 | 309 | raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
310 | 310 |
|
311 | 311 |
|
312 |
| -class TorchSDPAMetadataBuilderV1: |
| 312 | +class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): |
| 313 | + |
| 314 | + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, |
| 315 | + device: torch.device) -> None: |
| 316 | + self.kv_cache_spec = kv_cache_spec |
| 317 | + self.vllm_config = vllm_config |
| 318 | + self.scheduler_config = vllm_config.scheduler_config |
313 | 319 |
|
314 |
| - def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec, |
315 |
| - block_table: BlockTable) -> None: |
316 |
| - self.runner = runner |
317 |
| - self.block_table = block_table |
318 | 320 | # For reorder
|
319 |
| - self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs, |
320 |
| - dtype=np.int64) |
321 |
| - self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs, |
322 |
| - dtype=np.int64) |
| 321 | + self.reorder_prompt_req_index_list = np.empty( |
| 322 | + vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) |
| 323 | + self.reorder_decode_req_index_list = np.empty( |
| 324 | + vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) |
323 | 325 | self.num_prompt_req: int = 0
|
324 | 326 |
|
325 | 327 | self.seq_start_loc_cpu = torch.zeros(
|
326 |
| - runner.max_num_reqs + 1, |
| 328 | + vllm_config.scheduler_config.max_num_seqs + 1, |
327 | 329 | dtype=torch.int32,
|
328 | 330 | device="cpu",
|
329 | 331 | )
|
@@ -373,50 +375,52 @@ def reorder_batch(self, input_batch: InputBatch,
|
373 | 375 |
|
374 | 376 | return True
|
375 | 377 |
|
376 |
| - def build(self, common_prefix_len: int, |
377 |
| - common_attn_metadata: CommonAttentionMetadata): |
| 378 | + def build(self, |
| 379 | + common_prefix_len: int, |
| 380 | + common_attn_metadata: CommonAttentionMetadata, |
| 381 | + fast_build: bool = False) -> TorchSDPAMetadata: |
378 | 382 | num_reqs = common_attn_metadata.num_reqs
|
379 |
| - num_actual_tokens = common_attn_metadata.num_actual_tokens |
380 | 383 | max_query_len = common_attn_metadata.max_query_len
|
381 | 384 |
|
382 |
| - runner = self.runner |
383 |
| - block_table = self.block_table |
384 |
| - seq_lens_np = runner.seq_lens_np[:num_reqs] |
| 385 | + seq_lens_cpu = common_attn_metadata.seq_lens_cpu |
| 386 | + seq_lens_np = seq_lens_cpu.numpy() |
385 | 387 | num_prompt_req = self.num_prompt_req
|
386 | 388 | max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item(
|
387 | 389 | ) if num_prompt_req > 0 else 0
|
388 | 390 | max_decode_seq_len = seq_lens_np[num_prompt_req:num_reqs].max().item(
|
389 | 391 | ) if num_prompt_req < num_reqs else 0
|
390 | 392 | self.seq_start_loc_np[0] = 0
|
391 | 393 | np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1])
|
392 |
| - num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item() |
393 |
| - num_decode_tokens = runner.query_start_loc_np[num_reqs].item( |
394 |
| - ) - num_prefill_tokens |
395 |
| - slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long() |
396 |
| - block_table_tensor = block_table.get_device_tensor() |
| 394 | + |
| 395 | + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu |
| 396 | + num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item()) |
| 397 | + num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() - |
| 398 | + num_prefill_tokens) |
| 399 | + |
| 400 | + slot_mapping = common_attn_metadata.slot_mapping.long() |
| 401 | + block_table_tensor = common_attn_metadata.block_table_tensor |
| 402 | + |
397 | 403 | attn_metadata = TorchSDPAMetadata(
|
398 | 404 | num_prefills=num_prompt_req,
|
399 | 405 | num_prefill_tokens=num_prefill_tokens,
|
400 | 406 | num_decode_tokens=num_decode_tokens,
|
401 | 407 | slot_mapping=slot_mapping,
|
402 | 408 | # to ensure inference when chunked_prefill is disabled
|
403 |
| - seq_lens=runner.seq_lens_cpu[:num_reqs].tolist(), |
404 |
| - seq_lens_tensor=runner. |
405 |
| - seq_lens_cpu[num_prompt_req:num_reqs], # decode |
| 409 | + seq_lens=seq_lens_cpu.tolist(), |
| 410 | + seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode |
406 | 411 | max_decode_seq_len=max_decode_seq_len, # decode
|
407 | 412 | block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
|
408 |
| - chunked_prefill=self.runner.scheduler_config. |
409 |
| - chunked_prefill_enabled, |
| 413 | + chunked_prefill=self.scheduler_config.chunked_prefill_enabled, |
410 | 414 | max_query_len=max_query_len,
|
411 | 415 | max_kv_len=max_prefill_seq_len,
|
412 |
| - prefill_query_start_loc=runner. |
413 |
| - query_start_loc_cpu[:num_prompt_req + 1], # prefill |
| 416 | + prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req + |
| 417 | + 1], # prefill |
414 | 418 | kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req +
|
415 | 419 | 1], # prefill
|
416 | 420 | prefill_block_tables=block_table_tensor[:
|
417 | 421 | num_prompt_req], # prefill
|
418 |
| - query_start_loc=runner.query_start_loc_cpu[:num_reqs + |
419 |
| - 1], # for logits index |
| 422 | + query_start_loc=query_start_loc_cpu[:num_reqs + |
| 423 | + 1], # for logits index |
420 | 424 | multi_modal_placeholder_index_maps=None,
|
421 | 425 | enable_kv_scales_calculation=False,
|
422 | 426 | )
|
|
0 commit comments