diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 567c448a8c9..e36d01c249f 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -143,7 +143,8 @@ def __init__(self): self.model = LLM( model=os.path.join(os.path.dirname(__file__), "./model"), skip_tokenizer_init=True, - dtype="float32", + dtype="float16", + enforce_eager=True, ) def run(self, input_data, location_coords): diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py new file mode 100644 index 00000000000..8179aa064ef --- /dev/null +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from ....conftest import VllmRunner + + +def generate_test_mm_data(): + mm_data = { + "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), + } + return mm_data + + +def _run_test( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + + mm_data = generate_test_mm_data() + prompt = { + # This model deals with no text input + "prompt_token_ids": [1], + "multi_modal_data": mm_data + } + with vllm_runner(model, + task="embed", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True) as vllm_model: + vllm_model.encode(prompt) + + +MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + + +@pytest.mark.parametrize("model", MODELS) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, +) -> None: + _run_test( + vllm_runner, + model, + ) diff --git a/vllm/config.py b/vllm/config.py index 226a1014fa7..f799e971431 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -612,6 +612,8 @@ def __post_init__(self) -> None: self.served_model_name = get_served_model_name(self.model, self.served_model_name) self.multimodal_config = self._init_multimodal_config() + self.model_supports_multimodal_raw_input = ( + self._init_model_supports_multimodal_raw_input()) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() @@ -715,6 +717,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: return None + def _init_model_supports_multimodal_raw_input(self): + return self.registry.supports_multimodal_raw_input(self.architectures) + def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -1120,10 +1125,10 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: return self.get_hf_config_sliding_window() def get_vocab_size(self) -> int: - return self.hf_text_config.vocab_size + return getattr(self.hf_text_config, "vocab_size", 0) def get_hidden_size(self) -> int: - return self.hf_text_config.hidden_size + return getattr(self.hf_text_config, "hidden_size", 0) @property def is_deepseek_mla(self) -> bool: @@ -1417,6 +1422,10 @@ def uses_mrope(self) -> bool: @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + + @property + def is_pooling_model(self) -> bool: + return self.registry.is_pooling_model(self.architectures) @property def is_cross_encoder(self) -> bool: diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index a018bd5d09d..d6b96f2893b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -130,6 +130,48 @@ def supports_multimodal( return isinstance(model, SupportsMultiModal) +@runtime_checkable +class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): + """The interface required for all multi-modal models.""" + + supports_multimodal_raw_input: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + +@runtime_checkable +class _SupportsMultiModalWithRawInput(Protocol): + supports_multimodal_raw_input: ClassVar[Literal[True]] + + +@overload +def supports_multimodal_raw_input( + model: object) -> TypeIs[SupportsMultiModalWithRawInput]: + ... + + +@overload +def supports_multimodal_raw_input( + model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: + ... + + +def supports_multimodal_raw_input( + model: Union[type[object], object] +) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], + TypeIs[SupportsMultiModalWithRawInput]]: + if isinstance(model, type): + return isinstance(model, _SupportsMultiModalWithRawInput) + + return isinstance(model, SupportsMultiModalWithRawInput) + + @runtime_checkable class SupportsLoRA(Protocol): """The interface required for all models that support LoRA.""" diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index a36f24bc80e..cc34721cf41 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -25,14 +25,15 @@ from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModal, - SupportsV0Only) +from vllm.model_executor.models.interfaces import ( + IsAttentionFree, SupportsMultiModalWithRawInput) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalFieldElem, MultiModalInputs, + MultiModalKwargs, MultiModalKwargsItem, + MultiModalSharedField, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) @@ -62,8 +63,9 @@ def get_dummy_mm_data( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((1, 6, 512, 512), 1.0), - "location_coords": torch.full((1, 2), 1.0), + "pixel_values": torch.full((6, 512, 512), 1.0, + dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } @@ -75,8 +77,10 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - location_coords=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), ) def _get_prompt_updates( @@ -99,14 +103,25 @@ def apply( for k, v in mm_data.items(): mm_kwargs[k] = v + mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]} + + multimodal_kwargs_items = [ + MultiModalKwargsItem.from_elems([ + MultiModalFieldElem(modality="image", + key=key, + data=data, + field=MultiModalSharedField(1)) + for key, data in mm_kwargs.items() + ]) + ] return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs(mm_kwargs), + mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), mm_hashes=None, - mm_placeholders={}, + mm_placeholders=mm_place_holders, ) @@ -114,8 +129,8 @@ def apply( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, - SupportsV0Only): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, + SupportsMultiModalWithRawInput): """ Prithvi Masked Autoencoder""" @classmethod @@ -169,7 +184,7 @@ def _parse_and_validate_multimodal_data( if not isinstance(pixel_values, torch.Tensor): raise ValueError(f"Incorrect type of pixel_values. " f"Got type: {type(pixel_values)}") - pixel_values = torch.unbind(pixel_values, dim=0)[0] + # pixel_values = torch.unbind(pixel_values, dim=0)[0] location_coords = kwargs.pop("location_coords", None) if not isinstance(location_coords, torch.Tensor): @@ -181,6 +196,13 @@ def _parse_and_validate_multimodal_data( return pixel_values, location_coords + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + # We do not really use any input tokens and therefore no embeddings + # to be calculated. However, due to the mandatory token ids in + # the input prompt we pass one token and the size of the dummy + # embedding tensors must reflect that. + return torch.empty(input_ids.shape) + def forward( self, input_ids: Optional[torch.Tensor], @@ -202,7 +224,10 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) + return PoolerOutput([ + PoolingSequenceGroupOutput(hidden_state) + for hidden_state in hidden_states + ]) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b100fe77e37..20ead20f0f9 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -23,8 +23,8 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_pp, - supports_transcription, supports_v0_only) + supports_multimodal, supports_multimodal_raw_input, + supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -275,6 +275,7 @@ class _ModelInfo: is_pooling_model: bool supports_cross_encoding: bool supports_multimodal: bool + supports_multimodal_raw_input: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -291,6 +292,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": is_pooling_model=True, # Can convert any model into a pooling model supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), + supports_multimodal_raw_input=supports_multimodal_raw_input(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), @@ -528,6 +530,13 @@ def is_multimodal_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_multimodal + def supports_multimodal_raw_input( + self, + architectures: Union[str, list[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal_raw_input + def is_pp_supported_model( self, architectures: Union[str, list[str]], diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 27aaa661c35..c44fcacd246 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -266,7 +266,7 @@ def create_processor( if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if tokenizer is None: + if tokenizer is None and not model_config.skip_tokenizer_init: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: mm_config = model_config.get_multimodal_config() diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 08bb0efb2f3..18b74f016b1 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from typing import Optional @@ -64,7 +65,131 @@ def new_empty(self) -> "KVCacheBlocks": return KVCacheBlocks(tuple([] for _ in range(len(self.blocks)))) -class KVCacheManager: +class KVCacheManagerInterface(ABC): + + @property + @abstractmethod + def usage(self) -> float: + raise NotImplementedError + + @abstractmethod + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: + raise NotImplementedError + + @abstractmethod + def get_computed_blocks(self, + request: Request) -> tuple[KVCacheBlocks, int]: + raise NotImplementedError + + @abstractmethod + def allocate_slots( + self, + request: Request, + num_new_tokens: int, + num_new_computed_tokens: int = 0, + new_computed_blocks: Optional[KVCacheBlocks] = None, + num_draft_tokens: int = 0, + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, + ) -> Optional[KVCacheBlocks]: + raise NotImplementedError + + @abstractmethod + def free(self, request: Request) -> None: + raise NotImplementedError + + @abstractmethod + def reset_prefix_cache(self) -> bool: + raise NotImplementedError + + @abstractmethod + def get_num_common_prefix_blocks( + self, + request: Request, + num_running_requests: int, + ) -> list[int]: + raise NotImplementedError + + @abstractmethod + def free_block_hashes(self, request: Request) -> None: + raise NotImplementedError + + @abstractmethod + def take_events(self) -> list[KVCacheEvent]: + raise NotImplementedError + + @abstractmethod + def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: + raise NotImplementedError + + @abstractmethod + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: + raise NotImplementedError + + @abstractmethod + def create_empty_block_list(self) -> KVCacheBlocks: + raise NotImplementedError + + +class DummyKVCacheManager(KVCacheManagerInterface): + + @property + def usage(self) -> float: + return 0.0 + + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: + return None + + def get_computed_blocks(self, + request: Request) -> tuple[KVCacheBlocks, int]: + return (KVCacheBlocks([]), 0) + + def allocate_slots( + self, + request: Request, + num_new_tokens: int, + num_new_computed_tokens: int = 0, + new_computed_blocks: Optional[KVCacheBlocks] = None, + num_draft_tokens: int = 0, + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, + ) -> Optional[KVCacheBlocks]: + #if we do not return a KV cache block requests are unschedulable + return KVCacheBlocks(tuple([KVCacheBlock(block_id=0)])) + + def free(self, request: Request) -> None: + pass + + def reset_prefix_cache(self) -> bool: + return True + + def get_num_common_prefix_blocks( + self, + request: Request, + num_running_requests: int, + ) -> list[int]: + return [] + + def free_block_hashes(self, request: Request) -> None: + pass + + def take_events(self) -> list[KVCacheEvent]: + return [] + + def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: + """Get the block ids of a request.""" + return tuple([]) + + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: + """Cache the blocks for the request, if enabled.""" + pass + + def create_empty_block_list(self) -> KVCacheBlocks: + """Creates a new KVCacheBlocks instance with no blocks.""" + return KVCacheBlocks(tuple([])) + + +class KVCacheManager(KVCacheManagerInterface): def __init__( self, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fe552db74e2..7f752e34698 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,7 +19,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import DummyKVCacheManager, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -92,7 +92,8 @@ def __init__( ) num_gpu_blocks = self.cache_config.num_gpu_blocks - assert num_gpu_blocks is not None and num_gpu_blocks > 0 + # num_gpu_blocks can be zero for attention free models + assert num_gpu_blocks is not None self.block_size = self.cache_config.block_size @@ -151,15 +152,18 @@ def __init__( self.num_lookahead_tokens = self.num_spec_tokens # Create the KV cache manager. - self.kv_cache_manager = KVCacheManager( - kv_cache_config=kv_cache_config, - max_model_len=self.max_model_len, - enable_caching=self.cache_config.enable_prefix_caching, - caching_hash_algo=self.cache_config.prefix_caching_hash_algo, - use_eagle=self.use_eagle, - log_stats=self.log_stats, - enable_kv_cache_events=self.enable_kv_cache_events, - ) + if self.cache_config.is_attention_free: + self.kv_cache_manager = DummyKVCacheManager() + else: + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, + enable_caching=self.cache_config.enable_prefix_caching, + caching_hash_algo=self.cache_config.prefix_caching_hash_algo, + use_eagle=self.use_eagle, + log_stats=self.log_stats, + enable_kv_cache_events=self.enable_kv_cache_events, + ) self.use_pp = self.parallel_config.pipeline_parallel_size > 1 def schedule(self) -> SchedulerOutput: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e2fdf6f8a11..0901eef9234 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -134,26 +134,37 @@ def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() - # Get all kv cache needed by the model - kv_cache_specs = self.model_executor.get_kv_cache_specs() - - # Profiles the peak memory usage of the model to determine how much - # memory can be allocated for kv cache. - available_gpu_memory = self.model_executor.determine_available_memory() - - assert len(kv_cache_specs) == len(available_gpu_memory) - # Get the kv cache tensor size - kv_cache_configs = [ - get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, - available_gpu_memory_one_worker) - for kv_cache_spec_one_worker, available_gpu_memory_one_worker in - zip(kv_cache_specs, available_gpu_memory) - ] - - # Since we use a shared centralized controller, we need the - # `kv_cache_config` to be consistent across all workers to make sure - # all the memory operators can be applied to all workers. - unify_kv_cache_configs(kv_cache_configs) + if vllm_config.model_config.is_attention_free: + # No need for initializing anything related to KV cache if the model + # is attention free. + kv_cache_specs = [] + kv_cache_configs = [ + KVCacheConfig(num_blocks=0, + kv_cache_tensors=[], + kv_cache_groups=[]) + ] + else: + # Get all kv cache needed by the model + kv_cache_specs = self.model_executor.get_kv_cache_specs() + + # Profiles the peak memory usage of the model to determine how much + # memory can be allocated for kv cache. + available_gpu_memory = ( + self.model_executor.determine_available_memory()) + + assert len(kv_cache_specs) == len(available_gpu_memory) + # Get the kv cache tensor size + kv_cache_configs = [ + get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, + available_gpu_memory_one_worker) + for kv_cache_spec_one_worker, available_gpu_memory_one_worker + in zip(kv_cache_specs, available_gpu_memory) + ] + + # Since we use a shared centralized controller, we need the + # `kv_cache_config` to be consistent across all workers to make sure + # all the memory operators can be applied to all workers. + unify_kv_cache_configs(kv_cache_configs) # All workers have the same kv_cache_config except layer names, so use # an arbitrary one to initialize the scheduler. diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a2328c37ba0..e13a544b50f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,11 +82,14 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + if not self.vllm_config.model_config.skip_tokenizer_init: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + else: + self.tokenizer = None # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bcd61d1f0a..5d8e15f8bbb 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -327,14 +327,16 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - req_state = RequestState.from_new_request( - tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + tokenizer = None if not self.tokenizer else \ + self.tokenizer.get_lora_tokenizer(request.lora_request) + + req_state = RequestState.from_new_request(tokenizer=tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats,) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 9fc52543efd..26d767ac026 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -375,7 +375,10 @@ def _validate_model_input( prompt_type: Literal["encoder", "decoder"], ): model_config = self.model_config - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + if model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) prompt_ids = prompt_inputs["prompt_token_ids"] if not prompt_ids: @@ -384,9 +387,12 @@ def _validate_model_input( else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError(f"Token id {max_input_id} is out of vocabulary") + if ( + tokenizer and (max_input_id:=max(prompt_ids, default=0)) > + tokenizer.max_token_id + ): + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4786d047acb..79150bfacc3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -124,7 +124,9 @@ def __init__( cache_config.cache_dtype] self.is_multimodal_model = model_config.is_multimodal_model - self.is_pooling_model = model_config.pooler_config is not None + self.is_pooling_model = model_config.is_pooling_model + self.model_supports_multimodal_raw_input = ( + model_config.model_supports_multimodal_raw_input) self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -327,6 +329,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: Args: scheduler_output: The scheduler output. """ + + # nothing to be reordered when the model is attention free + if self.model_config.is_attention_free: + return + self.attn_metadata_builders[0].reorder_batch(self.input_batch, scheduler_output) @@ -551,6 +558,39 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _maybe_add_multimodal_kwargs( + self, + model_kwargs: dict[str, Any], + scheduler_output: "Optional[SchedulerOutput]" = None, + num_reqs: int = -1, + ): + + if not self.model_supports_multimodal_raw_input: + return + + # Multi-modal data. + if scheduler_output: + multi_modal_kwargs_list = [] + for req in scheduler_output.scheduled_new_reqs: + req_mm_inputs = req.mm_inputs + if not isinstance(req_mm_inputs, list): + req_mm_inputs = list(req_mm_inputs) + multi_modal_kwargs_list.extend(req_mm_inputs) + multi_modal_kwargs = MultiModalKwargs.batch( + multi_modal_kwargs_list) + else: + # The only case where SchedulerOtput is None is for a dummy run, + # let's get some dummy data. + dummy_data = [ + self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, seq_len=1).multi_modal_data + for i in range(num_reqs) + ] + multi_modal_kwargs = MultiModalKwargs.batch(dummy_data) + + model_kwargs.update(multi_modal_kwargs) + + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -1016,13 +1056,14 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=len(grouped_mm_inputs), - ) + if curr_group_outputs: + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=len(grouped_mm_inputs), + ) - for output in curr_group_outputs: - encoder_outputs.append(output) + for output in curr_group_outputs: + encoder_outputs.append(output) # Cache the encoder outputs. for (req_id, input_id, pos_info), output in zip( @@ -1319,11 +1360,14 @@ def execute_model( else: mm_embeds = [] + model_kwargs: dict[str, Any] = {} if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] + self._maybe_add_multimodal_kwargs( + model_kwargs=model_kwargs, scheduler_output=scheduler_output) if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1372,7 +1416,10 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, - ) + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + )) self.maybe_wait_for_kv_save() finished_sending, finished_recving = ( @@ -2021,7 +2068,10 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model + model_kwargs: dict[str, Any] = {} if self.is_multimodal_model: + self._maybe_add_multimodal_kwargs(model_kwargs=model_kwargs, + num_reqs=num_reqs) input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: @@ -2050,12 +2100,13 @@ def _dummy_run( self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp): - outputs = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + outputs = model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, device=self.device),) + if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: