42
42
from vllm .inputs import INPUT_REGISTRY
43
43
from vllm .logger import logger
44
44
from vllm .model_executor .layers .fused_moe import FusedMoE
45
+ from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
45
46
from vllm .model_executor .model_loader import get_model
46
- from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
47
+ from vllm .multimodal import MULTIMODAL_REGISTRY
48
+ from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
49
+ from vllm .multimodal .utils import group_mm_inputs_by_modality
47
50
from vllm .sampling_params import SamplingType
48
51
from vllm .sequence import IntermediateTensors
49
52
from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
61
64
from vllm .v1 .utils import bind_kv_cache
62
65
from vllm .v1 .worker .gpu_input_batch import CachedRequestState , InputBatch
63
66
from vllm .v1 .worker .lora_model_runner_mixin import LoRAModelRunnerMixin
67
+ from vllm .v1 .worker .utils import (gather_mm_placeholders ,
68
+ sanity_check_mm_encoder_outputs ,
69
+ scatter_mm_placeholders )
64
70
65
71
from vllm_ascend .ascend_config import get_ascend_config
66
72
from vllm_ascend .attention .attention import AttentionMaskBuilder
@@ -362,6 +368,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
362
368
# Remove finished requests from the cached states.
363
369
for req_id in scheduler_output .finished_req_ids :
364
370
self .requests .pop (req_id , None )
371
+ self .encoder_cache .pop (req_id , None )
365
372
# Remove the finished requests from the persistent batch.
366
373
# NOTE(woosuk): There could be an edge case where finished_req_ids and
367
374
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -374,6 +381,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
374
381
if req_index is not None :
375
382
removed_req_indices .append (req_index )
376
383
384
+ # Free the cached encoder outputs.
385
+ for req_id , input_id in scheduler_output .free_encoder_input_ids :
386
+ encoder_outputs = self .encoder_cache .get (req_id )
387
+ if encoder_outputs is not None :
388
+ encoder_outputs .pop (input_id , None )
389
+ if not encoder_outputs :
390
+ self .encoder_cache .pop (req_id , None )
391
+
377
392
# Remove the unscheduled requests from the persistent batch.
378
393
# NOTE(woosuk): The unscheduled requests are either preempted requests
379
394
# or running requests that are not scheduled in this step. We remove
@@ -415,6 +430,43 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
415
430
lora_request = new_req_data .lora_request ,
416
431
)
417
432
433
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
434
+ if self .uses_mrope :
435
+ image_grid_thw = []
436
+ video_grid_thw = []
437
+ second_per_grid_ts = []
438
+ audio_feature_lengths = []
439
+ use_audio_in_video = False
440
+ for mm_input in self .requests [req_id ].mm_inputs :
441
+ if mm_input .get ("image_grid_thw" ) is not None :
442
+ image_grid_thw .extend (
443
+ mm_input ["image_grid_thw" ].tolist ())
444
+ if mm_input .get ("video_grid_thw" ) is not None :
445
+ video_grid_thw .extend (
446
+ mm_input ["video_grid_thw" ].tolist ())
447
+ if mm_input .get ("second_per_grid_ts" ) is not None :
448
+ second_per_grid_ts .extend (
449
+ mm_input ["second_per_grid_ts" ])
450
+ if mm_input .get ("audio_feature_lengths" ) is not None :
451
+ audio_feature_lengths .extend (
452
+ mm_input ["audio_feature_lengths" ])
453
+ if mm_input .get ("use_audio_in_video" ) is True :
454
+ use_audio_in_video = True
455
+
456
+ hf_config = self .model_config .hf_config
457
+
458
+ self .requests [req_id ].mrope_positions , \
459
+ self .requests [req_id ].mrope_position_delta = \
460
+ MRotaryEmbedding .get_input_positions_tensor (
461
+ self .requests [req_id ].prompt_token_ids ,
462
+ hf_config = hf_config ,
463
+ image_grid_thw = image_grid_thw ,
464
+ video_grid_thw = video_grid_thw ,
465
+ second_per_grid_ts = second_per_grid_ts ,
466
+ audio_feature_lengths = audio_feature_lengths ,
467
+ use_audio_in_video = use_audio_in_video ,
468
+ )
469
+
418
470
req_ids_to_add .append (req_id )
419
471
420
472
# Update the states of the running/resumed requests.
@@ -535,6 +587,166 @@ def _make_attention_mask(self, seq_lens, query_lens, position,
535
587
else :
536
588
return None
537
589
590
+ def _calc_mrope_positions (self , scheduler_output : "SchedulerOutput" ):
591
+ mrope_pos_ptr = 0
592
+ for index , req_id in enumerate (self .input_batch .req_ids ):
593
+ req = self .requests [req_id ]
594
+ assert req .mrope_positions is not None
595
+
596
+ num_computed_tokens = \
597
+ self .input_batch .num_computed_tokens_cpu [index ]
598
+ num_scheduled_tokens = \
599
+ scheduler_output .num_scheduled_tokens [req_id ]
600
+ num_prompt_tokens = len (req .prompt_token_ids )
601
+
602
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens :
603
+ prompt_part_len = max (0 ,
604
+ num_prompt_tokens - num_computed_tokens )
605
+ completion_part_len = max (
606
+ 0 , num_scheduled_tokens - prompt_part_len )
607
+ else :
608
+ prompt_part_len = num_scheduled_tokens
609
+ completion_part_len = 0
610
+
611
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
612
+
613
+ if prompt_part_len > 0 :
614
+ # prompt's mrope_positions are pre-computed
615
+ dst_start = mrope_pos_ptr
616
+ dst_end = mrope_pos_ptr + prompt_part_len
617
+ src_start = num_computed_tokens
618
+ src_end = num_computed_tokens + prompt_part_len
619
+
620
+ self .mrope_positions_cpu [:, dst_start :dst_end ] = \
621
+ req .mrope_positions [:,src_start :src_end ]
622
+
623
+ mrope_pos_ptr += prompt_part_len
624
+
625
+ if completion_part_len > 0 :
626
+ # compute completion's mrope_positions on-the-fly
627
+ dst_start = mrope_pos_ptr
628
+ dst_end = mrope_pos_ptr + completion_part_len
629
+
630
+ self .mrope_positions_cpu [:, dst_start :dst_end ] = \
631
+ MRotaryEmbedding .get_next_input_positions_tensor (
632
+ req .mrope_position_delta ,
633
+ context_len = num_computed_tokens +
634
+ prompt_part_len ,
635
+ seq_len = num_computed_tokens +
636
+ prompt_part_len +
637
+ completion_part_len ,
638
+ )
639
+
640
+ mrope_pos_ptr += completion_part_len
641
+
642
+ def _execute_mm_encoder (self , scheduler_output : "SchedulerOutput" ):
643
+ scheduled_encoder_inputs = scheduler_output .scheduled_encoder_inputs
644
+ if not scheduled_encoder_inputs :
645
+ return
646
+
647
+ # Batch the multi-modal inputs.
648
+ mm_inputs = list [MultiModalKwargs ]()
649
+ req_ids_pos = list [tuple [str , int , PlaceholderRange ]]()
650
+ for req_id , encoder_input_ids in scheduled_encoder_inputs .items ():
651
+ req_state = self .requests [req_id ]
652
+
653
+ for mm_input_id in encoder_input_ids :
654
+ mm_inputs .append (req_state .mm_inputs [mm_input_id ])
655
+ req_ids_pos .append (
656
+ (req_id , mm_input_id , req_state .mm_positions [mm_input_id ]))
657
+
658
+ # Batch mm inputs as much as we can: if a request in the batch has
659
+ # multiple modalities or a different modality than the previous one,
660
+ # we process it separately to preserve item order.
661
+ # FIXME(ywang96): This is a hacky way to deal with multiple modalities
662
+ # in the same batch while still being able to benefit from batching
663
+ # multimodal inputs. The proper solution should be reordering the
664
+ # encoder outputs.
665
+ grouped_mm_inputs_list = group_mm_inputs_by_modality (mm_inputs )
666
+
667
+ encoder_outputs = []
668
+ for grouped_mm_inputs in grouped_mm_inputs_list :
669
+ batched_mm_inputs = MultiModalKwargs .batch (grouped_mm_inputs )
670
+ batched_mm_inputs = MultiModalKwargs .as_kwargs (batched_mm_inputs ,
671
+ device = self .device )
672
+
673
+ # Run the encoder.
674
+ # `curr_group_outputs` is either of the following:
675
+ # 1. A tensor of shape (num_items, feature_size, hidden_size)
676
+ # in case feature_size is fixed across all multimodal items.
677
+ # 2. A list or tuple (length: num_items) of tensors, each of shape
678
+ # (feature_size, hidden_size) in case the feature size is dynamic
679
+ # depending on the input multimodal items.
680
+ curr_group_outputs = self .model .get_multimodal_embeddings (
681
+ ** batched_mm_inputs )
682
+
683
+ sanity_check_mm_encoder_outputs (
684
+ curr_group_outputs ,
685
+ expected_num_items = len (grouped_mm_inputs ),
686
+ )
687
+
688
+ for output in curr_group_outputs :
689
+ encoder_outputs .append (output )
690
+
691
+ # Cache the encoder outputs.
692
+ for (req_id , input_id , pos_info ), output in zip (
693
+ req_ids_pos ,
694
+ encoder_outputs ,
695
+ ):
696
+ if req_id not in self .encoder_cache :
697
+ self .encoder_cache [req_id ] = {}
698
+
699
+ self .encoder_cache [req_id ][input_id ] = scatter_mm_placeholders (
700
+ output ,
701
+ is_embed = pos_info .is_embed ,
702
+ )
703
+
704
+ def _gather_mm_embeddings (
705
+ self ,
706
+ scheduler_output : "SchedulerOutput" ,
707
+ ) -> list [torch .Tensor ]:
708
+ mm_embeds : list [torch .Tensor ] = []
709
+ for req_id in self .input_batch .req_ids :
710
+ num_scheduled_tokens = scheduler_output .num_scheduled_tokens [
711
+ req_id ]
712
+ req_state = self .requests [req_id ]
713
+ num_computed_tokens = req_state .num_computed_tokens
714
+ mm_positions = req_state .mm_positions
715
+ for i , pos_info in enumerate (mm_positions ):
716
+ start_pos = pos_info .offset
717
+ num_encoder_tokens = pos_info .length
718
+
719
+ # The encoder output is needed if the two ranges overlap:
720
+ # [num_computed_tokens,
721
+ # num_computed_tokens + num_scheduled_tokens) and
722
+ # [start_pos, start_pos + num_encoder_tokens)
723
+ if start_pos >= num_computed_tokens + num_scheduled_tokens :
724
+ # The encoder output is not needed in this step.
725
+ break
726
+ if start_pos + num_encoder_tokens <= num_computed_tokens :
727
+ # The encoder output is already processed and stored
728
+ # in the decoder's KV cache.
729
+ continue
730
+
731
+ start_idx = max (num_computed_tokens - start_pos , 0 )
732
+ end_idx = min (
733
+ num_computed_tokens - start_pos + num_scheduled_tokens ,
734
+ num_encoder_tokens )
735
+ assert start_idx < end_idx
736
+ assert req_id in self .encoder_cache
737
+ assert i in self .encoder_cache [req_id ]
738
+ encoder_output = self .encoder_cache [req_id ][i ]
739
+
740
+ if (is_embed := pos_info .is_embed ) is not None :
741
+ is_embed = is_embed [start_idx :end_idx ]
742
+
743
+ mm_embeds_item = gather_mm_placeholders (
744
+ encoder_output [start_idx :end_idx ],
745
+ is_embed = is_embed ,
746
+ )
747
+ mm_embeds .append (mm_embeds_item )
748
+ return mm_embeds
749
+
538
750
def _process_reqs (
539
751
self ,
540
752
scheduler_output : "SchedulerOutput" ,
@@ -594,6 +806,17 @@ def _process_reqs(
594
806
arange ,
595
807
out = positions_np )
596
808
809
+ # Calculate M-RoPE positions.
810
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
811
+ if self .uses_mrope :
812
+ self ._calc_mrope_positions (scheduler_output )
813
+
814
+ if self .uses_mrope :
815
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
816
+ self .mrope_positions [:, :total_num_scheduled_tokens ].copy_ (
817
+ self .mrope_positions_cpu [:, :total_num_scheduled_tokens ],
818
+ non_blocking = True )
819
+
597
820
self .positions [:total_num_scheduled_tokens ].copy_ (
598
821
self .positions_cpu [:total_num_scheduled_tokens ], non_blocking = True )
599
822
positions = self .positions [:num_input_tokens ]
@@ -706,6 +929,43 @@ def _process_reqs(
706
929
input_ids = self .input_ids [:padded_batch_size ]
707
930
positions = self .positions [:padded_batch_size ]
708
931
932
+ # prepare the MRoPE for mllm if using multimodal
933
+ num_input_tokens = total_num_scheduled_tokens
934
+ # _prepare_inputs may reorder the batch, so we must gather multi
935
+ # modal outputs after that to ensure the correct order
936
+ if self .is_multimodal_model :
937
+ # Run the multimodal encoder if any.
938
+ self ._execute_mm_encoder (scheduler_output )
939
+ mm_embeds = self ._gather_mm_embeddings (scheduler_output )
940
+ else :
941
+ mm_embeds = []
942
+
943
+ if self .is_multimodal_model :
944
+ # NOTE(woosuk): To unify token ids and soft tokens (vision
945
+ # embeddings), we always use embeddings (rather than token ids)
946
+ # as input to the multimodal model, even when the input is text.
947
+ input_ids = self .input_ids [:num_input_tokens ]
948
+ if mm_embeds :
949
+ inputs_embeds = self .model .get_input_embeddings (
950
+ input_ids , mm_embeds )
951
+ else :
952
+ inputs_embeds = self .model .get_input_embeddings (input_ids )
953
+ # TODO(woosuk): Avoid the copy. Optimize.
954
+ self .inputs_embeds [:num_input_tokens ].copy_ (inputs_embeds )
955
+ inputs_embeds = self .inputs_embeds [:num_input_tokens ]
956
+ input_ids = None
957
+ else :
958
+ # For text-only models, we use token ids as input.
959
+ # While it is possible to use embeddings as input just like the
960
+ # multimodal models, it is not desirable for performance since
961
+ # then the embedding layer is not included in the CUDA graph.
962
+ input_ids = self .input_ids [:num_input_tokens ]
963
+ inputs_embeds = None
964
+ if self .uses_mrope :
965
+ positions = self .mrope_positions [:, :num_input_tokens ]
966
+ else :
967
+ positions = self .positions [:num_input_tokens ]
968
+
709
969
# Run forward pass
710
970
with set_forward_context (attn_metadata ,
711
971
self .vllm_config ,
@@ -722,7 +982,7 @@ def _process_reqs(
722
982
input_ids = input_ids ,
723
983
positions = positions ,
724
984
intermediate_tensors = intermediate_tensors ,
725
- inputs_embeds = None ,
985
+ inputs_embeds = inputs_embeds ,
726
986
** model_kwargs ,
727
987
)
728
988
else :
@@ -731,7 +991,7 @@ def _process_reqs(
731
991
input_ids = input_ids ,
732
992
positions = positions ,
733
993
intermediate_tensors = intermediate_tensors ,
734
- inputs_embeds = None ,
994
+ inputs_embeds = inputs_embeds ,
735
995
** model_kwargs ,
736
996
)
737
997
@@ -1214,8 +1474,11 @@ def _dummy_run(
1214
1474
return hidden_states
1215
1475
1216
1476
def profile_run (self ) -> None :
1217
- # Profile with multimodal encoder & encoder cache.
1218
- self ._profile_multimodal ()
1477
+ # FIXME Profile with multimodal encoder & encoder cache.
1478
+ # current _profile_multimodal() using PyTorch SDPA backend method not
1479
+ # support for window/full attn to reduce Memcpy operations, so will cause
1480
+ # Out Of Memory problem, so we currently don't use self._profile_multimodal()
1481
+ # self._profile_multimodal()
1219
1482
1220
1483
# For profile, have maximum num_reqs and that collectively have
1221
1484
# maximum num_tokens.
0 commit comments