Skip to content

Commit 7f0d422

Browse files
more refactors
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
1 parent 86338cd commit 7f0d422

File tree

3 files changed

+43
-42
lines changed

3 files changed

+43
-42
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@
201201
from vllm.attention.backends.utils import get_mla_dims
202202
from vllm.attention.ops.merge_attn_states import merge_attn_states
203203
from vllm.attention.utils.fa_utils import get_flash_attn_version
204+
from vllm.config import VllmConfig
204205
from vllm.logger import init_logger
205206
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
206207
LinearBase,
@@ -211,7 +212,6 @@
211212
AttentionMetadataBuilder, CommonAttentionMetadata,
212213
reoder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
213214
from vllm.v1.kv_cache_interface import AttentionSpec
214-
from vllm.v1.worker.block_table import BlockTable
215215

216216
try:
217217
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -225,7 +225,6 @@
225225
if TYPE_CHECKING:
226226
from vllm.v1.core.sched.output import SchedulerOutput
227227
from vllm.v1.worker.gpu_input_batch import InputBatch
228-
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
229228

230229
logger = init_logger(__name__)
231230

@@ -346,22 +345,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
346345
"""
347346

348347
def __init__(self,
349-
runner: "GPUModelRunner",
350348
kv_cache_spec: AttentionSpec,
351-
block_table: BlockTable,
349+
vllm_config: VllmConfig,
350+
device: torch.device,
352351
metadata_cls: Optional[type[M]] = None):
353352
self.metadata_cls = metadata_cls \
354353
if metadata_cls is not None else MLACommonMetadata
355-
self.runner = runner
356-
scheduler_config = runner.scheduler_config
357-
model_config = runner.model_config
358-
cache_config = runner.cache_config
354+
self.kv_cache_spec = kv_cache_spec
355+
self.device = device
356+
scheduler_config = vllm_config.scheduler_config
357+
self.model_config = vllm_config.model_config
358+
cache_config = vllm_config.cache_config
359+
parallel_config = vllm_config.parallel_config
359360
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
360-
self.num_heads = model_config.get_num_attention_heads(
361-
runner.parallel_config)
362-
self.mla_dims = get_mla_dims(model_config)
361+
self.num_heads = self.model_config.get_num_attention_heads(
362+
parallel_config)
363+
self.mla_dims = get_mla_dims(self.model_config)
363364
self.aot_schedule = current_platform.is_cuda()
364-
self.kv_cache_spec = kv_cache_spec
365365

366366
# Dont try to access the runner on AMD
367367
if self.aot_schedule:
@@ -372,7 +372,7 @@ def __init__(self,
372372
# Max sure there is enough for 8 full length request or at least
373373
# 4 pages of cache per request
374374
max(
375-
8 * model_config.max_model_len, 4 *
375+
8 * self.model_config.max_model_len, 4 *
376376
scheduler_config.max_num_seqs * cache_config.block_size),
377377
# For long-context models try not to over-allocate limiting
378378
# kv-cache space, limiting it to 64k tokens,
@@ -387,11 +387,10 @@ def __init__(self,
387387
scheduler_config.max_num_seqs * cache_config.block_size
388388
self.chunked_prefill_workspace = torch.empty(
389389
(self.chunked_prefill_workspace_size,
390-
model_config.get_head_size()),
391-
dtype=model_config.dtype,
392-
device=runner.device,
390+
self.model_config.get_head_size()),
391+
dtype=self.model_config.dtype,
392+
device=device,
393393
)
394-
self.block_table = block_table
395394

396395
def reorder_batch(self, input_batch: "InputBatch",
397396
scheduler_output: "SchedulerOutput") -> bool:
@@ -432,7 +431,7 @@ def build(self,
432431
# Note(simon): be careful about the CPU <> GPU memory movement in this
433432
# function. We should avoid GPU -> CPU sync as much as possible because
434433
# it blocks on all previous kernels.
435-
device = self.runner.device
434+
device = self.device
436435
block_table_tensor = common_attn_metadata.block_table_tensor
437436
slot_mapping = common_attn_metadata.slot_mapping
438437

@@ -538,7 +537,7 @@ def build(self,
538537
num_actual_tokens=num_tokens,
539538
query_start_loc=query_start_loc,
540539
slot_mapping=slot_mapping,
541-
head_dim=self.runner.model_config.get_head_size(),
540+
head_dim=self.model_config.get_head_size(),
542541
# MLACommonMetadata Chunk prefill specific
543542
num_decodes=num_decodes,
544543
num_decode_tokens=num_decode_tokens,

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
1212
get_mla_metadata,
1313
is_flashmla_supported)
14+
from vllm.config import VllmConfig
1415
from vllm.logger import init_logger
1516
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
1617
MLACommonDecodeMetadata,
1718
MLACommonImpl,
1819
MLACommonMetadata,
1920
MLACommonMetadataBuilder)
2021
from vllm.v1.kv_cache_interface import AttentionSpec
21-
from vllm.v1.worker.block_table import BlockTable
2222

2323
logger = init_logger(__name__)
2424

@@ -56,12 +56,13 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
5656
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
5757
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
5858

59-
def __init__(self, runner, kv_cache_spec: AttentionSpec,
60-
block_table: BlockTable):
61-
super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata)
59+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
60+
device: torch.device):
61+
super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata)
6262

63-
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
64-
self.runner.parallel_config)
63+
self.compilation_config = vllm_config.compilation_config
64+
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
65+
vllm_config.parallel_config)
6566

6667
self.cg_buf_tile_scheduler_metadata = None
6768
self.cg_buf_num_splits = None
@@ -75,7 +76,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
7576
1, # MQA for the decode path
7677
)
7778

78-
if self.runner.full_cuda_graph:
79+
if self.compilation_config.full_cuda_graph:
7980
# First time around (CUDAGraph capture), allocate the static buffer
8081
if self.cg_buf_tile_scheduler_metadata is None:
8182
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import vllm.envs as envs
1010
from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd
11+
from vllm.config import VllmConfig
1112
# yapf conflicts with isort for this docstring
1213
# yapf: disable
1314
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
@@ -16,7 +17,6 @@
1617
MLACommonMetadata,
1718
MLACommonMetadataBuilder)
1819
from vllm.v1.kv_cache_interface import AttentionSpec
19-
from vllm.v1.worker.block_table import BlockTable
2020

2121
# yapf: enable
2222

@@ -65,24 +65,25 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
6565
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
6666
full_cudagraph_supported: ClassVar[bool] = True # decode only
6767

68-
def __init__(self, runner, kv_cache_spec: AttentionSpec,
69-
block_table: BlockTable):
70-
super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata)
68+
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
69+
device: torch.device):
70+
super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata)
7171
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
7272
"only supports block size 1."
7373

74+
self.compilation_config = vllm_config.compilation_config
75+
7476
# Preparing persistent buffers
75-
if self.runner.full_cuda_graph:
76-
device = self.runner.device
77-
max_num_reqs = self.runner.max_num_reqs
77+
if vllm_config.compilation_config.full_cuda_graph:
78+
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
7879
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
7980
dtype=torch.int32,
8081
device=device)
81-
self.paged_kv_indices = torch.zeros(
82-
block_table.get_device_tensor().numel(
83-
), # max num pages possible
84-
dtype=torch.int32,
85-
device=device)
82+
# We'll assume a reasonable max number of pages
83+
max_pages = max_num_reqs * 1024 # Rough estimate
84+
self.paged_kv_indices = torch.zeros(max_pages,
85+
dtype=torch.int32,
86+
device=device)
8687
self.paged_kv_last_page_len = torch.zeros(max_num_reqs,
8788
dtype=torch.int32,
8889
device=device)
@@ -96,7 +97,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
9697
seq_lens: torch.Tensor) -> AiterMLADecodeMetadata:
9798
page_size = self.kv_cache_spec.block_size
9899
block_table_bounds = (seq_lens + page_size - 1) // page_size
99-
device = self.runner.device
100+
device = self.device
101+
num_reqs = seq_lens.size(0)
100102

101103
mask = (torch.arange(block_table_tensor.size(1),
102104
dtype=block_table_tensor.dtype,
@@ -113,8 +115,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
113115
block_table_bounds.cumsum(dim=0, dtype=torch.int32)
114116
])
115117

116-
if self.runner.full_cuda_graph:
117-
num_reqs = self._num_decodes
118+
if self.compilation_config.full_cuda_graph:
118119

119120
num_actual_pages = paged_kv_indices.size(0)
120121

@@ -137,7 +138,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
137138

138139
else:
139140
qo_indptr = torch.arange(0,
140-
self._num_decodes + 1,
141+
num_reqs + 1,
141142
step=1,
142143
dtype=torch.int32,
143144
device=device)

0 commit comments

Comments
 (0)