201
201
from vllm .attention .backends .utils import get_mla_dims
202
202
from vllm .attention .ops .merge_attn_states import merge_attn_states
203
203
from vllm .attention .utils .fa_utils import get_flash_attn_version
204
+ from vllm .config import VllmConfig
204
205
from vllm .logger import init_logger
205
206
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
206
207
LinearBase ,
211
212
AttentionMetadataBuilder , CommonAttentionMetadata ,
212
213
reoder_batch_to_split_decodes_and_prefills , split_decodes_and_prefills )
213
214
from vllm .v1 .kv_cache_interface import AttentionSpec
214
- from vllm .v1 .worker .block_table import BlockTable
215
215
216
216
try :
217
217
from vllm .vllm_flash_attn import flash_attn_varlen_func
225
225
if TYPE_CHECKING :
226
226
from vllm .v1 .core .sched .output import SchedulerOutput
227
227
from vllm .v1 .worker .gpu_input_batch import InputBatch
228
- from vllm .v1 .worker .gpu_model_runner import GPUModelRunner
229
228
230
229
logger = init_logger (__name__ )
231
230
@@ -346,22 +345,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
346
345
"""
347
346
348
347
def __init__ (self ,
349
- runner : "GPUModelRunner" ,
350
348
kv_cache_spec : AttentionSpec ,
351
- block_table : BlockTable ,
349
+ vllm_config : VllmConfig ,
350
+ device : torch .device ,
352
351
metadata_cls : Optional [type [M ]] = None ):
353
352
self .metadata_cls = metadata_cls \
354
353
if metadata_cls is not None else MLACommonMetadata
355
- self .runner = runner
356
- scheduler_config = runner .scheduler_config
357
- model_config = runner .model_config
358
- cache_config = runner .cache_config
354
+ self .kv_cache_spec = kv_cache_spec
355
+ self .device = device
356
+ scheduler_config = vllm_config .scheduler_config
357
+ self .model_config = vllm_config .model_config
358
+ cache_config = vllm_config .cache_config
359
+ parallel_config = vllm_config .parallel_config
359
360
self .chunked_prefill_enabled = scheduler_config .chunked_prefill_enabled
360
- self .num_heads = model_config .get_num_attention_heads (
361
- runner . parallel_config )
362
- self .mla_dims = get_mla_dims (model_config )
361
+ self .num_heads = self . model_config .get_num_attention_heads (
362
+ parallel_config )
363
+ self .mla_dims = get_mla_dims (self . model_config )
363
364
self .aot_schedule = current_platform .is_cuda ()
364
- self .kv_cache_spec = kv_cache_spec
365
365
366
366
# Dont try to access the runner on AMD
367
367
if self .aot_schedule :
@@ -372,7 +372,7 @@ def __init__(self,
372
372
# Max sure there is enough for 8 full length request or at least
373
373
# 4 pages of cache per request
374
374
max (
375
- 8 * model_config .max_model_len , 4 *
375
+ 8 * self . model_config .max_model_len , 4 *
376
376
scheduler_config .max_num_seqs * cache_config .block_size ),
377
377
# For long-context models try not to over-allocate limiting
378
378
# kv-cache space, limiting it to 64k tokens,
@@ -387,11 +387,10 @@ def __init__(self,
387
387
scheduler_config .max_num_seqs * cache_config .block_size
388
388
self .chunked_prefill_workspace = torch .empty (
389
389
(self .chunked_prefill_workspace_size ,
390
- model_config .get_head_size ()),
391
- dtype = model_config .dtype ,
392
- device = runner . device ,
390
+ self . model_config .get_head_size ()),
391
+ dtype = self . model_config .dtype ,
392
+ device = device ,
393
393
)
394
- self .block_table = block_table
395
394
396
395
def reorder_batch (self , input_batch : "InputBatch" ,
397
396
scheduler_output : "SchedulerOutput" ) -> bool :
@@ -432,7 +431,7 @@ def build(self,
432
431
# Note(simon): be careful about the CPU <> GPU memory movement in this
433
432
# function. We should avoid GPU -> CPU sync as much as possible because
434
433
# it blocks on all previous kernels.
435
- device = self .runner . device
434
+ device = self .device
436
435
block_table_tensor = common_attn_metadata .block_table_tensor
437
436
slot_mapping = common_attn_metadata .slot_mapping
438
437
@@ -538,7 +537,7 @@ def build(self,
538
537
num_actual_tokens = num_tokens ,
539
538
query_start_loc = query_start_loc ,
540
539
slot_mapping = slot_mapping ,
541
- head_dim = self .runner . model_config .get_head_size (),
540
+ head_dim = self .model_config .get_head_size (),
542
541
# MLACommonMetadata Chunk prefill specific
543
542
num_decodes = num_decodes ,
544
543
num_decode_tokens = num_decode_tokens ,
0 commit comments