@@ -122,6 +122,7 @@ def __init__(
122
122
cache_config .cache_dtype ]
123
123
124
124
self .is_multimodal_model = model_config .is_multimodal_model
125
+ self .is_pooling_model = model_config .is_pooling_model
125
126
self .model_supports_multimodal_raw_input = model_config .model_supports_multimodal_raw_input
126
127
self .max_model_len = model_config .max_model_len
127
128
self .max_num_tokens = scheduler_config .max_num_batched_tokens
@@ -550,10 +551,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
550
551
batch_reordered = self ._may_reorder_batch (scheduler_output )
551
552
552
553
if batch_changed or batch_reordered :
553
- self .input_batch .refresh ()
554
+ self .input_batch .refresh_sampling_metadata ()
554
555
555
556
def _add_multimodal_inputs_to_model_args (self , model_kwargs : dict [str , Any ],
556
- scheduler_output : "SchedulerOutput" ):
557
+ scheduler_output : "SchedulerOutput" ,
558
+ num_reqs : int = - 1 ):
557
559
# Multi-modal data.
558
560
if scheduler_output :
559
561
multi_modal_kwargs_list = []
@@ -565,21 +567,20 @@ def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any],
565
567
multi_modal_kwargs = MultiModalKwargs .batch (multi_modal_kwargs_list )
566
568
else :
567
569
# The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data.
568
- dummy_data = self .mm_registry .get_decoder_dummy_data (model_config = self .model_config , seq_len = 1 )
569
- multi_modal_kwargs = MultiModalKwargs .batch ([dummy_data .multi_modal_data ])
570
+ dummy_data = [self .mm_registry .get_decoder_dummy_data (model_config = self .model_config , seq_len = 1 ).multi_modal_data for i in range (num_reqs )]
571
+ # dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1)
572
+ # multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data])
573
+ multi_modal_kwargs = MultiModalKwargs .batch (dummy_data )
570
574
571
575
model_kwargs .update (multi_modal_kwargs )
572
576
573
- def _maybe_add_model_args (self , num_tokens : int ,
577
+ def _maybe_add_multimodal_kwargs (self ,
574
578
model_kwargs : dict [str ,Any ],
575
- scheduler_output : "SchedulerOutput" = None ):
576
-
577
- if self .supports_token_type_ids :
578
- model_kwargs ["token_type_ids" ] = \
579
- self .get_token_type_ids ()[:num_tokens ]
579
+ scheduler_output : "SchedulerOutput" = None ,
580
+ num_reqs : int = - 1 ):
580
581
581
582
if self .model_supports_multimodal_raw_input :
582
- self ._add_multimodal_inputs_to_model_args (model_kwargs , scheduler_output )
583
+ self ._add_multimodal_inputs_to_model_args (model_kwargs , scheduler_output , num_reqs )
583
584
584
585
def _maybe_compute_attn_prefix (
585
586
self ,
@@ -1344,15 +1345,15 @@ def execute_model(
1344
1345
mm_embeds = self ._gather_mm_embeddings (scheduler_output )
1345
1346
else :
1346
1347
mm_embeds = []
1347
-
1348
+
1349
+ model_kwargs : dict [str , Any ] = {}
1348
1350
if self .is_multimodal_model and get_pp_group ().is_first_rank :
1349
1351
# NOTE(woosuk): To unify token ids and soft tokens (vision
1350
1352
# embeddings), we always use embeddings (rather than token ids)
1351
1353
# as input to the multimodal model, even when the input is text.
1352
1354
input_ids = self .input_ids [:num_scheduled_tokens ]
1353
- self ._maybe_add_model_args (num_scheduled_tokens ,
1354
- model_kwargs , scheduler_output )
1355
-
1355
+ self ._maybe_add_multimodal_kwargs (model_kwargs = model_kwargs ,
1356
+ scheduler_output = scheduler_output )
1356
1357
if mm_embeds :
1357
1358
inputs_embeds = self .model .get_input_embeddings (
1358
1359
input_ids , mm_embeds )
@@ -1368,7 +1369,6 @@ def execute_model(
1368
1369
# multimodal models, it is not desirable for performance since
1369
1370
# then the embedding layer is not included in the CUDA graph.
1370
1371
input_ids = self .input_ids [:num_input_tokens ]
1371
- self ._maybe_add_model_args (num_input_tokens , model_kwargs , scheduler_output )
1372
1372
inputs_embeds = None
1373
1373
if self .uses_mrope :
1374
1374
positions = self .mrope_positions [:, :num_input_tokens ]
@@ -1994,8 +1994,9 @@ def _dummy_run(
1994
1994
num_scheduled_tokens ):
1995
1995
model = self .model
1996
1996
model_kwargs : dict [str , Any ] = {}
1997
- self ._maybe_add_model_args (num_tokens , model_kwargs )
1998
1997
if self .is_multimodal_model :
1998
+ self ._maybe_add_multimodal_kwargs (model_kwargs = model_kwargs ,
1999
+ num_reqs = num_reqs )
1999
2000
input_ids = None
2000
2001
inputs_embeds = self .inputs_embeds [:num_tokens ]
2001
2002
else :
0 commit comments