201
201
from vllm .attention .backends .utils import get_mla_dims
202
202
from vllm .attention .ops .merge_attn_states import merge_attn_states
203
203
from vllm .attention .utils .fa_utils import get_flash_attn_version
204
+ from vllm .config import VllmConfig
204
205
from vllm .logger import init_logger
205
206
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
206
207
LinearBase ,
214
215
reoder_batch_to_split_decodes_and_prefills ,
215
216
split_decodes_and_prefills )
216
217
from vllm .v1 .kv_cache_interface import AttentionSpec
217
- from vllm .v1 .worker .block_table import BlockTable
218
218
219
219
try :
220
220
from vllm .vllm_flash_attn import flash_attn_varlen_func
234
234
if TYPE_CHECKING :
235
235
from vllm .v1 .core .sched .output import SchedulerOutput
236
236
from vllm .v1 .worker .gpu_input_batch import InputBatch
237
- from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
238
237
239
238
logger = init_logger (__name__ )
240
239
@@ -377,22 +376,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
377
376
"""
378
377
379
378
def __init__ (self ,
380
- runner : "GPUModelRunner" ,
381
379
kv_cache_spec : AttentionSpec ,
382
- block_table : BlockTable ,
380
+ vllm_config : VllmConfig ,
381
+ device : torch .device ,
383
382
metadata_cls : Optional [type [M ]] = None ):
384
383
self .metadata_cls = metadata_cls \
385
384
if metadata_cls is not None else MLACommonMetadata
386
- self .runner = runner
387
- scheduler_config = runner .scheduler_config
388
- model_config = runner .model_config
389
- cache_config = runner .cache_config
385
+ self .kv_cache_spec = kv_cache_spec
386
+ self .device = device
387
+ scheduler_config = vllm_config .scheduler_config
388
+ self .model_config = vllm_config .model_config
389
+ cache_config = vllm_config .cache_config
390
+ parallel_config = vllm_config .parallel_config
390
391
self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
391
- self .num_heads = model_config .get_num_attention_heads (
392
- runner . parallel_config )
393
- self .mla_dims = get_mla_dims (model_config )
392
+ self .num_heads = self . model_config .get_num_attention_heads (
393
+ parallel_config )
394
+ self .mla_dims = get_mla_dims (self . model_config )
394
395
self .aot_schedule = current_platform .is_cuda ()
395
- self .kv_cache_spec = kv_cache_spec
396
396
397
397
# Dont try to access the runner on AMD
398
398
if self .aot_schedule :
@@ -403,7 +403,7 @@ def __init__(self,
403
403
# Max sure there is enough for 8 full length request or at least
404
404
# 4 pages of cache per request
405
405
max (
406
- 8 * model_config .max_model_len , 4 *
406
+ 8 * self . model_config .max_model_len , 4 *
407
407
scheduler_config .max_num_seqs * cache_config .block_size ),
408
408
# For long-context models try not to over-allocate limiting
409
409
# kv-cache space, limiting it to 64k tokens,
@@ -418,11 +418,10 @@ def __init__(self,
418
418
scheduler_config .max_num_seqs * cache_config .block_size
419
419
self .chunked_prefill_workspace = torch .empty (
420
420
(self .chunked_prefill_workspace_size ,
421
- model_config .get_head_size ()),
422
- dtype = model_config .dtype ,
423
- device = runner . device ,
421
+ self . model_config .get_head_size ()),
422
+ dtype = self . model_config .dtype ,
423
+ device = device ,
424
424
)
425
- self .block_table = block_table
426
425
427
426
self ._use_fi_prefill = use_flashinfer_prefill ()
428
427
self .prefill_metadata_cls = FlashInferPrefillMetadata \
@@ -558,7 +557,7 @@ def build(self,
558
557
# Note(simon): be careful about the CPU <> GPU memory movement in this
559
558
# function. We should avoid GPU -> CPU sync as much as possible because
560
559
# it blocks on all previous kernels.
561
- device = self .runner . device
560
+ device = self .device
562
561
block_table_tensor = common_attn_metadata .block_table_tensor
563
562
slot_mapping = common_attn_metadata .slot_mapping
564
563
@@ -664,7 +663,7 @@ def build(self,
664
663
num_actual_tokens = num_tokens ,
665
664
query_start_loc = query_start_loc ,
666
665
slot_mapping = slot_mapping ,
667
- head_dim = self .runner . model_config .get_head_size (),
666
+ head_dim = self .model_config .get_head_size (),
668
667
# MLACommonMetadata Chunk prefill specific
669
668
num_decodes = num_decodes ,
670
669
num_decode_tokens = num_decode_tokens ,
0 commit comments