From 953f66ae3ffba7dd4e99b9ab286877a5d161e434 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 5 Jun 2025 15:27:21 +0000 Subject: [PATCH 01/12] Support for attention free models in V1 Signed-off-by: Christian Pinto --- vllm/v1/core/kv_cache_manager.py | 47 ++++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 27 ++++++++++-------- vllm/v1/engine/core.py | 49 +++++++++++++++++++------------- 3 files changed, 91 insertions(+), 32 deletions(-) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 08bb0efb2f3..6b51f380acc 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -63,6 +63,53 @@ def new_empty(self) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" return KVCacheBlocks(tuple([] for _ in range(len(self.blocks)))) +class DummyKVCacheManager: + @property + def usage(self) -> float: + return 0.0 + + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: + return None + + def get_computed_blocks(self, + request: Request) -> tuple[KVCacheBlocks, int]: + return(KVCacheBlocks([]), 0) + + def allocate_slots( + self, + request: Request, + num_new_tokens: int, + num_new_computed_tokens: int = 0, + new_computed_blocks: Optional[KVCacheBlocks] = None, + num_draft_tokens: int = 0, + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, + ) -> Optional[KVCacheBlocks]: + #if we do not return a KV cache block requests are unschedulable + return KVCacheBlocks([KVCacheBlock(block_id=0)]) + + def free(self, request: Request) -> None: + pass + + def reset_prefix_cache(self) -> bool: + return True + + def get_num_common_prefix_blocks( + self, + request: Request, + num_running_requests: int, + ) -> list[int]: + return [] + + def free_block_hashes(self, request: Request) -> None: + pass + + def take_events(self) -> list[KVCacheEvent]: + return [] + + def get_block_ids(self, request_id: str) -> list[list[int]]: + """Get the block ids of a request.""" + return [] class KVCacheManager: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fe552db74e2..f887d2e34fa 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,7 +19,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager, DummyKVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -92,7 +92,8 @@ def __init__( ) num_gpu_blocks = self.cache_config.num_gpu_blocks - assert num_gpu_blocks is not None and num_gpu_blocks > 0 + # num_gpu_blocks can be ero for attention free models + assert num_gpu_blocks is not None self.block_size = self.cache_config.block_size @@ -151,16 +152,18 @@ def __init__( self.num_lookahead_tokens = self.num_spec_tokens # Create the KV cache manager. - self.kv_cache_manager = KVCacheManager( - kv_cache_config=kv_cache_config, - max_model_len=self.max_model_len, - enable_caching=self.cache_config.enable_prefix_caching, - caching_hash_algo=self.cache_config.prefix_caching_hash_algo, - use_eagle=self.use_eagle, - log_stats=self.log_stats, - enable_kv_cache_events=self.enable_kv_cache_events, - ) - self.use_pp = self.parallel_config.pipeline_parallel_size > 1 + if self.cache_config.is_attention_free: + self.kv_cache_manager = DummyKVCacheManager() + else: + self.kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + max_model_len=self.max_model_len, + enable_caching=self.cache_config.enable_prefix_caching, + caching_hash_algo=self.cache_config.prefix_caching_hash_algo, + use_eagle=self.use_eagle, + log_stats=self.log_stats, + enable_kv_cache_events=self.enable_kv_cache_events, + ) def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e2fdf6f8a11..0c63ff6b1e2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -134,26 +134,34 @@ def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() - # Get all kv cache needed by the model - kv_cache_specs = self.model_executor.get_kv_cache_specs() - - # Profiles the peak memory usage of the model to determine how much - # memory can be allocated for kv cache. - available_gpu_memory = self.model_executor.determine_available_memory() - - assert len(kv_cache_specs) == len(available_gpu_memory) - # Get the kv cache tensor size - kv_cache_configs = [ - get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, - available_gpu_memory_one_worker) - for kv_cache_spec_one_worker, available_gpu_memory_one_worker in - zip(kv_cache_specs, available_gpu_memory) - ] - - # Since we use a shared centralized controller, we need the - # `kv_cache_config` to be consistent across all workers to make sure - # all the memory operators can be applied to all workers. - unify_kv_cache_configs(kv_cache_configs) + if vllm_config.model_config.is_attention_free: + # No need for initializing anything related to KV cache if the model + # is attention free. + kv_cache_specs = [] + kv_cache_configs = [ + KVCacheConfig(num_blocks=0, tensors={}, kv_cache_groups=[]) + ] + else: + # Get all kv cache needed by the model + kv_cache_specs = self.model_executor.get_kv_cache_specs() + + # Profiles the peak memory usage of the model to determine how much + # memory can be allocated for kv cache. + available_gpu_memory = self.model_executor.determine_available_memory() + + assert len(kv_cache_specs) == len(available_gpu_memory) + # Get the kv cache tensor size + kv_cache_configs = [ + get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, + available_gpu_memory_one_worker) + for kv_cache_spec_one_worker, available_gpu_memory_one_worker in + zip(kv_cache_specs, available_gpu_memory) + ] + + # Since we use a shared centralized controller, we need the + # `kv_cache_config` to be consistent across all workers to make sure + # all the memory operators can be applied to all workers. + unify_kv_cache_configs(kv_cache_configs) # All workers have the same kv_cache_config except layer names, so use # an arbitrary one to initialize the scheduler. @@ -186,6 +194,7 @@ def add_request(self, request: EngineCoreRequest): request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( request.mm_inputs, request.mm_hashes) + req = Request.from_engine_core_request(request) if req.use_structured_output: # Start grammar compilation asynchronously From f174bbff58e5c486e38fc0b38d0086ab1f243a37 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 5 Jun 2025 15:30:18 +0000 Subject: [PATCH 02/12] Better support for skip_tokenizer_init=True Signed-off-by: Christian Pinto --- vllm/config.py | 8 ++++++-- vllm/multimodal/registry.py | 2 +- vllm/v1/engine/llm_engine.py | 14 +++++++++----- vllm/v1/engine/output_processor.py | 5 ++++- vllm/v1/engine/processor.py | 12 ++++++++---- 5 files changed, 28 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 226a1014fa7..474a84670a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -612,6 +612,7 @@ def __post_init__(self) -> None: self.served_model_name = get_served_model_name(self.model, self.served_model_name) self.multimodal_config = self._init_multimodal_config() + self.model_supports_multimodal_raw_input = self._init_model_supports_multimodal_raw_input() if not self.skip_tokenizer_init: self._verify_tokenizer_mode() @@ -715,6 +716,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: return None + def _init_model_supports_multimodal_raw_input(self): + return self.registry.supports_multimodal_raw_input(self.architectures) + def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -1120,10 +1124,10 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: return self.get_hf_config_sliding_window() def get_vocab_size(self) -> int: - return self.hf_text_config.vocab_size + return getattr(self.hf_text_config, "vocab_size", 0) def get_hidden_size(self) -> int: - return self.hf_text_config.hidden_size + return getattr(self.hf_text_config, "hidden_size", 0) @property def is_deepseek_mla(self) -> bool: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 27aaa661c35..c44fcacd246 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -266,7 +266,7 @@ def create_processor( if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if tokenizer is None: + if tokenizer is None and not model_config.skip_tokenizer_init: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: mm_config = model_config.get_multimodal_config() diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a2328c37ba0..c8259aac77b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,11 +82,15 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + + if not self.vllm_config.model_config.skip_tokenizer_init: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + else: + self.tokenizer = None # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bcd61d1f0a..05ea6cc492e 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -327,8 +327,11 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") + tokenizer = None if not self.tokenizer else \ + self.tokenizer.get_lora_tokenizer(request.lora_request) + req_state = RequestState.from_new_request( - tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), + tokenizer=tokenizer, request=request, prompt=prompt, parent_req=parent_req, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 9fc52543efd..c79eb9e082d 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -375,7 +375,10 @@ def _validate_model_input( prompt_type: Literal["encoder", "decoder"], ): model_config = self.model_config - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + if model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) prompt_ids = prompt_inputs["prompt_token_ids"] if not prompt_ids: @@ -384,9 +387,10 @@ def _validate_model_input( else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError(f"Token id {max_input_id} is out of vocabulary") + if tokenizer: + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError(f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: From f3ab1fbd87f7a334d9d2586de53d9c0f34150617 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 5 Jun 2025 15:31:55 +0000 Subject: [PATCH 03/12] Support for attention free models in V1 Signed-off-by: Christian Pinto --- .../models/prithvi_geospatial_mae.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index a36f24bc80e..496ef0b9a30 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -26,13 +26,13 @@ from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModal, - SupportsV0Only) + SupportsMultiModalWithRawInput) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalFieldElem, + MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem, + MultiModalSharedField, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) @@ -75,8 +75,8 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - location_coords=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.shared(batch_size=1, modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, modality="image"), ) def _get_prompt_updates( @@ -99,14 +99,24 @@ def apply( for k, v in mm_data.items(): mm_kwargs[k] = v + mm_place_holders = { + "image": [PlaceholderRange(offset=0, length=0)] + } + + multimodal_kwargs_items = [ + MultiModalKwargsItem.from_elems( + [MultiModalFieldElem(modality="image", key=key, data=data, field=MultiModalSharedField(1)) + for key, data in mm_kwargs.items()] + ) + ] return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs(mm_kwargs), + mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), mm_hashes=None, - mm_placeholders={}, + mm_placeholders=mm_place_holders, ) @@ -114,8 +124,7 @@ def apply( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, - SupportsV0Only): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModalWithRawInput): """ Prithvi Masked Autoencoder""" @classmethod @@ -180,7 +189,13 @@ def _parse_and_validate_multimodal_data( location_coords = None return pixel_values, location_coords - + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + # We do not really use any input tokens and therefore no embeddings to be calculated + # However, due to the mandatory token ids in the input prompt we pass one token and the + # size of the dummy embedding tensors must reflect that. + return torch.empty(input_ids.shape) + def forward( self, input_ids: Optional[torch.Tensor], @@ -202,7 +217,7 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) + return PoolerOutput([PoolingSequenceGroupOutput(hidden_states[0])]) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: From ff86cb08ac4b58d22cc35141ce7dbd06457794b0 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Fri, 6 Jun 2025 07:47:15 +0000 Subject: [PATCH 04/12] Last few changes after rebasing to latest branch version Signed-off-by: Christian Pinto --- .../models/prithvi_geospatial_mae.py | 4 +-- vllm/v1/worker/gpu_model_runner.py | 34 +++++++++++++++---- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 496ef0b9a30..fbf32dd4933 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -62,8 +62,8 @@ def get_dummy_mm_data( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((1, 6, 512, 512), 1.0), - "location_coords": torch.full((1, 2), 1.0), + "pixel_values": torch.full((1, 6, 512, 512), 1.0, dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4786d047acb..45c871f0671 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -124,7 +124,7 @@ def __init__( cache_config.cache_dtype] self.is_multimodal_model = model_config.is_multimodal_model - self.is_pooling_model = model_config.pooler_config is not None + self.model_supports_multimodal_raw_input = model_config.model_supports_multimodal_raw_input self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -327,6 +327,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: Args: scheduler_output: The scheduler output. """ + + # nothing to be reordered when the mdoel is attention free + if self.model_config.is_attention_free: + return False + self.attn_metadata_builders[0].reorder_batch(self.input_batch, scheduler_output) @@ -1016,13 +1021,14 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=len(grouped_mm_inputs), - ) + if curr_group_outputs: + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=len(grouped_mm_inputs), + ) - for output in curr_group_outputs: - encoder_outputs.append(output) + for output in curr_group_outputs: + encoder_outputs.append(output) # Cache the encoder outputs. for (req_id, input_id, pos_info), output in zip( @@ -1324,6 +1330,9 @@ def execute_model( # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] + self._maybe_add_model_args(num_scheduled_tokens, + model_kwargs, scheduler_output) + if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1339,6 +1348,7 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] + self._maybe_add_model_args(num_input_tokens, model_kwargs, scheduler_output) inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -1372,6 +1382,10 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + ) ) self.maybe_wait_for_kv_save() @@ -2021,6 +2035,8 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model + model_kwargs: dict[str, Any] = {} + self._maybe_add_model_args(num_tokens, model_kwargs) if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -2055,7 +2071,11 @@ def _dummy_run( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device) ) + if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: From 769d8dd2c06ec8d17080edad2aaad4ad7d8b6b40 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 5 Jun 2025 15:31:03 +0000 Subject: [PATCH 05/12] Support passing raw multimodal data to model Signed-off-by: Christian Pinto --- vllm/model_executor/models/interfaces.py | 36 +++++++++++++++++++++ vllm/model_executor/models/registry.py | 14 +++++++-- vllm/v1/worker/gpu_model_runner.py | 40 ++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index a018bd5d09d..a3778cc89a2 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -129,6 +129,42 @@ def supports_multimodal( return isinstance(model, SupportsMultiModal) +@runtime_checkable +class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): + """The interface required for all multi-modal models.""" + + supports_multimodal_raw_input: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + +@runtime_checkable +class _SupportsMultiModalWithRawInput(Protocol): + supports_multimodal_raw_input: ClassVar[Literal[True]] + + +@overload +def supports_multimodal_raw_input(model: object) -> TypeIs[SupportsMultiModalWithRawInput]: + ... + + +@overload +def supports_multimodal_raw_input(model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: + ... + + +def supports_multimodal_raw_input( + model: Union[type[object], object] +) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], TypeIs[SupportsMultiModalWithRawInput]]: + if isinstance(model, type): + return isinstance(model, _SupportsMultiModalWithRawInput) + + return isinstance(model, SupportsMultiModalWithRawInput) @runtime_checkable class SupportsLoRA(Protocol): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b100fe77e37..398f4d9556f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -23,8 +23,9 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_pp, - supports_transcription, supports_v0_only) + supports_multimodal, supports_multimodal_raw_input, + supports_pp, supports_transcription, + supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -275,6 +276,7 @@ class _ModelInfo: is_pooling_model: bool supports_cross_encoding: bool supports_multimodal: bool + supports_multimodal_raw_input: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -291,6 +293,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": is_pooling_model=True, # Can convert any model into a pooling model supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), + supports_multimodal_raw_input=supports_multimodal_raw_input(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), @@ -527,6 +530,13 @@ def is_multimodal_model( ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_multimodal + + def supports_multimodal_raw_input( + self, + architectures: Union[str, list[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal_raw_input def is_pp_supported_model( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 45c871f0671..9f5109de34e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -556,6 +556,46 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], + scheduler_output: "SchedulerOutput"): + # Multi-modal data. + if scheduler_output: + multi_modal_kwargs_list = [] + for req in scheduler_output.scheduled_new_reqs: + req_mm_inputs = req.mm_inputs + if not isinstance(req_mm_inputs, list): + req_mm_inputs = list(req_mm_inputs) + multi_modal_kwargs_list.extend(req_mm_inputs) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + else: + # The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data. + dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1) + multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data]) + + model_kwargs.update(multi_modal_kwargs) + + def _maybe_add_model_args(self, num_tokens: int, + model_kwargs: dict[str,Any], + scheduler_output: "SchedulerOutput"=None): + + if self.supports_token_type_ids: + model_kwargs["token_type_ids"] =\ + self.get_token_type_ids()[:num_tokens] + + if self.model_supports_multimodal_raw_input: + self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output) + + def _maybe_compute_attn_prefix( + self, + scheduler_output: "SchedulerOutput", + ) -> list[int]: + return [0] * len(self.kv_cache_config.kv_cache_groups) + + def _maybe_prepare_additional_inputs(self, + scheduler_output: "SchedulerOutput", + token_indices: torch.Tensor): + pass + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, From 9cc76b14087876c5690a6f4fd251e33e4f7ac764 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Fri, 6 Jun 2025 15:05:04 +0000 Subject: [PATCH 06/12] latest changes to align with the original branch Signed-off-by: Christian Pinto --- vllm/config.py | 1 + .../models/prithvi_geospatial_mae.py | 6 ++-- vllm/v1/worker/gpu_model_runner.py | 33 ++++++++++--------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 474a84670a2..c0f411ff73f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -612,6 +612,7 @@ def __post_init__(self) -> None: self.served_model_name = get_served_model_name(self.model, self.served_model_name) self.multimodal_config = self._init_multimodal_config() + self.is_pooling_model = self.registry.is_pooling_model(self.architectures) self.model_supports_multimodal_raw_input = self._init_model_supports_multimodal_raw_input() if not self.skip_tokenizer_init: self._verify_tokenizer_mode() diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index fbf32dd4933..7f7a6b76c71 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -62,7 +62,7 @@ def get_dummy_mm_data( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((1, 6, 512, 512), 1.0, dtype=torch.float16), + "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } @@ -178,7 +178,7 @@ def _parse_and_validate_multimodal_data( if not isinstance(pixel_values, torch.Tensor): raise ValueError(f"Incorrect type of pixel_values. " f"Got type: {type(pixel_values)}") - pixel_values = torch.unbind(pixel_values, dim=0)[0] + # pixel_values = torch.unbind(pixel_values, dim=0)[0] location_coords = kwargs.pop("location_coords", None) if not isinstance(location_coords, torch.Tensor): @@ -217,7 +217,7 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_states[0])]) + return PoolerOutput([PoolingSequenceGroupOutput(hidden_state) for hidden_state in hidden_states]) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9f5109de34e..c5dce455152 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -124,6 +124,7 @@ def __init__( cache_config.cache_dtype] self.is_multimodal_model = model_config.is_multimodal_model + self.is_pooling_model = model_config.is_pooling_model self.model_supports_multimodal_raw_input = model_config.model_supports_multimodal_raw_input self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens @@ -557,7 +558,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput"): + scheduler_output: "SchedulerOutput", + num_reqs: int=-1): # Multi-modal data. if scheduler_output: multi_modal_kwargs_list = [] @@ -569,21 +571,20 @@ def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) else: # The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data. - dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1) - multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data]) + dummy_data = [self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1).multi_modal_data for i in range(num_reqs)] + # dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1) + # multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data]) + multi_modal_kwargs = MultiModalKwargs.batch(dummy_data) model_kwargs.update(multi_modal_kwargs) - def _maybe_add_model_args(self, num_tokens: int, + def _maybe_add_multimodal_kwargs(self, model_kwargs: dict[str,Any], - scheduler_output: "SchedulerOutput"=None): - - if self.supports_token_type_ids: - model_kwargs["token_type_ids"] =\ - self.get_token_type_ids()[:num_tokens] + scheduler_output: "SchedulerOutput"=None, + num_reqs: int=-1): if self.model_supports_multimodal_raw_input: - self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output) + self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output, num_reqs) def _maybe_compute_attn_prefix( self, @@ -1364,15 +1365,15 @@ def execute_model( mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - + + model_kwargs: dict[str, Any] = {} if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] - self._maybe_add_model_args(num_scheduled_tokens, - model_kwargs, scheduler_output) - + self._maybe_add_multimodal_kwargs(model_kwargs=model_kwargs, + scheduler_output=scheduler_output) if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1388,7 +1389,6 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] - self._maybe_add_model_args(num_input_tokens, model_kwargs, scheduler_output) inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -2076,8 +2076,9 @@ def _dummy_run( num_scheduled_tokens): model = self.model model_kwargs: dict[str, Any] = {} - self._maybe_add_model_args(num_tokens, model_kwargs) if self.is_multimodal_model: + self._maybe_add_multimodal_kwargs(model_kwargs=model_kwargs, + num_reqs=num_reqs) input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: From f8226da6db90768e2168a6c50265321d9c20c82f Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Tue, 24 Jun 2025 14:25:03 +0000 Subject: [PATCH 07/12] Latest changes to aadpt to upstream master Signed-off-by: Christian Pinto --- examples/offline_inference/prithvi_geospatial_mae.py | 3 ++- vllm/v1/engine/core.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 567c448a8c9..e0169eafbab 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -143,7 +143,8 @@ def __init__(self): self.model = LLM( model=os.path.join(os.path.dirname(__file__), "./model"), skip_tokenizer_init=True, - dtype="float32", + dtype="float16", + enforce_eager=True ) def run(self, input_data, location_coords): diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 0c63ff6b1e2..395a3f8c018 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -139,7 +139,7 @@ def _initialize_kv_caches( # is attention free. kv_cache_specs = [] kv_cache_configs = [ - KVCacheConfig(num_blocks=0, tensors={}, kv_cache_groups=[]) + KVCacheConfig(num_blocks=0, kv_cache_tensors={}, kv_cache_groups=[]) ] else: # Get all kv cache needed by the model From c70826621b9562a83efa8c709d8a17a840383417 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 25 Jun 2025 10:52:35 +0000 Subject: [PATCH 08/12] Some reformatting to make the pre-commit hooks succeed Signed-off-by: Christian Pinto --- .../prithvi_geospatial_mae.py | 2 +- .../multimodal/pooling/test_prithvi_mae.py | 44 +++++++++++++ vllm/config.py | 6 +- vllm/model_executor/models/interfaces.py | 12 +++- .../models/prithvi_geospatial_mae.py | 52 +++++++++------ vllm/model_executor/models/registry.py | 5 +- vllm/v1/core/kv_cache_manager.py | 13 +++- vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/engine/core.py | 14 ++-- vllm/v1/engine/llm_engine.py | 3 +- vllm/v1/engine/output_processor.py | 15 ++--- vllm/v1/engine/processor.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 64 ++++++++++--------- 13 files changed, 157 insertions(+), 78 deletions(-) create mode 100644 tests/models/multimodal/pooling/test_prithvi_mae.py diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index e0169eafbab..e36d01c249f 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -144,7 +144,7 @@ def __init__(self): model=os.path.join(os.path.dirname(__file__), "./model"), skip_tokenizer_init=True, dtype="float16", - enforce_eager=True + enforce_eager=True, ) def run(self, input_data, location_coords): diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py new file mode 100644 index 00000000000..870edad3632 --- /dev/null +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from ....conftest import VllmRunner + +def generate_test_mm_data(): + mm_data = { + "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), + } + return mm_data + +def _run_test( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + + mm_data = generate_test_mm_data() + prompt = { + # This model deals with no text input + "prompt_token_ids": [1], + "multi_modal_data": mm_data + } + with vllm_runner(model, task="embed", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True) as vllm_model: + output = vllm_model.encode(prompt) + +MODELS=["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] +@pytest.mark.parametrize("model", MODELS) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, +) -> None: + _run_test( + vllm_runner, + model, + ) \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index c0f411ff73f..a24c4ecbcba 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -612,8 +612,10 @@ def __post_init__(self) -> None: self.served_model_name = get_served_model_name(self.model, self.served_model_name) self.multimodal_config = self._init_multimodal_config() - self.is_pooling_model = self.registry.is_pooling_model(self.architectures) - self.model_supports_multimodal_raw_input = self._init_model_supports_multimodal_raw_input() + self.is_pooling_model = self.registry.is_pooling_model( + self.architectures) + self.model_supports_multimodal_raw_input = ( + self._init_model_supports_multimodal_raw_input()) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index a3778cc89a2..d6b96f2893b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -129,6 +129,7 @@ def supports_multimodal( return isinstance(model, SupportsMultiModal) + @runtime_checkable class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): """The interface required for all multi-modal models.""" @@ -143,29 +144,34 @@ class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): MRO of your model class. """ + @runtime_checkable class _SupportsMultiModalWithRawInput(Protocol): supports_multimodal_raw_input: ClassVar[Literal[True]] @overload -def supports_multimodal_raw_input(model: object) -> TypeIs[SupportsMultiModalWithRawInput]: +def supports_multimodal_raw_input( + model: object) -> TypeIs[SupportsMultiModalWithRawInput]: ... @overload -def supports_multimodal_raw_input(model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: +def supports_multimodal_raw_input( + model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: ... def supports_multimodal_raw_input( model: Union[type[object], object] -) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], TypeIs[SupportsMultiModalWithRawInput]]: +) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], + TypeIs[SupportsMultiModalWithRawInput]]: if isinstance(model, type): return isinstance(model, _SupportsMultiModalWithRawInput) return isinstance(model, SupportsMultiModalWithRawInput) + @runtime_checkable class SupportsLoRA(Protocol): """The interface required for all models that support LoRA.""" diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 7f7a6b76c71..cc34721cf41 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -25,13 +25,14 @@ from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModalWithRawInput) +from vllm.model_executor.models.interfaces import ( + IsAttentionFree, SupportsMultiModalWithRawInput) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalFieldElem, - MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem, +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalFieldElem, MultiModalInputs, + MultiModalKwargs, MultiModalKwargsItem, MultiModalSharedField, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -62,7 +63,8 @@ def get_dummy_mm_data( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), + "pixel_values": torch.full((6, 512, 512), 1.0, + dtype=torch.float16), "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } @@ -75,8 +77,10 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - pixel_values=MultiModalFieldConfig.shared(batch_size=1, modality="image"), - location_coords=MultiModalFieldConfig.shared(batch_size=1, modality="image"), + pixel_values=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), ) def _get_prompt_updates( @@ -99,15 +103,16 @@ def apply( for k, v in mm_data.items(): mm_kwargs[k] = v - mm_place_holders = { - "image": [PlaceholderRange(offset=0, length=0)] - } + mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]} multimodal_kwargs_items = [ - MultiModalKwargsItem.from_elems( - [MultiModalFieldElem(modality="image", key=key, data=data, field=MultiModalSharedField(1)) - for key, data in mm_kwargs.items()] - ) + MultiModalKwargsItem.from_elems([ + MultiModalFieldElem(modality="image", + key=key, + data=data, + field=MultiModalSharedField(1)) + for key, data in mm_kwargs.items() + ]) ] return MultiModalInputs( @@ -124,7 +129,8 @@ def apply( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModalWithRawInput): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, + SupportsMultiModalWithRawInput): """ Prithvi Masked Autoencoder""" @classmethod @@ -189,13 +195,14 @@ def _parse_and_validate_multimodal_data( location_coords = None return pixel_values, location_coords - + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - # We do not really use any input tokens and therefore no embeddings to be calculated - # However, due to the mandatory token ids in the input prompt we pass one token and the - # size of the dummy embedding tensors must reflect that. + # We do not really use any input tokens and therefore no embeddings + # to be calculated. However, due to the mandatory token ids in + # the input prompt we pass one token and the size of the dummy + # embedding tensors must reflect that. return torch.empty(input_ids.shape) - + def forward( self, input_ids: Optional[torch.Tensor], @@ -217,7 +224,10 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_state) for hidden_state in hidden_states]) + return PoolerOutput([ + PoolingSequenceGroupOutput(hidden_state) + for hidden_state in hidden_states + ]) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 398f4d9556f..20ead20f0f9 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -24,8 +24,7 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, supports_multimodal_raw_input, - supports_pp, supports_transcription, - supports_v0_only) + supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -530,7 +529,7 @@ def is_multimodal_model( ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_multimodal - + def supports_multimodal_raw_input( self, architectures: Union[str, list[str]], diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 6b51f380acc..d3271beeb29 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -63,7 +63,9 @@ def new_empty(self) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" return KVCacheBlocks(tuple([] for _ in range(len(self.blocks)))) + class DummyKVCacheManager: + @property def usage(self) -> float: return 0.0 @@ -73,7 +75,7 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: - return(KVCacheBlocks([]), 0) + return (KVCacheBlocks([]), 0) def allocate_slots( self, @@ -111,6 +113,15 @@ def get_block_ids(self, request_id: str) -> list[list[int]]: """Get the block ids of a request.""" return [] + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: + """Cache the blocks for the request, if enabled.""" + pass + + def create_empty_block_list(self) -> KVCacheBlocks: + """Creates a new KVCacheBlocks instance with no blocks.""" + return (KVCacheBlocks([]), 0) + + class KVCacheManager: def __init__( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f887d2e34fa..101175596bb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -19,7 +19,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager, DummyKVCacheManager +from vllm.v1.core.kv_cache_manager import DummyKVCacheManager, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 395a3f8c018..f941f56031b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -139,23 +139,26 @@ def _initialize_kv_caches( # is attention free. kv_cache_specs = [] kv_cache_configs = [ - KVCacheConfig(num_blocks=0, kv_cache_tensors={}, kv_cache_groups=[]) - ] + KVCacheConfig(num_blocks=0, + kv_cache_tensors={}, + kv_cache_groups=[]) + ] else: # Get all kv cache needed by the model kv_cache_specs = self.model_executor.get_kv_cache_specs() # Profiles the peak memory usage of the model to determine how much # memory can be allocated for kv cache. - available_gpu_memory = self.model_executor.determine_available_memory() + available_gpu_memory = ( + self.model_executor.determine_available_memory()) assert len(kv_cache_specs) == len(available_gpu_memory) # Get the kv cache tensor size kv_cache_configs = [ get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, available_gpu_memory_one_worker) - for kv_cache_spec_one_worker, available_gpu_memory_one_worker in - zip(kv_cache_specs, available_gpu_memory) + for kv_cache_spec_one_worker, available_gpu_memory_one_worker + in zip(kv_cache_specs, available_gpu_memory) ] # Since we use a shared centralized controller, we need the @@ -194,7 +197,6 @@ def add_request(self, request: EngineCoreRequest): request.mm_inputs = self.mm_input_cache_server.get_and_update_p1( request.mm_inputs, request.mm_hashes) - req = Request.from_engine_core_request(request) if req.use_structured_output: # Start grammar compilation asynchronously diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c8259aac77b..e13a544b50f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,8 +82,7 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - - if not self.vllm_config.model_config.skip_tokenizer_init: + if not self.vllm_config.model_config.skip_tokenizer_init: # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 05ea6cc492e..3be6c482121 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -330,14 +330,13 @@ def add_request( tokenizer = None if not self.tokenizer else \ self.tokenizer.get_lora_tokenizer(request.lora_request) - req_state = RequestState.from_new_request( - tokenizer=tokenizer, - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + req_state = RequestState.from_new_request(tokenizer=tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index c79eb9e082d..d776ce7585c 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -390,7 +390,8 @@ def _validate_model_input( if tokenizer: max_input_id = max(prompt_ids, default=0) if max_input_id > tokenizer.max_token_id: - raise ValueError(f"Token id {max_input_id} is out of vocabulary") + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c5dce455152..5d15e1aeb1e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -125,7 +125,8 @@ def __init__( self.is_multimodal_model = model_config.is_multimodal_model self.is_pooling_model = model_config.is_pooling_model - self.model_supports_multimodal_raw_input = model_config.model_supports_multimodal_raw_input + self.model_supports_multimodal_raw_input = ( + model_config.model_supports_multimodal_raw_input) self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -557,9 +558,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() - def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput", - num_reqs: int=-1): + def _add_multimodal_inputs_to_model_args( + self, + model_kwargs: dict[str, Any], + scheduler_output: "SchedulerOutput", + num_reqs: int = -1): # Multi-modal data. if scheduler_output: multi_modal_kwargs_list = [] @@ -568,23 +571,30 @@ def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], if not isinstance(req_mm_inputs, list): req_mm_inputs = list(req_mm_inputs) multi_modal_kwargs_list.extend(req_mm_inputs) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch( + multi_modal_kwargs_list) else: - # The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data. - dummy_data = [self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1).multi_modal_data for i in range(num_reqs)] - # dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1) - # multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data]) + # The only case where SchedulerOtput is None is for a dummy run, + # let's get some dummy data. + dummy_data = [ + self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, seq_len=1).multi_modal_data + for i in range(num_reqs) + ] multi_modal_kwargs = MultiModalKwargs.batch(dummy_data) - + model_kwargs.update(multi_modal_kwargs) - def _maybe_add_multimodal_kwargs(self, - model_kwargs: dict[str,Any], - scheduler_output: "SchedulerOutput"=None, - num_reqs: int=-1): + def _maybe_add_multimodal_kwargs( + self, + model_kwargs: dict[str, Any], + scheduler_output: "SchedulerOutput" = None, + num_reqs: int = -1): if self.model_supports_multimodal_raw_input: - self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output, num_reqs) + self._add_multimodal_inputs_to_model_args(model_kwargs, + scheduler_output, + num_reqs) def _maybe_compute_attn_prefix( self, @@ -1365,15 +1375,15 @@ def execute_model( mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - + model_kwargs: dict[str, Any] = {} if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] - self._maybe_add_multimodal_kwargs(model_kwargs=model_kwargs, - scheduler_output=scheduler_output) + self._maybe_add_multimodal_kwargs( + model_kwargs=model_kwargs, scheduler_output=scheduler_output) if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1425,8 +1435,7 @@ def execute_model( **MultiModalKwargs.as_kwargs( model_kwargs, device=self.device, - ) - ) + )) self.maybe_wait_for_kv_save() finished_sending, finished_recving = ( @@ -2107,15 +2116,12 @@ def _dummy_run( self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp): - outputs = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **MultiModalKwargs.as_kwargs( - model_kwargs, - device=self.device) - ) + outputs = model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, device=self.device)) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs From 0dba4cd1c052c0a5024f68f5f6e9f0c0f3f306cc Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 25 Jun 2025 12:54:25 +0000 Subject: [PATCH 09/12] Few more changes to solve some other pre-commit hooks failures Signed-off-by: Christian Pinto --- .../multimodal/pooling/test_prithvi_mae.py | 24 +++--- vllm/v1/core/kv_cache_manager.py | 78 +++++++++++++++++-- vllm/v1/core/sched/scheduler.py | 4 +- vllm/v1/engine/core.py | 2 +- 4 files changed, 89 insertions(+), 19 deletions(-) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 870edad3632..7350f2990c1 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -6,17 +6,19 @@ from ....conftest import VllmRunner + def generate_test_mm_data(): mm_data = { "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } return mm_data - + + def _run_test( vllm_runner: type[VllmRunner], model: str, -) -> None: +) -> None: mm_data = generate_test_mm_data() prompt = { @@ -24,13 +26,15 @@ def _run_test( "prompt_token_ids": [1], "multi_modal_data": mm_data } - with vllm_runner(model, task="embed", - dtype=torch.float16, - enforce_eager=True, - skip_tokenizer_init=True) as vllm_model: - output = vllm_model.encode(prompt) - -MODELS=["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + with vllm_runner(model, + task="embed", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True) as vllm_model: + vllm_model.encode(prompt) + +MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + @pytest.mark.parametrize("model", MODELS) def test_models_image( hf_runner, @@ -41,4 +45,4 @@ def test_models_image( _run_test( vllm_runner, model, - ) \ No newline at end of file + ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index d3271beeb29..2f6708c62fa 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from typing import Optional @@ -64,7 +65,72 @@ def new_empty(self) -> "KVCacheBlocks": return KVCacheBlocks(tuple([] for _ in range(len(self.blocks)))) -class DummyKVCacheManager: +class KVCacheManagerInterface(ABC): + + @abstractmethod + def usage(self) -> float: + raise NotImplementedError + + @abstractmethod + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: + raise NotImplementedError + + @abstractmethod + def get_computed_blocks(self, + request: Request) -> tuple[KVCacheBlocks, int]: + raise NotImplementedError + + @abstractmethod + def allocate_slots( + self, + request: Request, + num_new_tokens: int, + num_new_computed_tokens: int = 0, + new_computed_blocks: Optional[KVCacheBlocks] = None, + num_draft_tokens: int = 0, + num_lookahead_tokens: int = 0, + delay_cache_blocks: bool = False, + ) -> Optional[KVCacheBlocks]: + raise NotImplementedError + + @abstractmethod + def free(self, request: Request) -> None: + raise NotImplementedError + + @abstractmethod + def reset_prefix_cache(self) -> bool: + raise NotImplementedError + + @abstractmethod + def get_num_common_prefix_blocks( + self, + request: Request, + num_running_requests: int, + ) -> list[int]: + raise NotImplementedError + + @abstractmethod + def free_block_hashes(self, request: Request) -> None: + raise NotImplementedError + + @abstractmethod + def take_events(self) -> list[KVCacheEvent]: + raise NotImplementedError + + @abstractmethod + def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: + raise NotImplementedError + + @abstractmethod + def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: + raise NotImplementedError + + @abstractmethod + def create_empty_block_list(self) -> KVCacheBlocks: + raise NotImplementedError + + +class DummyKVCacheManager(KVCacheManagerInterface): @property def usage(self) -> float: @@ -88,7 +154,7 @@ def allocate_slots( delay_cache_blocks: bool = False, ) -> Optional[KVCacheBlocks]: #if we do not return a KV cache block requests are unschedulable - return KVCacheBlocks([KVCacheBlock(block_id=0)]) + return KVCacheBlocks(tuple([KVCacheBlock(block_id=0)])) def free(self, request: Request) -> None: pass @@ -109,9 +175,9 @@ def free_block_hashes(self, request: Request) -> None: def take_events(self) -> list[KVCacheEvent]: return [] - def get_block_ids(self, request_id: str) -> list[list[int]]: + def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" - return [] + return tuple([]) def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: """Cache the blocks for the request, if enabled.""" @@ -119,10 +185,10 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" - return (KVCacheBlocks([]), 0) + return KVCacheBlocks(tuple([])) -class KVCacheManager: +class KVCacheManager(KVCacheManagerInterface): def __init__( self, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 101175596bb..b35cf1f34c5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -493,8 +493,8 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = ( - self.kv_cache_manager.get_block_ids(request.request_id)) + req_to_new_block_ids[request.request_id] = \ + self.kv_cache_manager.get_block_ids(request.request_id) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f941f56031b..0901eef9234 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -140,7 +140,7 @@ def _initialize_kv_caches( kv_cache_specs = [] kv_cache_configs = [ KVCacheConfig(num_blocks=0, - kv_cache_tensors={}, + kv_cache_tensors=[], kv_cache_groups=[]) ] else: From e0ae3118c6d07c0e1abb9a1bc71ae42c3938bf20 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 26 Jun 2025 08:31:59 +0000 Subject: [PATCH 10/12] Some style changes - Improved formatting around - made is_pooling_model a @property in ModelConfig Signed-off-by: Christian Pinto --- tests/models/multimodal/pooling/test_prithvi_mae.py | 2 ++ vllm/config.py | 6 ++++-- vllm/v1/core/kv_cache_manager.py | 1 + 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 7350f2990c1..8179aa064ef 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -33,8 +33,10 @@ def _run_test( skip_tokenizer_init=True) as vllm_model: vllm_model.encode(prompt) + MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + @pytest.mark.parametrize("model", MODELS) def test_models_image( hf_runner, diff --git a/vllm/config.py b/vllm/config.py index a24c4ecbcba..f799e971431 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -612,8 +612,6 @@ def __post_init__(self) -> None: self.served_model_name = get_served_model_name(self.model, self.served_model_name) self.multimodal_config = self._init_multimodal_config() - self.is_pooling_model = self.registry.is_pooling_model( - self.architectures) self.model_supports_multimodal_raw_input = ( self._init_model_supports_multimodal_raw_input()) if not self.skip_tokenizer_init: @@ -1424,6 +1422,10 @@ def uses_mrope(self) -> bool: @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + + @property + def is_pooling_model(self) -> bool: + return self.registry.is_pooling_model(self.architectures) @property def is_cross_encoder(self) -> bool: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 2f6708c62fa..18b74f016b1 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -67,6 +67,7 @@ def new_empty(self) -> "KVCacheBlocks": class KVCacheManagerInterface(ABC): + @property @abstractmethod def usage(self) -> float: raise NotImplementedError From aab89561fb47cfadcd95c01189c6a13930c43c42 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Fri, 27 Jun 2025 08:48:10 +0000 Subject: [PATCH 11/12] Simple code refactoring - Remove unused functions - merged functions not called anywhere else Signed-off-by: Christian Pinto --- vllm/v1/worker/gpu_model_runner.py | 34 ++++++++---------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5d15e1aeb1e..30931792f69 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -558,11 +558,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() - def _add_multimodal_inputs_to_model_args( + def _maybe_add_multimodal_kwargs( self, model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput", - num_reqs: int = -1): + scheduler_output: "SchedulerOutput" = None, + num_reqs: int = -1, + ): + + if not self.model_supports_multimodal_raw_input: + return + # Multi-modal data. if scheduler_output: multi_modal_kwargs_list = [] @@ -585,28 +590,7 @@ def _add_multimodal_inputs_to_model_args( model_kwargs.update(multi_modal_kwargs) - def _maybe_add_multimodal_kwargs( - self, - model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput" = None, - num_reqs: int = -1): - - if self.model_supports_multimodal_raw_input: - self._add_multimodal_inputs_to_model_args(model_kwargs, - scheduler_output, - num_reqs) - - def _maybe_compute_attn_prefix( - self, - scheduler_output: "SchedulerOutput", - ) -> list[int]: - return [0] * len(self.kv_cache_config.kv_cache_groups) - - def _maybe_prepare_additional_inputs(self, - scheduler_output: "SchedulerOutput", - token_indices: torch.Tensor): - pass - + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, From 3b7729ec50e2c29b6cb42803a16a75bfe25f4699 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 9 Jul 2025 08:01:41 +0000 Subject: [PATCH 12/12] Some format chawnges to improve readability and diff sizes Signed-off-by: Christian Pinto --- vllm/v1/core/sched/scheduler.py | 7 ++++--- vllm/v1/engine/output_processor.py | 2 +- vllm/v1/engine/processor.py | 7 ++++--- vllm/v1/worker/gpu_model_runner.py | 8 ++++---- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index b35cf1f34c5..7f752e34698 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -92,7 +92,7 @@ def __init__( ) num_gpu_blocks = self.cache_config.num_gpu_blocks - # num_gpu_blocks can be ero for attention free models + # num_gpu_blocks can be zero for attention free models assert num_gpu_blocks is not None self.block_size = self.cache_config.block_size @@ -164,6 +164,7 @@ def __init__( log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, ) + self.use_pp = self.parallel_config.pipeline_parallel_size > 1 def schedule(self) -> SchedulerOutput: # NOTE(woosuk) on the scheduling algorithm: @@ -493,8 +494,8 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = \ - self.kv_cache_manager.get_block_ids(request.request_id) + req_to_new_block_ids[request.request_id] = ( + self.kv_cache_manager.get_block_ids(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 3be6c482121..5d8e15f8bbb 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -336,7 +336,7 @@ def add_request( parent_req=parent_req, request_index=request_index, queue=queue, - log_stats=self.log_stats) + log_stats=self.log_stats,) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index d776ce7585c..26d767ac026 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -387,9 +387,10 @@ def _validate_model_input( else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - if tokenizer: - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: + if ( + tokenizer and (max_input_id:=max(prompt_ids, default=0)) > + tokenizer.max_token_id + ): raise ValueError( f"Token id {max_input_id} is out of vocabulary") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 30931792f69..79150bfacc3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -330,9 +330,9 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: scheduler_output: The scheduler output. """ - # nothing to be reordered when the mdoel is attention free + # nothing to be reordered when the model is attention free if self.model_config.is_attention_free: - return False + return self.attn_metadata_builders[0].reorder_batch(self.input_batch, scheduler_output) @@ -561,7 +561,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: def _maybe_add_multimodal_kwargs( self, model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput" = None, + scheduler_output: "Optional[SchedulerOutput]" = None, num_reqs: int = -1, ): @@ -2105,7 +2105,7 @@ def _dummy_run( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, **MultiModalKwargs.as_kwargs( - model_kwargs, device=self.device)) + model_kwargs, device=self.device),) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs