@@ -122,7 +122,7 @@ def __init__(
122
122
cache_config .cache_dtype ]
123
123
124
124
self .is_multimodal_model = model_config .is_multimodal_model
125
- self .is_pooling_model = model_config .pooler_config is not None
125
+ self .model_supports_multimodal_raw_input = model_config .model_supports_multimodal_raw_input
126
126
self .max_model_len = model_config .max_model_len
127
127
self .max_num_tokens = scheduler_config .max_num_batched_tokens
128
128
self .max_num_reqs = scheduler_config .max_num_seqs
@@ -320,6 +320,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
320
320
Returns:
321
321
True if the batch was reordered, False otherwise.
322
322
"""
323
+
324
+ # nothing to be reordered when the mdoel is attention free
325
+ if self .model_config .is_attention_free :
326
+ return False
327
+
323
328
batch_reordered = self .attn_metadata_builders [0 ].reorder_batch (
324
329
self .input_batch , scheduler_output )
325
330
@@ -545,7 +550,23 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
545
550
batch_reordered = self ._may_reorder_batch (scheduler_output )
546
551
547
552
if batch_changed or batch_reordered :
548
- self .input_batch .refresh_sampling_metadata ()
553
+ self .input_batch .refresh ()
554
+
555
+ def _maybe_add_model_args (self , num_tokens : int ,
556
+ model_kwargs : dict [str , Any ],
557
+ scheduler_output : "SchedulerOutput" = None ):
558
+ pass
559
+
560
+ def _maybe_compute_attn_prefix (
561
+ self ,
562
+ scheduler_output : "SchedulerOutput" ,
563
+ ) -> list [int ]:
564
+ return [0 ] * len (self .kv_cache_config .kv_cache_groups )
565
+
566
+ def _maybe_prepare_additional_inputs (self ,
567
+ scheduler_output : "SchedulerOutput" ,
568
+ token_indices : torch .Tensor ):
569
+ pass
549
570
550
571
def _get_cumsum_and_arange (
551
572
self ,
@@ -1012,13 +1033,14 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
1012
1033
curr_group_outputs = self .model .get_multimodal_embeddings (
1013
1034
** batched_mm_inputs )
1014
1035
1015
- sanity_check_mm_encoder_outputs (
1016
- curr_group_outputs ,
1017
- expected_num_items = len (grouped_mm_inputs ),
1018
- )
1036
+ if curr_group_outputs :
1037
+ sanity_check_mm_encoder_outputs (
1038
+ curr_group_outputs ,
1039
+ expected_num_items = len (grouped_mm_inputs ),
1040
+ )
1019
1041
1020
- for output in curr_group_outputs :
1021
- encoder_outputs .append (output )
1042
+ for output in curr_group_outputs :
1043
+ encoder_outputs .append (output )
1022
1044
1023
1045
# Cache the encoder outputs.
1024
1046
for (req_id , input_id , pos_info ), output in zip (
@@ -1304,6 +1326,9 @@ def execute_model(
1304
1326
# embeddings), we always use embeddings (rather than token ids)
1305
1327
# as input to the multimodal model, even when the input is text.
1306
1328
input_ids = self .input_ids [:num_scheduled_tokens ]
1329
+ self ._maybe_add_model_args (num_scheduled_tokens ,
1330
+ model_kwargs , scheduler_output )
1331
+
1307
1332
if mm_embeds :
1308
1333
inputs_embeds = self .model .get_input_embeddings (
1309
1334
input_ids , mm_embeds )
@@ -1319,6 +1344,7 @@ def execute_model(
1319
1344
# multimodal models, it is not desirable for performance since
1320
1345
# then the embedding layer is not included in the CUDA graph.
1321
1346
input_ids = self .input_ids [:num_input_tokens ]
1347
+ self ._maybe_add_model_args (num_input_tokens , model_kwargs , scheduler_output )
1322
1348
inputs_embeds = None
1323
1349
if self .uses_mrope :
1324
1350
positions = self .mrope_positions [:, :num_input_tokens ]
@@ -1352,6 +1378,10 @@ def execute_model(
1352
1378
positions = positions ,
1353
1379
intermediate_tensors = intermediate_tensors ,
1354
1380
inputs_embeds = inputs_embeds ,
1381
+ ** MultiModalKwargs .as_kwargs (
1382
+ model_kwargs ,
1383
+ device = self .device ,
1384
+ )
1355
1385
)
1356
1386
1357
1387
self .maybe_wait_for_kv_save ()
@@ -1939,6 +1969,8 @@ def _dummy_run(
1939
1969
with self .maybe_dummy_run_with_lora (self .lora_config ,
1940
1970
num_scheduled_tokens ):
1941
1971
model = self .model
1972
+ model_kwargs : dict [str , Any ] = {}
1973
+ self ._maybe_add_model_args (num_tokens , model_kwargs )
1942
1974
if self .is_multimodal_model :
1943
1975
input_ids = None
1944
1976
inputs_embeds = self .inputs_embeds [:num_tokens ]
@@ -1973,7 +2005,11 @@ def _dummy_run(
1973
2005
positions = positions ,
1974
2006
intermediate_tensors = intermediate_tensors ,
1975
2007
inputs_embeds = inputs_embeds ,
2008
+ ** MultiModalKwargs .as_kwargs (
2009
+ model_kwargs ,
2010
+ device = self .device )
1976
2011
)
2012
+
1977
2013
if self .use_aux_hidden_state_outputs :
1978
2014
hidden_states , _ = outputs
1979
2015
else :
0 commit comments