@@ -123,6 +123,7 @@ def __init__(
123
123
cache_config .cache_dtype ]
124
124
125
125
self .is_multimodal_model = model_config .is_multimodal_model
126
+ self .is_pooling_model = model_config .is_pooling_model
126
127
self .model_supports_multimodal_raw_input = model_config .model_supports_multimodal_raw_input
127
128
self .max_model_len = model_config .max_model_len
128
129
self .max_num_tokens = scheduler_config .max_num_batched_tokens
@@ -560,7 +561,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
560
561
self .input_batch .refresh_metadata ()
561
562
562
563
def _add_multimodal_inputs_to_model_args (self , model_kwargs : dict [str , Any ],
563
- scheduler_output : "SchedulerOutput" ):
564
+ scheduler_output : "SchedulerOutput" ,
565
+ num_reqs : int = - 1 ):
564
566
# Multi-modal data.
565
567
if scheduler_output :
566
568
multi_modal_kwargs_list = []
@@ -572,21 +574,20 @@ def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any],
572
574
multi_modal_kwargs = MultiModalKwargs .batch (multi_modal_kwargs_list )
573
575
else :
574
576
# The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data.
575
- dummy_data = self .mm_registry .get_decoder_dummy_data (model_config = self .model_config , seq_len = 1 )
576
- multi_modal_kwargs = MultiModalKwargs .batch ([dummy_data .multi_modal_data ])
577
+ dummy_data = [self .mm_registry .get_decoder_dummy_data (model_config = self .model_config , seq_len = 1 ).multi_modal_data for i in range (num_reqs )]
578
+ # dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1)
579
+ # multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data])
580
+ multi_modal_kwargs = MultiModalKwargs .batch (dummy_data )
577
581
578
582
model_kwargs .update (multi_modal_kwargs )
579
583
580
- def _maybe_add_model_args (self , num_tokens : int ,
584
+ def _maybe_add_multimodal_kwargs (self ,
581
585
model_kwargs : dict [str ,Any ],
582
- scheduler_output : "SchedulerOutput" = None ):
583
-
584
- if self .supports_token_type_ids :
585
- model_kwargs ["token_type_ids" ] = \
586
- self .get_token_type_ids ()[:num_tokens ]
586
+ scheduler_output : "SchedulerOutput" = None ,
587
+ num_reqs : int = - 1 ):
587
588
588
589
if self .model_supports_multimodal_raw_input :
589
- self ._add_multimodal_inputs_to_model_args (model_kwargs , scheduler_output )
590
+ self ._add_multimodal_inputs_to_model_args (model_kwargs , scheduler_output , num_reqs )
590
591
591
592
def _maybe_compute_attn_prefix (
592
593
self ,
@@ -1364,15 +1365,15 @@ def execute_model(
1364
1365
mm_embeds = self ._gather_mm_embeddings (scheduler_output )
1365
1366
else :
1366
1367
mm_embeds = []
1367
-
1368
+
1369
+ model_kwargs : dict [str , Any ] = {}
1368
1370
if self .is_multimodal_model and get_pp_group ().is_first_rank :
1369
1371
# NOTE(woosuk): To unify token ids and soft tokens (vision
1370
1372
# embeddings), we always use embeddings (rather than token ids)
1371
1373
# as input to the multimodal model, even when the input is text.
1372
1374
input_ids = self .input_ids [:num_scheduled_tokens ]
1373
- self ._maybe_add_model_args (num_scheduled_tokens ,
1374
- model_kwargs , scheduler_output )
1375
-
1375
+ self ._maybe_add_multimodal_kwargs (model_kwargs = model_kwargs ,
1376
+ scheduler_output = scheduler_output )
1376
1377
if mm_embeds :
1377
1378
inputs_embeds = self .model .get_input_embeddings (
1378
1379
input_ids , mm_embeds )
@@ -1388,7 +1389,6 @@ def execute_model(
1388
1389
# multimodal models, it is not desirable for performance since
1389
1390
# then the embedding layer is not included in the CUDA graph.
1390
1391
input_ids = self .input_ids [:num_input_tokens ]
1391
- self ._maybe_add_model_args (num_input_tokens , model_kwargs , scheduler_output )
1392
1392
inputs_embeds = None
1393
1393
if self .uses_mrope :
1394
1394
positions = self .mrope_positions [:, :num_input_tokens ]
@@ -2053,8 +2053,9 @@ def _dummy_run(
2053
2053
num_scheduled_tokens ):
2054
2054
model = self .model
2055
2055
model_kwargs : dict [str , Any ] = {}
2056
- self ._maybe_add_model_args (num_tokens , model_kwargs )
2057
2056
if self .is_multimodal_model :
2057
+ self ._maybe_add_multimodal_kwargs (model_kwargs = model_kwargs ,
2058
+ num_reqs = num_reqs )
2058
2059
input_ids = None
2059
2060
inputs_embeds = self .inputs_embeds [:num_tokens ]
2060
2061
else :
0 commit comments