Skip to content

Commit 4160c1e

Browse files
more refactors
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 6f30b98 commit 4160c1e

File tree

3 files changed

+43
-42
lines changed

3 files changed

+43
-42
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@
201201
from vllm.attention.backends.utils import get_mla_dims
202202
from vllm.attention.ops.merge_attn_states import merge_attn_states
203203
from vllm.attention.utils.fa_utils import get_flash_attn_version
204+
from vllm.config import VllmConfig
204205
from vllm.logger import init_logger
205206
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
206207
LinearBase,
@@ -214,7 +215,6 @@
214215
reoder_batch_to_split_decodes_and_prefills,
215216
split_decodes_and_prefills)
216217
from vllm.v1.kv_cache_interface import AttentionSpec
217-
from vllm.v1.worker.block_table import BlockTable
218218

219219
try:
220220
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -234,7 +234,6 @@
234234
if TYPE_CHECKING:
235235
from vllm.v1.core.sched.output import SchedulerOutput
236236
from vllm.v1.worker.gpu_input_batch import InputBatch
237-
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
238237

239238
logger = init_logger(__name__)
240239

@@ -377,22 +376,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
377376
"""
378377

379378
def __init__(self,
380-
runner: "GPUModelRunner",
381379
kv_cache_spec: AttentionSpec,
382-
block_table: BlockTable,
380+
vllm_config: VllmConfig,
381+
device: torch.device,
383382
metadata_cls: Optional[type[M]] = None):
384383
self.metadata_cls = metadata_cls \
385384
if metadata_cls is not None else MLACommonMetadata
386-
self.runner = runner
387-
scheduler_config = runner.scheduler_config
388-
model_config = runner.model_config
389-
cache_config = runner.cache_config
385+
self.kv_cache_spec = kv_cache_spec
386+
self.device = device
387+
scheduler_config = vllm_config.scheduler_config
388+
self.model_config = vllm_config.model_config
389+
cache_config = vllm_config.cache_config
390+
parallel_config = vllm_config.parallel_config
390391
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
391-
self.num_heads = model_config.get_num_attention_heads(
392-
runner.parallel_config)
393-
self.mla_dims = get_mla_dims(model_config)
392+
self.num_heads = self.model_config.get_num_attention_heads(
393+
parallel_config)
394+
self.mla_dims = get_mla_dims(self.model_config)
394395
self.aot_schedule = current_platform.is_cuda()
395-
self.kv_cache_spec = kv_cache_spec
396396

397397
# Dont try to access the runner on AMD
398398
if self.aot_schedule:
@@ -403,7 +403,7 @@ def __init__(self,
403403
# Max sure there is enough for 8 full length request or at least
404404
# 4 pages of cache per request
405405
max(
406-
8 * model_config.max_model_len, 4 *
406+
8 * self.model_config.max_model_len, 4 *
407407
scheduler_config.max_num_seqs * cache_config.block_size),
408408
# For long-context models try not to over-allocate limiting
409409
# kv-cache space, limiting it to 64k tokens,
@@ -418,11 +418,10 @@ def __init__(self,
418418
scheduler_config.max_num_seqs * cache_config.block_size
419419
self.chunked_prefill_workspace = torch.empty(
420420
(self.chunked_prefill_workspace_size,
421-
model_config.get_head_size()),
422-
dtype=model_config.dtype,
423-
device=runner.device,
421+
self.model_config.get_head_size()),
422+
dtype=self.model_config.dtype,
423+
device=device,
424424
)
425-
self.block_table = block_table
426425

427426
self._use_fi_prefill = use_flashinfer_prefill()
428427
self.prefill_metadata_cls = FlashInferPrefillMetadata \
@@ -558,7 +557,7 @@ def build(self,
558557
# Note(simon): be careful about the CPU <> GPU memory movement in this
559558
# function. We should avoid GPU -> CPU sync as much as possible because
560559
# it blocks on all previous kernels.
561-
device = self.runner.device
560+
device = self.device
562561
block_table_tensor = common_attn_metadata.block_table_tensor
563562
slot_mapping = common_attn_metadata.slot_mapping
564563

@@ -664,7 +663,7 @@ def build(self,
664663
num_actual_tokens=num_tokens,
665664
query_start_loc=query_start_loc,
666665
slot_mapping=slot_mapping,
667-
head_dim=self.runner.model_config.get_head_size(),
666+
head_dim=self.model_config.get_head_size(),
668667
# MLACommonMetadata Chunk prefill specific
669668
num_decodes=num_decodes,
670669
num_decode_tokens=num_decode_tokens,

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
1212
get_mla_metadata,
1313
is_flashmla_supported)
14+
from vllm.config import VllmConfig
1415
from vllm.logger import init_logger
1516
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
1617
MLACommonDecodeMetadata,
1718
MLACommonImpl,
1819
MLACommonMetadata,
1920
MLACommonMetadataBuilder)
2021
from vllm.v1.kv_cache_interface import AttentionSpec
21-
from vllm.v1.worker.block_table import BlockTable
2222

2323
logger = init_logger(__name__)
2424

@@ -56,12 +56,13 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
5656
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
5757
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
5858

59-
def __init__(self, runner, kv_cache_spec: AttentionSpec,
60-
block_table: BlockTable):
61-
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
59+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
60+
device: torch.device):
61+
super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata)
6262

63-
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
64-
self.runner.parallel_config)
63+
self.compilation_config = vllm_config.compilation_config
64+
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
65+
vllm_config.parallel_config)
6566

6667
self.cg_buf_tile_scheduler_metadata = None
6768
self.cg_buf_num_splits = None
@@ -75,7 +76,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
7576
1, # MQA for the decode path
7677
)
7778

78-
if self.runner.full_cuda_graph:
79+
if self.compilation_config.full_cuda_graph:
7980
# First time around (CUDAGraph capture), allocate the static buffer
8081
if self.cg_buf_tile_scheduler_metadata is None:
8182
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import vllm.envs as envs
1010
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
11+
from vllm.config import VllmConfig
1112
# yapf conflicts with isort for this docstring
1213
# yapf: disable
1314
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
@@ -16,7 +17,6 @@
1617
MLACommonMetadata,
1718
MLACommonMetadataBuilder)
1819
from vllm.v1.kv_cache_interface import AttentionSpec
19-
from vllm.v1.worker.block_table import BlockTable
2020

2121
# yapf: enable
2222

@@ -65,24 +65,25 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
6565
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
6666
full_cudagraph_supported: ClassVar[bool] = True # decode only
6767

68-
def __init__(self, runner, kv_cache_spec: AttentionSpec,
69-
block_table: BlockTable):
70-
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
68+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
69+
device: torch.device):
70+
super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata)
7171
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
7272
"only supports block size 1."
7373

74+
self.compilation_config = vllm_config.compilation_config
75+
7476
# Preparing persistent buffers
75-
if self.runner.full_cuda_graph:
76-
device = self.runner.device
77-
max_num_reqs = self.runner.max_num_reqs
77+
if vllm_config.compilation_config.full_cuda_graph:
78+
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
7879
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
7980
dtype=torch.int32,
8081
device=device)
81-
self.paged_kv_indices = torch.zeros(
82-
block_table.get_device_tensor().numel(
83-
), # max num pages possible
84-
dtype=torch.int32,
85-
device=device)
82+
# We'll assume a reasonable max number of pages
83+
max_pages = max_num_reqs * 1024 # Rough estimate
84+
self.paged_kv_indices = torch.zeros(max_pages,
85+
dtype=torch.int32,
86+
device=device)
8687
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
8788
dtype=torch.int32,
8889
device=device)
@@ -96,7 +97,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
9697
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
9798
page_size = self.kv_cache_spec.block_size
9899
block_table_bounds = (seq_lens + page_size - 1) // page_size
99-
device = self.runner.device
100+
device = self.device
101+
num_reqs = seq_lens.size(0)
100102

101103
mask = (torch.arange(block_table_tensor.size(1),
102104
dtype=block_table_tensor.dtype,
@@ -113,8 +115,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
113115
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
114116
])
115117

116-
if self.runner.full_cuda_graph:
117-
num_reqs = self._num_decodes
118+
if self.compilation_config.full_cuda_graph:
118119

119120
num_actual_pages = paged_kv_indices.size(0)
120121

@@ -137,7 +138,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
137138

138139
else:
139140
qo_indptr = torch.arange(0,
140-
self._num_decodes + 1,
141+
num_reqs + 1,
141142
step=1,
142143
dtype=torch.int32,
143144
device=device)

0 commit comments

Comments
 (0)