@@ -124,6 +124,7 @@ def __init__(
124
124
cache_config .cache_dtype ]
125
125
126
126
self .is_multimodal_model = model_config .is_multimodal_model
127
+ self .is_pooling_model = model_config .is_pooling_model
127
128
self .model_supports_multimodal_raw_input = model_config .model_supports_multimodal_raw_input
128
129
self .max_model_len = model_config .max_model_len
129
130
self .max_num_tokens = scheduler_config .max_num_batched_tokens
@@ -557,7 +558,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
557
558
self .input_batch .refresh_metadata ()
558
559
559
560
def _add_multimodal_inputs_to_model_args (self , model_kwargs : dict [str , Any ],
560
- scheduler_output : "SchedulerOutput" ):
561
+ scheduler_output : "SchedulerOutput" ,
562
+ num_reqs : int = - 1 ):
561
563
# Multi-modal data.
562
564
if scheduler_output :
563
565
multi_modal_kwargs_list = []
@@ -569,21 +571,20 @@ def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any],
569
571
multi_modal_kwargs = MultiModalKwargs .batch (multi_modal_kwargs_list )
570
572
else :
571
573
# The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data.
572
- dummy_data = self .mm_registry .get_decoder_dummy_data (model_config = self .model_config , seq_len = 1 )
573
- multi_modal_kwargs = MultiModalKwargs .batch ([dummy_data .multi_modal_data ])
574
+ dummy_data = [self .mm_registry .get_decoder_dummy_data (model_config = self .model_config , seq_len = 1 ).multi_modal_data for i in range (num_reqs )]
575
+ # dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1)
576
+ # multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data])
577
+ multi_modal_kwargs = MultiModalKwargs .batch (dummy_data )
574
578
575
579
model_kwargs .update (multi_modal_kwargs )
576
580
577
- def _maybe_add_model_args (self , num_tokens : int ,
581
+ def _maybe_add_multimodal_kwargs (self ,
578
582
model_kwargs : dict [str ,Any ],
579
- scheduler_output : "SchedulerOutput" = None ):
580
-
581
- if self .supports_token_type_ids :
582
- model_kwargs ["token_type_ids" ] = \
583
- self .get_token_type_ids ()[:num_tokens ]
583
+ scheduler_output : "SchedulerOutput" = None ,
584
+ num_reqs : int = - 1 ):
584
585
585
586
if self .model_supports_multimodal_raw_input :
586
- self ._add_multimodal_inputs_to_model_args (model_kwargs , scheduler_output )
587
+ self ._add_multimodal_inputs_to_model_args (model_kwargs , scheduler_output , num_reqs )
587
588
588
589
def _maybe_compute_attn_prefix (
589
590
self ,
@@ -1364,15 +1365,15 @@ def execute_model(
1364
1365
mm_embeds = self ._gather_mm_embeddings (scheduler_output )
1365
1366
else :
1366
1367
mm_embeds = []
1367
-
1368
+
1369
+ model_kwargs : dict [str , Any ] = {}
1368
1370
if self .is_multimodal_model and get_pp_group ().is_first_rank :
1369
1371
# NOTE(woosuk): To unify token ids and soft tokens (vision
1370
1372
# embeddings), we always use embeddings (rather than token ids)
1371
1373
# as input to the multimodal model, even when the input is text.
1372
1374
input_ids = self .input_ids [:num_scheduled_tokens ]
1373
- self ._maybe_add_model_args (num_scheduled_tokens ,
1374
- model_kwargs , scheduler_output )
1375
-
1375
+ self ._maybe_add_multimodal_kwargs (model_kwargs = model_kwargs ,
1376
+ scheduler_output = scheduler_output )
1376
1377
if mm_embeds :
1377
1378
inputs_embeds = self .model .get_input_embeddings (
1378
1379
input_ids , mm_embeds )
@@ -1388,7 +1389,6 @@ def execute_model(
1388
1389
# multimodal models, it is not desirable for performance since
1389
1390
# then the embedding layer is not included in the CUDA graph.
1390
1391
input_ids = self .input_ids [:num_input_tokens ]
1391
- self ._maybe_add_model_args (num_input_tokens , model_kwargs , scheduler_output )
1392
1392
inputs_embeds = None
1393
1393
if self .uses_mrope :
1394
1394
positions = self .mrope_positions [:, :num_input_tokens ]
@@ -2076,8 +2076,9 @@ def _dummy_run(
2076
2076
num_scheduled_tokens ):
2077
2077
model = self .model
2078
2078
model_kwargs : dict [str , Any ] = {}
2079
- self ._maybe_add_model_args (num_tokens , model_kwargs )
2080
2079
if self .is_multimodal_model :
2080
+ self ._maybe_add_multimodal_kwargs (model_kwargs = model_kwargs ,
2081
+ num_reqs = num_reqs )
2081
2082
input_ids = None
2082
2083
inputs_embeds = self .inputs_embeds [:num_tokens ]
2083
2084
else :
0 commit comments