From c623c60deb634734a585b255ff3f00782d409d10 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 5 Jun 2025 15:30:18 +0000 Subject: [PATCH 01/14] Better support for skip_tokenizer_init=True Signed-off-by: Christian Pinto --- vllm/config.py | 8 ++++++-- vllm/multimodal/registry.py | 2 +- vllm/v1/engine/llm_engine.py | 14 +++++++++----- vllm/v1/engine/output_processor.py | 5 ++++- vllm/v1/engine/processor.py | 12 ++++++++---- 5 files changed, 28 insertions(+), 13 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 766d7708625..4145819474a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -642,6 +642,7 @@ def __post_init__(self) -> None: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.multimodal_config = self._init_multimodal_config() + self.model_supports_multimodal_raw_input = self._init_model_supports_multimodal_raw_input() if not self.skip_tokenizer_init: self._verify_tokenizer_mode() @@ -753,6 +754,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: return None + def _init_model_supports_multimodal_raw_input(self): + return self.registry.supports_multimodal_raw_input(self.architectures) + def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -1201,10 +1205,10 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: return self.get_hf_config_sliding_window() def get_vocab_size(self) -> int: - return self.hf_text_config.vocab_size + return getattr(self.hf_text_config, "vocab_size", 0) def get_hidden_size(self) -> int: - return self.hf_text_config.hidden_size + return getattr(self.hf_text_config, "hidden_size", 0) @property def is_deepseek_mla(self) -> bool: diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 27aaa661c35..c44fcacd246 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -266,7 +266,7 @@ def create_processor( if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if tokenizer is None: + if tokenizer is None and not model_config.skip_tokenizer_init: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: mm_config = model_config.get_multimodal_config() diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a2328c37ba0..c8259aac77b 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,11 +82,15 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + + if not self.vllm_config.model_config.skip_tokenizer_init: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + else: + self.tokenizer = None # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bcd61d1f0a..05ea6cc492e 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -327,8 +327,11 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") + tokenizer = None if not self.tokenizer else \ + self.tokenizer.get_lora_tokenizer(request.lora_request) + req_state = RequestState.from_new_request( - tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), + tokenizer=tokenizer, request=request, prompt=prompt, parent_req=parent_req, diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7af4ed54a22..5dd51f6b774 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -380,7 +380,10 @@ def _validate_model_input( prompt_type: Literal["encoder", "decoder"], ): model_config = self.model_config - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + if model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) prompt_ids = prompt_inputs["prompt_token_ids"] if not prompt_ids: @@ -389,9 +392,10 @@ def _validate_model_input( else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError(f"Token id {max_input_id} is out of vocabulary") + if tokenizer: + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError(f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: From 146b3b2b6219ae969be03a6e938faad079a281f5 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 5 Jun 2025 15:31:55 +0000 Subject: [PATCH 02/14] Support for attention free models in V1 Signed-off-by: Christian Pinto --- .../models/prithvi_geospatial_mae.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index a36f24bc80e..496ef0b9a30 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -26,13 +26,13 @@ from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModal, - SupportsV0Only) + SupportsMultiModalWithRawInput) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalFieldElem, + MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem, + MultiModalSharedField, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) @@ -75,8 +75,8 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - location_coords=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.shared(batch_size=1, modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, modality="image"), ) def _get_prompt_updates( @@ -99,14 +99,24 @@ def apply( for k, v in mm_data.items(): mm_kwargs[k] = v + mm_place_holders = { + "image": [PlaceholderRange(offset=0, length=0)] + } + + multimodal_kwargs_items = [ + MultiModalKwargsItem.from_elems( + [MultiModalFieldElem(modality="image", key=key, data=data, field=MultiModalSharedField(1)) + for key, data in mm_kwargs.items()] + ) + ] return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs(mm_kwargs), + mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), mm_hashes=None, - mm_placeholders={}, + mm_placeholders=mm_place_holders, ) @@ -114,8 +124,7 @@ def apply( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, - SupportsV0Only): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModalWithRawInput): """ Prithvi Masked Autoencoder""" @classmethod @@ -180,7 +189,13 @@ def _parse_and_validate_multimodal_data( location_coords = None return pixel_values, location_coords - + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + # We do not really use any input tokens and therefore no embeddings to be calculated + # However, due to the mandatory token ids in the input prompt we pass one token and the + # size of the dummy embedding tensors must reflect that. + return torch.empty(input_ids.shape) + def forward( self, input_ids: Optional[torch.Tensor], @@ -202,7 +217,7 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) + return PoolerOutput([PoolingSequenceGroupOutput(hidden_states[0])]) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: From 7392c450661f92c2fdef8bc13206ac53089f41b1 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Fri, 6 Jun 2025 07:47:15 +0000 Subject: [PATCH 03/14] Last few changes after rebasing to latest branch version Signed-off-by: Christian Pinto --- .../models/prithvi_geospatial_mae.py | 4 +-- vllm/v1/worker/gpu_model_runner.py | 34 +++++++++++++++---- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 496ef0b9a30..fbf32dd4933 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -62,8 +62,8 @@ def get_dummy_mm_data( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((1, 6, 512, 512), 1.0), - "location_coords": torch.full((1, 2), 1.0), + "pixel_values": torch.full((1, 6, 512, 512), 1.0, dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af216539c90..6ddf2e5f75b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -123,7 +123,7 @@ def __init__( cache_config.cache_dtype] self.is_multimodal_model = model_config.is_multimodal_model - self.is_pooling_model = model_config.pooler_config is not None + self.model_supports_multimodal_raw_input = model_config.model_supports_multimodal_raw_input self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -326,6 +326,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: Args: scheduler_output: The scheduler output. """ + + # nothing to be reordered when the mdoel is attention free + if self.model_config.is_attention_free: + return False + self.attn_metadata_builders[0].reorder_batch(self.input_batch, scheduler_output) @@ -1019,13 +1024,14 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=len(grouped_mm_inputs), - ) + if curr_group_outputs: + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=len(grouped_mm_inputs), + ) - for output in curr_group_outputs: - encoder_outputs.append(output) + for output in curr_group_outputs: + encoder_outputs.append(output) # Cache the encoder outputs. for (req_id, input_id, pos_info), output in zip( @@ -1324,6 +1330,9 @@ def execute_model( # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] + self._maybe_add_model_args(num_scheduled_tokens, + model_kwargs, scheduler_output) + if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1339,6 +1348,7 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] + self._maybe_add_model_args(num_input_tokens, model_kwargs, scheduler_output) inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -1372,6 +1382,10 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + ) ) self.maybe_wait_for_kv_save() @@ -1998,6 +2012,8 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model + model_kwargs: dict[str, Any] = {} + self._maybe_add_model_args(num_tokens, model_kwargs) if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -2032,7 +2048,11 @@ def _dummy_run( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device) ) + if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: From 9a06b552eb481100580bd938eb6637766cdd00e0 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 5 Jun 2025 15:31:03 +0000 Subject: [PATCH 04/14] Support passing raw multimodal data to model Signed-off-by: Christian Pinto --- vllm/model_executor/models/interfaces.py | 36 +++++++++++++++++++++ vllm/model_executor/models/registry.py | 14 +++++++-- vllm/v1/worker/gpu_model_runner.py | 40 ++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 92ecb8972d5..9b26c678994 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -145,6 +145,42 @@ def supports_multimodal( return isinstance(model, SupportsMultiModal) +@runtime_checkable +class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): + """The interface required for all multi-modal models.""" + + supports_multimodal_raw_input: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + +@runtime_checkable +class _SupportsMultiModalWithRawInput(Protocol): + supports_multimodal_raw_input: ClassVar[Literal[True]] + + +@overload +def supports_multimodal_raw_input(model: object) -> TypeIs[SupportsMultiModalWithRawInput]: + ... + + +@overload +def supports_multimodal_raw_input(model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: + ... + + +def supports_multimodal_raw_input( + model: Union[type[object], object] +) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], TypeIs[SupportsMultiModalWithRawInput]]: + if isinstance(model, type): + return isinstance(model, _SupportsMultiModalWithRawInput) + + return isinstance(model, SupportsMultiModalWithRawInput) @runtime_checkable class SupportsScoreTemplate(Protocol): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b7f9638d322..61bb597d541 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -22,8 +22,9 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_pp, - supports_transcription, supports_v0_only) + supports_multimodal, supports_multimodal_raw_input, + supports_pp, supports_transcription, + supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -281,6 +282,7 @@ class _ModelInfo: is_pooling_model: bool supports_cross_encoding: bool supports_multimodal: bool + supports_multimodal_raw_input: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -298,6 +300,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": is_pooling_model=True, # Can convert any model into a pooling model supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), + supports_multimodal_raw_input=supports_multimodal_raw_input(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), @@ -536,6 +539,13 @@ def is_multimodal_model( ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_multimodal + + def supports_multimodal_raw_input( + self, + architectures: Union[str, list[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal_raw_input def is_pp_supported_model( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6ddf2e5f75b..00056967063 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -559,6 +559,46 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], + scheduler_output: "SchedulerOutput"): + # Multi-modal data. + if scheduler_output: + multi_modal_kwargs_list = [] + for req in scheduler_output.scheduled_new_reqs: + req_mm_inputs = req.mm_inputs + if not isinstance(req_mm_inputs, list): + req_mm_inputs = list(req_mm_inputs) + multi_modal_kwargs_list.extend(req_mm_inputs) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + else: + # The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data. + dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1) + multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data]) + + model_kwargs.update(multi_modal_kwargs) + + def _maybe_add_model_args(self, num_tokens: int, + model_kwargs: dict[str,Any], + scheduler_output: "SchedulerOutput"=None): + + if self.supports_token_type_ids: + model_kwargs["token_type_ids"] =\ + self.get_token_type_ids()[:num_tokens] + + if self.model_supports_multimodal_raw_input: + self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output) + + def _maybe_compute_attn_prefix( + self, + scheduler_output: "SchedulerOutput", + ) -> list[int]: + return [0] * len(self.kv_cache_config.kv_cache_groups) + + def _maybe_prepare_additional_inputs(self, + scheduler_output: "SchedulerOutput", + token_indices: torch.Tensor): + pass + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, From 9aa55337744993c1fe771db8bfce9287d68a96fd Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Fri, 6 Jun 2025 15:05:04 +0000 Subject: [PATCH 05/14] latest changes to align with the original branch Signed-off-by: Christian Pinto --- vllm/config.py | 1 + .../models/prithvi_geospatial_mae.py | 6 ++-- vllm/v1/worker/gpu_model_runner.py | 33 ++++++++++--------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4145819474a..a22e1fc7a6d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -642,6 +642,7 @@ def __post_init__(self) -> None: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.multimodal_config = self._init_multimodal_config() + self.is_pooling_model = self.registry.is_pooling_model(self.architectures) self.model_supports_multimodal_raw_input = self._init_model_supports_multimodal_raw_input() if not self.skip_tokenizer_init: self._verify_tokenizer_mode() diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index fbf32dd4933..7f7a6b76c71 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -62,7 +62,7 @@ def get_dummy_mm_data( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((1, 6, 512, 512), 1.0, dtype=torch.float16), + "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } @@ -178,7 +178,7 @@ def _parse_and_validate_multimodal_data( if not isinstance(pixel_values, torch.Tensor): raise ValueError(f"Incorrect type of pixel_values. " f"Got type: {type(pixel_values)}") - pixel_values = torch.unbind(pixel_values, dim=0)[0] + # pixel_values = torch.unbind(pixel_values, dim=0)[0] location_coords = kwargs.pop("location_coords", None) if not isinstance(location_coords, torch.Tensor): @@ -217,7 +217,7 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_states[0])]) + return PoolerOutput([PoolingSequenceGroupOutput(hidden_state) for hidden_state in hidden_states]) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 00056967063..1b85bc527b8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -123,6 +123,7 @@ def __init__( cache_config.cache_dtype] self.is_multimodal_model = model_config.is_multimodal_model + self.is_pooling_model = model_config.is_pooling_model self.model_supports_multimodal_raw_input = model_config.model_supports_multimodal_raw_input self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens @@ -560,7 +561,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput"): + scheduler_output: "SchedulerOutput", + num_reqs: int=-1): # Multi-modal data. if scheduler_output: multi_modal_kwargs_list = [] @@ -572,21 +574,20 @@ def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) else: # The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data. - dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1) - multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data]) + dummy_data = [self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1).multi_modal_data for i in range(num_reqs)] + # dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1) + # multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data]) + multi_modal_kwargs = MultiModalKwargs.batch(dummy_data) model_kwargs.update(multi_modal_kwargs) - def _maybe_add_model_args(self, num_tokens: int, + def _maybe_add_multimodal_kwargs(self, model_kwargs: dict[str,Any], - scheduler_output: "SchedulerOutput"=None): - - if self.supports_token_type_ids: - model_kwargs["token_type_ids"] =\ - self.get_token_type_ids()[:num_tokens] + scheduler_output: "SchedulerOutput"=None, + num_reqs: int=-1): if self.model_supports_multimodal_raw_input: - self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output) + self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output, num_reqs) def _maybe_compute_attn_prefix( self, @@ -1364,15 +1365,15 @@ def execute_model( mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - + + model_kwargs: dict[str, Any] = {} if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] - self._maybe_add_model_args(num_scheduled_tokens, - model_kwargs, scheduler_output) - + self._maybe_add_multimodal_kwargs(model_kwargs=model_kwargs, + scheduler_output=scheduler_output) if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1388,7 +1389,6 @@ def execute_model( # multimodal models, it is not desirable for performance since # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] - self._maybe_add_model_args(num_input_tokens, model_kwargs, scheduler_output) inputs_embeds = None if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] @@ -2053,8 +2053,9 @@ def _dummy_run( num_scheduled_tokens): model = self.model model_kwargs: dict[str, Any] = {} - self._maybe_add_model_args(num_tokens, model_kwargs) if self.is_multimodal_model: + self._maybe_add_multimodal_kwargs(model_kwargs=model_kwargs, + num_reqs=num_reqs) input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: From 5ac66e7deaba5fd509cfa402b77b98df4e01c984 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Tue, 24 Jun 2025 14:25:03 +0000 Subject: [PATCH 06/14] Latest changes to aadpt to upstream master Signed-off-by: Christian Pinto --- .../offline_inference/prithvi_geospatial_mae.py | 3 ++- vllm/v1/engine/core.py | 13 +++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 567c448a8c9..e0169eafbab 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -143,7 +143,8 @@ def __init__(self): self.model = LLM( model=os.path.join(os.path.dirname(__file__), "./model"), skip_tokenizer_init=True, - dtype="float32", + dtype="float16", + enforce_eager=True ) def run(self, input_data, location_coords): diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f5c59bef478..bdd5bb8ec43 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -134,8 +134,17 @@ def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() - # Get all kv cache needed by the model - kv_cache_specs = self.model_executor.get_kv_cache_specs() + #TODO: CP start from here + if vllm_config.model_config.is_attention_free: + # No need for initializing anything related to KV cache if the model + # is attention free. + kv_cache_specs = [] + kv_cache_configs = [ + KVCacheConfig(num_blocks=0, kv_cache_tensors={}, kv_cache_groups=[]) + ] + else: + # Get all kv cache needed by the model + kv_cache_specs = self.model_executor.get_kv_cache_specs() # Profiles the peak memory usage of the model to determine how much # memory can be allocated for kv cache. From 8f27e28c27c566c8505ff589b015cc3366081cf1 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 25 Jun 2025 10:52:35 +0000 Subject: [PATCH 07/14] Some reformatting to make the pre-commit hooks succeed Signed-off-by: Christian Pinto --- .../prithvi_geospatial_mae.py | 2 +- .../multimodal/pooling/test_prithvi_mae.py | 44 +++++++++++++ vllm/config.py | 6 +- vllm/model_executor/models/interfaces.py | 9 ++- .../models/prithvi_geospatial_mae.py | 52 +++++++++------ vllm/model_executor/models/registry.py | 5 +- vllm/v1/engine/core.py | 17 ++--- vllm/v1/engine/llm_engine.py | 3 +- vllm/v1/engine/output_processor.py | 15 ++--- vllm/v1/engine/processor.py | 3 +- vllm/v1/worker/gpu_model_runner.py | 64 ++++++++++--------- 11 files changed, 137 insertions(+), 83 deletions(-) create mode 100644 tests/models/multimodal/pooling/test_prithvi_mae.py diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index e0169eafbab..e36d01c249f 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -144,7 +144,7 @@ def __init__(self): model=os.path.join(os.path.dirname(__file__), "./model"), skip_tokenizer_init=True, dtype="float16", - enforce_eager=True + enforce_eager=True, ) def run(self, input_data, location_coords): diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py new file mode 100644 index 00000000000..870edad3632 --- /dev/null +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from ....conftest import VllmRunner + +def generate_test_mm_data(): + mm_data = { + "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), + } + return mm_data + +def _run_test( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + + mm_data = generate_test_mm_data() + prompt = { + # This model deals with no text input + "prompt_token_ids": [1], + "multi_modal_data": mm_data + } + with vllm_runner(model, task="embed", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True) as vllm_model: + output = vllm_model.encode(prompt) + +MODELS=["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] +@pytest.mark.parametrize("model", MODELS) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, +) -> None: + _run_test( + vllm_runner, + model, + ) \ No newline at end of file diff --git a/vllm/config.py b/vllm/config.py index a22e1fc7a6d..047d1984b85 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -642,8 +642,10 @@ def __post_init__(self) -> None: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.multimodal_config = self._init_multimodal_config() - self.is_pooling_model = self.registry.is_pooling_model(self.architectures) - self.model_supports_multimodal_raw_input = self._init_model_supports_multimodal_raw_input() + self.is_pooling_model = self.registry.is_pooling_model( + self.architectures) + self.model_supports_multimodal_raw_input = ( + self._init_model_supports_multimodal_raw_input()) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 9b26c678994..010d6dc4351 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -165,18 +165,21 @@ class _SupportsMultiModalWithRawInput(Protocol): @overload -def supports_multimodal_raw_input(model: object) -> TypeIs[SupportsMultiModalWithRawInput]: +def supports_multimodal_raw_input( + model: object) -> TypeIs[SupportsMultiModalWithRawInput]: ... @overload -def supports_multimodal_raw_input(model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: +def supports_multimodal_raw_input( + model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: ... def supports_multimodal_raw_input( model: Union[type[object], object] -) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], TypeIs[SupportsMultiModalWithRawInput]]: +) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], + TypeIs[SupportsMultiModalWithRawInput]]: if isinstance(model, type): return isinstance(model, _SupportsMultiModalWithRawInput) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 7f7a6b76c71..cc34721cf41 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -25,13 +25,14 @@ from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModalWithRawInput) +from vllm.model_executor.models.interfaces import ( + IsAttentionFree, SupportsMultiModalWithRawInput) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalFieldElem, - MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem, +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalFieldElem, MultiModalInputs, + MultiModalKwargs, MultiModalKwargsItem, MultiModalSharedField, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, @@ -62,7 +63,8 @@ def get_dummy_mm_data( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), + "pixel_values": torch.full((6, 512, 512), 1.0, + dtype=torch.float16), "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } @@ -75,8 +77,10 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - pixel_values=MultiModalFieldConfig.shared(batch_size=1, modality="image"), - location_coords=MultiModalFieldConfig.shared(batch_size=1, modality="image"), + pixel_values=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), ) def _get_prompt_updates( @@ -99,15 +103,16 @@ def apply( for k, v in mm_data.items(): mm_kwargs[k] = v - mm_place_holders = { - "image": [PlaceholderRange(offset=0, length=0)] - } + mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]} multimodal_kwargs_items = [ - MultiModalKwargsItem.from_elems( - [MultiModalFieldElem(modality="image", key=key, data=data, field=MultiModalSharedField(1)) - for key, data in mm_kwargs.items()] - ) + MultiModalKwargsItem.from_elems([ + MultiModalFieldElem(modality="image", + key=key, + data=data, + field=MultiModalSharedField(1)) + for key, data in mm_kwargs.items() + ]) ] return MultiModalInputs( @@ -124,7 +129,8 @@ def apply( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModalWithRawInput): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, + SupportsMultiModalWithRawInput): """ Prithvi Masked Autoencoder""" @classmethod @@ -189,13 +195,14 @@ def _parse_and_validate_multimodal_data( location_coords = None return pixel_values, location_coords - + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - # We do not really use any input tokens and therefore no embeddings to be calculated - # However, due to the mandatory token ids in the input prompt we pass one token and the - # size of the dummy embedding tensors must reflect that. + # We do not really use any input tokens and therefore no embeddings + # to be calculated. However, due to the mandatory token ids in + # the input prompt we pass one token and the size of the dummy + # embedding tensors must reflect that. return torch.empty(input_ids.shape) - + def forward( self, input_ids: Optional[torch.Tensor], @@ -217,7 +224,10 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_state) for hidden_state in hidden_states]) + return PoolerOutput([ + PoolingSequenceGroupOutput(hidden_state) + for hidden_state in hidden_states + ]) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 61bb597d541..6e706a6eb0a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -23,8 +23,7 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, supports_multimodal, supports_multimodal_raw_input, - supports_pp, supports_transcription, - supports_v0_only) + supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -539,7 +538,7 @@ def is_multimodal_model( ) -> bool: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_multimodal - + def supports_multimodal_raw_input( self, architectures: Union[str, list[str]], diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index bdd5bb8ec43..eb9daa37603 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -134,17 +134,8 @@ def _initialize_kv_caches( self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]: start = time.time() - #TODO: CP start from here - if vllm_config.model_config.is_attention_free: - # No need for initializing anything related to KV cache if the model - # is attention free. - kv_cache_specs = [] - kv_cache_configs = [ - KVCacheConfig(num_blocks=0, kv_cache_tensors={}, kv_cache_groups=[]) - ] - else: - # Get all kv cache needed by the model - kv_cache_specs = self.model_executor.get_kv_cache_specs() + # Get all kv cache needed by the model + kv_cache_specs = self.model_executor.get_kv_cache_specs() # Profiles the peak memory usage of the model to determine how much # memory can be allocated for kv cache. @@ -161,8 +152,8 @@ def _initialize_kv_caches( kv_cache_configs = [ get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, available_gpu_memory_one_worker) - for kv_cache_spec_one_worker, available_gpu_memory_one_worker in - zip(kv_cache_specs, available_gpu_memory) + for kv_cache_spec_one_worker, available_gpu_memory_one_worker + in zip(kv_cache_specs, available_gpu_memory) ] # Since we use a shared centralized controller, we need the diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index c8259aac77b..e13a544b50f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,8 +82,7 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - - if not self.vllm_config.model_config.skip_tokenizer_init: + if not self.vllm_config.model_config.skip_tokenizer_init: # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 05ea6cc492e..3be6c482121 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -330,14 +330,13 @@ def add_request( tokenizer = None if not self.tokenizer else \ self.tokenizer.get_lora_tokenizer(request.lora_request) - req_state = RequestState.from_new_request( - tokenizer=tokenizer, - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + req_state = RequestState.from_new_request(tokenizer=tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5dd51f6b774..60e97c78a5f 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -395,7 +395,8 @@ def _validate_model_input( if tokenizer: max_input_id = max(prompt_ids, default=0) if max_input_id > tokenizer.max_token_id: - raise ValueError(f"Token id {max_input_id} is out of vocabulary") + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b85bc527b8..56b2d8e4d70 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -124,7 +124,8 @@ def __init__( self.is_multimodal_model = model_config.is_multimodal_model self.is_pooling_model = model_config.is_pooling_model - self.model_supports_multimodal_raw_input = model_config.model_supports_multimodal_raw_input + self.model_supports_multimodal_raw_input = ( + model_config.model_supports_multimodal_raw_input) self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -560,9 +561,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() - def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput", - num_reqs: int=-1): + def _add_multimodal_inputs_to_model_args( + self, + model_kwargs: dict[str, Any], + scheduler_output: "SchedulerOutput", + num_reqs: int = -1): # Multi-modal data. if scheduler_output: multi_modal_kwargs_list = [] @@ -571,23 +574,30 @@ def _add_multimodal_inputs_to_model_args(self, model_kwargs: dict[str, Any], if not isinstance(req_mm_inputs, list): req_mm_inputs = list(req_mm_inputs) multi_modal_kwargs_list.extend(req_mm_inputs) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + multi_modal_kwargs = MultiModalKwargs.batch( + multi_modal_kwargs_list) else: - # The only case where SchedulerOtput is None is for a dummy run, let's get some dummy data. - dummy_data = [self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1).multi_modal_data for i in range(num_reqs)] - # dummy_data = self.mm_registry.get_decoder_dummy_data(model_config=self.model_config, seq_len =1) - # multi_modal_kwargs = MultiModalKwargs.batch([dummy_data.multi_modal_data]) + # The only case where SchedulerOtput is None is for a dummy run, + # let's get some dummy data. + dummy_data = [ + self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, seq_len=1).multi_modal_data + for i in range(num_reqs) + ] multi_modal_kwargs = MultiModalKwargs.batch(dummy_data) - + model_kwargs.update(multi_modal_kwargs) - def _maybe_add_multimodal_kwargs(self, - model_kwargs: dict[str,Any], - scheduler_output: "SchedulerOutput"=None, - num_reqs: int=-1): + def _maybe_add_multimodal_kwargs( + self, + model_kwargs: dict[str, Any], + scheduler_output: "SchedulerOutput" = None, + num_reqs: int = -1): if self.model_supports_multimodal_raw_input: - self._add_multimodal_inputs_to_model_args(model_kwargs, scheduler_output, num_reqs) + self._add_multimodal_inputs_to_model_args(model_kwargs, + scheduler_output, + num_reqs) def _maybe_compute_attn_prefix( self, @@ -1365,15 +1375,15 @@ def execute_model( mm_embeds = self._gather_mm_embeddings(scheduler_output) else: mm_embeds = [] - + model_kwargs: dict[str, Any] = {} if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] - self._maybe_add_multimodal_kwargs(model_kwargs=model_kwargs, - scheduler_output=scheduler_output) + self._maybe_add_multimodal_kwargs( + model_kwargs=model_kwargs, scheduler_output=scheduler_output) if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1425,8 +1435,7 @@ def execute_model( **MultiModalKwargs.as_kwargs( model_kwargs, device=self.device, - ) - ) + )) self.maybe_wait_for_kv_save() @@ -2084,15 +2093,12 @@ def _dummy_run( self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp): - outputs = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **MultiModalKwargs.as_kwargs( - model_kwargs, - device=self.device) - ) + outputs = model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, device=self.device)) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs From 5fe55fd97f68d67f3ab54ea05db8ec47c09f8e98 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 25 Jun 2025 12:54:25 +0000 Subject: [PATCH 08/14] Few more changes to solve some other pre-commit hooks failures Signed-off-by: Christian Pinto --- .../multimodal/pooling/test_prithvi_mae.py | 24 +++++++++++-------- vllm/v1/core/kv_cache_manager.py | 2 +- vllm/v1/core/sched/scheduler.py | 4 ++-- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 870edad3632..7350f2990c1 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -6,17 +6,19 @@ from ....conftest import VllmRunner + def generate_test_mm_data(): mm_data = { "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } return mm_data - + + def _run_test( vllm_runner: type[VllmRunner], model: str, -) -> None: +) -> None: mm_data = generate_test_mm_data() prompt = { @@ -24,13 +26,15 @@ def _run_test( "prompt_token_ids": [1], "multi_modal_data": mm_data } - with vllm_runner(model, task="embed", - dtype=torch.float16, - enforce_eager=True, - skip_tokenizer_init=True) as vllm_model: - output = vllm_model.encode(prompt) - -MODELS=["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + with vllm_runner(model, + task="embed", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True) as vllm_model: + vllm_model.encode(prompt) + +MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + @pytest.mark.parametrize("model", MODELS) def test_models_image( hf_runner, @@ -41,4 +45,4 @@ def test_models_image( _run_test( vllm_runner, model, - ) \ No newline at end of file + ) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index e820a0ad6d5..8195020e97d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from typing import Optional @@ -65,7 +66,6 @@ def new_empty(self) -> "KVCacheBlocks": class KVCacheManager: - def __init__( self, kv_cache_config: KVCacheConfig, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 446f98034cb..8b85b37b1a9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -488,8 +488,8 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = ( - self.kv_cache_manager.get_block_ids(request.request_id)) + req_to_new_block_ids[request.request_id] = \ + self.kv_cache_manager.get_block_ids(request.request_id) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING From c992ea311e908075f39af883d1fa0a6855ca3db8 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Thu, 26 Jun 2025 08:31:59 +0000 Subject: [PATCH 09/14] Some style changes - Improved formatting around - made is_pooling_model a @property in ModelConfig Signed-off-by: Christian Pinto --- tests/models/multimodal/pooling/test_prithvi_mae.py | 2 ++ vllm/config.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 7350f2990c1..8179aa064ef 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -33,8 +33,10 @@ def _run_test( skip_tokenizer_init=True) as vllm_model: vllm_model.encode(prompt) + MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + @pytest.mark.parametrize("model", MODELS) def test_models_image( hf_runner, diff --git a/vllm/config.py b/vllm/config.py index 047d1984b85..de4a2a3f218 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -642,8 +642,6 @@ def __post_init__(self) -> None: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.multimodal_config = self._init_multimodal_config() - self.is_pooling_model = self.registry.is_pooling_model( - self.architectures) self.model_supports_multimodal_raw_input = ( self._init_model_supports_multimodal_raw_input()) if not self.skip_tokenizer_init: @@ -1516,6 +1514,10 @@ def uses_mrope(self) -> bool: @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + + @property + def is_pooling_model(self) -> bool: + return self.registry.is_pooling_model(self.architectures) @property def is_cross_encoder(self) -> bool: From 137ec297399a1a93571130cf9c699d441ab5d0fc Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Fri, 27 Jun 2025 08:48:10 +0000 Subject: [PATCH 10/14] Simple code refactoring - Remove unused functions - merged functions not called anywhere else Signed-off-by: Christian Pinto --- vllm/v1/worker/gpu_model_runner.py | 34 ++++++++---------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 56b2d8e4d70..384de89bad0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -561,11 +561,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() - def _add_multimodal_inputs_to_model_args( + def _maybe_add_multimodal_kwargs( self, model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput", - num_reqs: int = -1): + scheduler_output: "SchedulerOutput" = None, + num_reqs: int = -1, + ): + + if not self.model_supports_multimodal_raw_input: + return + # Multi-modal data. if scheduler_output: multi_modal_kwargs_list = [] @@ -588,28 +593,7 @@ def _add_multimodal_inputs_to_model_args( model_kwargs.update(multi_modal_kwargs) - def _maybe_add_multimodal_kwargs( - self, - model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput" = None, - num_reqs: int = -1): - - if self.model_supports_multimodal_raw_input: - self._add_multimodal_inputs_to_model_args(model_kwargs, - scheduler_output, - num_reqs) - - def _maybe_compute_attn_prefix( - self, - scheduler_output: "SchedulerOutput", - ) -> list[int]: - return [0] * len(self.kv_cache_config.kv_cache_groups) - - def _maybe_prepare_additional_inputs(self, - scheduler_output: "SchedulerOutput", - token_indices: torch.Tensor): - pass - + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, From e59d7dc49046285e3e3031f383bc0dafd86901f0 Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 16 Jul 2025 10:31:54 +0000 Subject: [PATCH 11/14] Rebased to master Signed-off-by: Christian Pinto --- .../prithvi_geospatial_mae.py | 437 ++++++++---------- .../multimodal/pooling/test_prithvi_mae.py | 2 +- vllm/config.py | 4 - .../models/prithvi_geospatial_mae.py | 3 +- vllm/v1/core/kv_cache_manager.py | 2 +- vllm/v1/core/sched/scheduler.py | 4 +- vllm/v1/engine/core.py | 4 +- vllm/v1/engine/output_processor.py | 15 +- vllm/v1/engine/processor.py | 10 +- vllm/v1/worker/gpu_model_runner.py | 17 +- 10 files changed, 215 insertions(+), 283 deletions(-) diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index e36d01c249f..314b6739b51 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -1,188 +1,144 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This is a demo script showing how to use the -PrithviGeospatialMAE model with vLLM -This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa - -Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa - -The requirements for running this script are: -- Installing [terratorch, albumentations, rasterio] in your python environment -- downloading the model weights in a 'model' folder local to the script - (temporary measure until the proper config.json file is uploaded to HF) -- download an input example image (India_900498_S2Hand.tif) and place it in - the same folder with the script (or specify with the --data_file argument) - -Run the example: -python prithvi_geospatial_mae.py - -""" # noqa: E501 - -import argparse -import datetime import os -from typing import Union - import albumentations +from terratorch.datamodules import Sen1Floods11NonGeoDataModule +import argparse +from typing import List, Union +import re +import datetime import numpy as np import rasterio -import regex as re import torch +import time from einops import rearrange -from terratorch.datamodules import Sen1Floods11NonGeoDataModule - +from typing import Tuple from vllm import LLM +torch.set_default_dtype(torch.float16) + NO_DATA = -9999 NO_DATA_FLOAT = 0.0001 OFFSET = 0 PERCENTILE = 99 -model_config = """{ - "architectures": ["PrithviGeoSpatialMAE"], - "num_classes": 0, - "pretrained_cfg": { - "task_args": { - "task": "SemanticSegmentationTask", - "model_factory": "EncoderDecoderFactory", - "loss": "ce", - "ignore_index": -1, - "lr": 0.001, - "freeze_backbone": false, - "freeze_decoder": false, - "plot_on_val": 10, - "optimizer": "AdamW", - "scheduler": "CosineAnnealingLR" - }, - "model_args": { - "backbone_pretrained": false, - "backbone": "prithvi_eo_v2_300_tl", - "decoder": "UperNetDecoder", - "decoder_channels": 256, - "decoder_scale_modules": true, - "num_classes": 2, - "rescale": true, - "backbone_bands": [ - "BLUE", - "GREEN", - "RED", - "NIR_NARROW", - "SWIR_1", - "SWIR_2" - ], - "head_dropout": 0.1, - "necks": [ - { - "name": "SelectIndices", - "indices": [ - 5, - 11, - 17, - 23 - ] - }, - { - "name": "ReshapeTokensToImage" - } - ] - }, - "optimizer_params" : { - "lr": 5.0e-05, - "betas": [0.9, 0.999], - "eps": [1.0e-08], - "weight_decay": 0.05, - "amsgrad": false, - "maximize": false, - "capturable": false, - "differentiable": false - }, - "scheduler_params" : { - "T_max": 50, - "eta_min": 0, - "last_epoch": -1, - "verbose": "deprecated" - } - }, - - - "torch_dtype": "float32" -} -""" - -# Temporarily creating the "config.json" for the model. -# This is going to disappear once the correct config.json is available on HF -with open( - os.path.join(os.path.dirname(__file__), "./model/config.json"), "w" -) as config_file: - config_file.write(model_config) - datamodule_config = { - "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], - "batch_size": 16, - "constant_scale": 0.0001, - "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11", - "drop_last": True, - "no_data_replace": 0.0, - "no_label_replace": -1, - "num_workers": 8, - "test_transform": [ - albumentations.Resize( - always_apply=False, height=448, interpolation=1, p=1, width=448 - ), - albumentations.pytorch.ToTensorV2( - transpose_mask=False, always_apply=True, p=1.0 - ), - ], + 'bands': ['BLUE', + 'GREEN', + 'RED', + 'NIR_NARROW', + 'SWIR_1', + 'SWIR_2'], + 'batch_size': 16, + 'constant_scale': 0.0001, + 'data_root': '/dccstor/geofm-finetuning/datasets/sen1floods11', + 'drop_last': True, + 'no_data_replace': 0.0, + 'no_label_replace': -1, + 'num_workers': 8, + 'test_transform': [albumentations.Resize(always_apply=False, + height=448, + interpolation=1, + p=1, + width=448), + albumentations.pytorch.ToTensorV2( + transpose_mask=False, + always_apply=True, + p=1.0 + )], } - class PrithviMAE: - def __init__(self): - print("Initializing PrithviMAE model") - self.model = LLM( - model=os.path.join(os.path.dirname(__file__), "./model"), - skip_tokenizer_init=True, - dtype="float16", - enforce_eager=True, - ) - - def run(self, input_data, location_coords): + def __init__(self,model): + print("Initializing Terratorch model") + #self.model = LLM(model=os.path.join(os.path.dirname(__file__), "./model"), skip_tokenizer_init=True, dtype="float32") + self.model = LLM(model=model,skip_tokenizer_init=True, dtype="float16",enforce_eager=True) + + def patchify(self, pixel_values): + """ + Args: + pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`): + Pixel values. + + Returns: + torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: + Patchified pixel values. + """ + patch_size_t, patch_size_h, patch_size_w = self.get_patch_size() + num_channels = self.get_num_channels() + + # patchify + patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', + c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w) + + + return patchified_pixel_values + + def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None): + """ + Args: + patchified_pixel_values (`torch.FloatTensor` of shape + `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: + Patchified pixel values. + image_size (`Tuple[int, int]`, *optional*): + Original image size. + + Returns: + `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: + Pixel values. + """ + patch_size_t, patch_size_h, patch_size_w = self.get_patch_size() + image_size = to_2tuple(image_size) if image_size is not None else self.get_img_size() + original_height, original_width = image_size + num_patches_h = original_height // patch_size_h + num_patches_w = original_width // patch_size_w + num_channels = self.get_num_channels() + + pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', + c=num_channels, h=num_patches_h, w=num_patches_w, + s=patch_size_t, p=patch_size_h, q=patch_size_w) + return pixel_values + + def run(self, input_data, temporal_coords, location_coords): print("################ Running inference on vLLM ##############") # merge the inputs into one data structure + if input_data is not None and input_data.dtype == torch.float32 : + input_data= input_data.to(torch.float16) + input_data = input_data[0] + mm_data = { "pixel_values": torch.empty(0) if input_data is None else input_data, - "location_coords": torch.empty(0) - if location_coords is None - else location_coords, + "location_coords": torch.empty(0) if location_coords is None else location_coords } - prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} + prompt = { + "prompt_token_ids": [1], + "multi_modal_data": mm_data + } + start = time.time() outputs = self.model.encode(prompt, use_tqdm=False) - print("################ Inference done (it took seconds) ##############") + end = time.time() + elapsed = end - start + print(f"################ Inference done (it took {round(elapsed,2)} seconds) ##############") return outputs[0].outputs.data - def generate_datamodule(): - datamodule = Sen1Floods11NonGeoDataModule( - data_root=datamodule_config["data_root"], - batch_size=datamodule_config["batch_size"], - num_workers=datamodule_config["num_workers"], - bands=datamodule_config["bands"], - drop_last=datamodule_config["drop_last"], - test_transform=datamodule_config["test_transform"], - ) + - return datamodule + datamodule = Sen1Floods11NonGeoDataModule(data_root=datamodule_config['data_root'], + batch_size=datamodule_config["batch_size"], + num_workers=datamodule_config["num_workers"], + bands=datamodule_config["bands"], + drop_last=datamodule_config["drop_last"], + test_transform=datamodule_config["test_transform" + ""]) + return datamodule def process_channel_group(orig_img, channels): """ Args: - orig_img: torch.Tensor representing original image (reference) - with shape = (bands, H, W). + orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W). channels: list of indices representing RGB channels. Returns: @@ -193,6 +149,7 @@ def process_channel_group(orig_img, channels): valid_mask = torch.ones_like(orig_img, dtype=torch.bool) valid_mask[orig_img == NO_DATA_FLOAT] = False + # Rescale (enhancing contrast) max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE)) min_value = OFFSET @@ -221,7 +178,7 @@ def read_geotiff(file_path: str): meta = src.meta try: coords = src.lnglat() - except Exception: + except: # Cannot read coords coords = None @@ -252,19 +209,17 @@ def _convert_np_uint8(float_image: torch.Tensor): def load_example( - file_paths: list[str], - mean: list[float] = None, - std: list[float] = None, + file_paths: List[str], + mean: List[float] = None, + std: List[float] = None, indices: Union[list[int], None] = None, ): """Build an input example by loading images in *file_paths*. Args: file_paths: list of file paths . - mean: list containing mean values for each band in the images - in *file_paths*. - std: list containing std values for each band in the images - in *file_paths*. + mean: list containing mean values for each band in the images in *file_paths*. + std: list containing std values for each band in the images in *file_paths*. Returns: np.array containing created example @@ -292,38 +247,26 @@ def load_example( location_coords.append(coords) try: - match = re.search(r"(\d{7,8}T\d{6})", file) + match = re.search(r'(\d{7,8}T\d{6})', file) if match: year = int(match.group(1)[:4]) - julian_day = match.group(1).split("T")[0][4:] + julian_day = match.group(1).split('T')[0][4:] if len(julian_day) == 3: julian_day = int(julian_day) else: - julian_day = ( - datetime.datetime.strptime(julian_day, "%m%d") - .timetuple() - .tm_yday - ) + julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday temporal_coords.append([year, julian_day]) except Exception as e: - print(f"Could not extract timestamp for {file} ({e})") + print(f'Could not extract timestamp for {file} ({e})') imgs = np.stack(imgs, axis=0) # num_frames, H, W, C - imgs = np.moveaxis(imgs, -1, 0).astype("float32") + imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W imgs = np.expand_dims(imgs, axis=0) # add batch di return imgs, temporal_coords, location_coords, metas -def run_model( - input_data, - temporal_coords, - location_coords, - model, - datamodule, - img_size, - lightning_model=None, -): +def run_model(input_data, temporal_coords, location_coords, model, datamodule, img_size, lightning_model=None): # Reflect pad if not divisible by img_size original_h, original_w = input_data.shape[-2:] pad_h = (img_size - (original_h % img_size)) % img_size @@ -333,8 +276,10 @@ def run_model( ) # Build sliding window + batch_size = 1 - batch = torch.tensor(input_data, device="cpu") + #batch = torch.tensor(input_data, device="cpu") + batch = torch.tensor(input_data) windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) h1, w1 = windows.shape[3:5] windows = rearrange( @@ -345,39 +290,39 @@ def run_model( num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 windows = torch.tensor_split(windows, num_batches, dim=0) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') if temporal_coords: - temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0) + #temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0) + temporal_coords = torch.tensor(temporal_coords).unsqueeze(0) else: temporal_coords = None if location_coords: - location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0) + #location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0) + location_coords = torch.tensor(location_coords[0]).unsqueeze(0) else: location_coords = None - # Run model + # Run Prithvi-EO-V2-300M-TL-Sen1Floods11 pred_imgs = [] for x in windows: # Apply standardization - x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0)) - x = datamodule.aug(x)["image"] + x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0)) + x = datamodule.aug(x)['image'] with torch.no_grad(): - x = x.to(device) - pred = model.run(x, location_coords=location_coords) + pred = model.run(x, temporal_coords=temporal_coords, location_coords=location_coords) if lightning_model: - pred_lightning = lightning_model( - x, temporal_coords=temporal_coords, location_coords=location_coords - ) + pred_lightning = lightning_model(x, temporal_coords=temporal_coords, location_coords=location_coords) pred_lightning = pred_lightning.output.detach().cpu() if not torch.equal(pred, pred_lightning): print("Inference output is not equal") y_hat = pred.argmax(dim=1) - y_hat = torch.nn.functional.interpolate( - y_hat.unsqueeze(1).float(), size=img_size, mode="nearest" - ) + y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), size=img_size, mode="nearest") pred_imgs.append(y_hat) @@ -403,57 +348,25 @@ def run_model( return pred_imgs - -def parse_args(): - parser = argparse.ArgumentParser("MAE run inference", add_help=False) - - parser.add_argument( - "--data_file", - type=str, - default="./India_900498_S2Hand.tif", - help="Path to the file.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="output", - help="Path to the directory where to save outputs.", - ) - parser.add_argument( - "--input_indices", - default=[1, 2, 3, 8, 11, 12], - type=int, - nargs="+", - help="0-based indices of the six Prithvi channels to be selected from the " - "input. By default selects [1,2,3,8,11,12] for S2L1C data.", - ) - parser.add_argument( - "--rgb_outputs", - action="store_true", - help="If present, output files will only contain RGB channels. " - "Otherwise, all bands will be saved.", - ) - - def main( data_file: str, + model: str, output_dir: str, rgb_outputs: bool, input_indices: list[int] = None, ): os.makedirs(output_dir, exist_ok=True) - # Load model --------------------------------------------------------------- + # Load Prithvi-EO-V2-300M-TL-Sen1Floods11 --------------------------------------------------------------------------------- - model_obj = PrithviMAE() + model_obj = PrithviMAE(model=model) datamodule = generate_datamodule() - img_size = 256 # Size of Sen1Floods11 + img_size = 512 # Size of Sen1Floods11 - # Loading data ------------------------------------------------------------- + # Loading data --------------------------------------------------------------------------------- input_data, temporal_coords, location_coords, meta_data = load_example( - file_paths=[data_file], - indices=input_indices, + file_paths=[data_file], indices=input_indices, ) meta_data = meta_data[0] # only one image @@ -461,21 +374,17 @@ def main( if input_data.mean() > 1: input_data = input_data / 10000 # Convert to range 0-1 - # Running model ------------------------------------------------------------ + # Running Prithvi-EO-V2-300M-TL-Sen1Floods11 -------------------------------------------------------------------------------- - channels = [ - datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"] - ] # BGR -> RGB + channels = [datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]] # BGR -> RGB - pred = run_model( - input_data, temporal_coords, location_coords, model_obj, datamodule, img_size - ) + # lightning_model = LightningInferenceModel.from_config(config, checkpoint) + pred = run_model(input_data, temporal_coords, location_coords, + model_obj, datamodule, img_size) # Save pred meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) - pred_file = os.path.join( - output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff" - ) + pred_file = os.path.join(output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") save_geotiff(_convert_np_uint8(pred), pred_file, meta_data) # Save image + pred @@ -488,14 +397,13 @@ def main( orig_img=torch.Tensor(input_data[0, :, 0, ...]), channels=channels, ) + rgb_orig= rgb_orig.to(torch.float32) - pred[pred == 0.0] = np.nan + pred[pred == 0.] = np.nan img_pred = rgb_orig * 0.7 + pred * 0.3 img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()] - img_pred_file = os.path.join( - output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff" - ) + img_pred_file = os.path.join(output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") save_geotiff( image=_convert_np_uint8(img_pred), output_path=img_pred_file, @@ -504,18 +412,51 @@ def main( # Save image rgb if rgb_outputs: - rgb_file = os.path.join( - output_dir, - f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff", - ) + rgb_file = os.path.join(output_dir, f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") save_geotiff( image=_convert_np_uint8(rgb_orig), output_path=rgb_file, meta=meta_data, ) + print("Done!") + if __name__ == "__main__": - args = parse_args() + parser = argparse.ArgumentParser("MAE run inference", add_help=False) + + parser.add_argument( + "--data_file", + type=str, + default="/workspace/scripts/demo_flooding/examples/India_900498_S2Hand.tif", + help="Path to the file.", + ) + + parser.add_argument( + "--model", + type=str, + default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + help="Path to a checkpoint file to load from.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Path to the directory where to save outputs.", + ) + parser.add_argument( + "--input_indices", + default=[1,2,3,8,11,12], + type=int, + nargs="+", + help="0-based indices of the six Prithvi channels to be selected from the input. By default selects [1,2,3,8,11,12] for S2L1C data.", + ) + parser.add_argument( + "--rgb_outputs", + action="store_true", + help="If present, output files will only contain RGB channels. " + "Otherwise, all bands will be saved.", + ) + args = parser.parse_args() main(**vars(args)) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 8179aa064ef..3ff419f3d7b 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -36,7 +36,7 @@ def _run_test( MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] - +@pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) def test_models_image( hf_runner, diff --git a/vllm/config.py b/vllm/config.py index de4a2a3f218..c5f91e4c217 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1514,10 +1514,6 @@ def uses_mrope(self) -> bool: @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None - - @property - def is_pooling_model(self) -> bool: - return self.registry.is_pooling_model(self.architectures) @property def is_cross_encoder(self) -> bool: diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index cc34721cf41..26f5e594f30 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -184,7 +184,6 @@ def _parse_and_validate_multimodal_data( if not isinstance(pixel_values, torch.Tensor): raise ValueError(f"Incorrect type of pixel_values. " f"Got type: {type(pixel_values)}") - # pixel_values = torch.unbind(pixel_values, dim=0)[0] location_coords = kwargs.pop("location_coords", None) if not isinstance(location_coords, torch.Tensor): @@ -201,7 +200,7 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: # to be calculated. However, due to the mandatory token ids in # the input prompt we pass one token and the size of the dummy # embedding tensors must reflect that. - return torch.empty(input_ids.shape) + return torch.empty((input_ids.shape[0], 0)) def forward( self, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 8195020e97d..e820a0ad6d5 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass from typing import Optional @@ -66,6 +65,7 @@ def new_empty(self) -> "KVCacheBlocks": class KVCacheManager: + def __init__( self, kv_cache_config: KVCacheConfig, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8b85b37b1a9..446f98034cb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -488,8 +488,8 @@ def schedule(self) -> SchedulerOutput: if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) - req_to_new_block_ids[request.request_id] = \ - self.kv_cache_manager.get_block_ids(request.request_id) + req_to_new_block_ids[request.request_id] = ( + self.kv_cache_manager.get_block_ids(request.request_id)) num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index eb9daa37603..f5c59bef478 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -152,8 +152,8 @@ def _initialize_kv_caches( kv_cache_configs = [ get_kv_cache_config(vllm_config, kv_cache_spec_one_worker, available_gpu_memory_one_worker) - for kv_cache_spec_one_worker, available_gpu_memory_one_worker - in zip(kv_cache_specs, available_gpu_memory) + for kv_cache_spec_one_worker, available_gpu_memory_one_worker in + zip(kv_cache_specs, available_gpu_memory) ] # Since we use a shared centralized controller, we need the diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 3be6c482121..05ea6cc492e 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -330,13 +330,14 @@ def add_request( tokenizer = None if not self.tokenizer else \ self.tokenizer.get_lora_tokenizer(request.lora_request) - req_state = RequestState.from_new_request(tokenizer=tokenizer, - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + req_state = RequestState.from_new_request( + tokenizer=tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 60e97c78a5f..45ac61fb30b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -384,6 +384,10 @@ def _validate_model_input( tokenizer = None else: tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") prompt_ids = prompt_inputs["prompt_token_ids"] if not prompt_ids: @@ -392,12 +396,6 @@ def _validate_model_input( else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - if tokenizer: - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError( - f"Token id {max_input_id} is out of vocabulary") - max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 384de89bad0..87bcabf99a0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -123,7 +123,7 @@ def __init__( cache_config.cache_dtype] self.is_multimodal_model = model_config.is_multimodal_model - self.is_pooling_model = model_config.is_pooling_model + self.is_pooling_model = model_config.pooler_config is not None self.model_supports_multimodal_raw_input = ( model_config.model_supports_multimodal_raw_input) self.max_model_len = model_config.max_model_len @@ -328,8 +328,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: Args: scheduler_output: The scheduler output. """ - - # nothing to be reordered when the mdoel is attention free if self.model_config.is_attention_free: return False @@ -1059,14 +1057,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) - if curr_group_outputs: - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=len(grouped_mm_inputs), - ) + sanity_check_mm_encoder_outputs( + curr_group_outputs, + expected_num_items=len(grouped_mm_inputs), + ) - for output in curr_group_outputs: - encoder_outputs.append(output) + for output in curr_group_outputs: + encoder_outputs.append(output) # Cache the encoder outputs. for (req_id, input_id, pos_info), output in zip( From 9afe699afbe927042d28d69a0a7759ec3d0ae7ed Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 16 Jul 2025 12:05:57 +0000 Subject: [PATCH 12/14] Ensure pre-commit checks succeed Signed-off-by: Christian Pinto --- .../prithvi_geospatial_mae.py | 269 ++++++++---------- .../multimodal/pooling/test_prithvi_mae.py | 1 + vllm/model_executor/models/interfaces.py | 3 + vllm/v1/engine/output_processor.py | 15 +- vllm/v1/engine/processor.py | 10 +- vllm/v1/worker/gpu_model_runner.py | 11 +- 6 files changed, 133 insertions(+), 176 deletions(-) diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 314b6739b51..4fdc7a3cf70 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -1,16 +1,18 @@ -import os -import albumentations -from terratorch.datamodules import Sen1Floods11NonGeoDataModule +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse -from typing import List, Union -import re import datetime +import os +import re +from typing import Union + +import albumentations import numpy as np import rasterio import torch -import time from einops import rearrange -from typing import Tuple +from terratorch.datamodules import Sen1Floods11NonGeoDataModule + from vllm import LLM torch.set_default_dtype(torch.float16) @@ -21,135 +23,77 @@ PERCENTILE = 99 datamodule_config = { - 'bands': ['BLUE', - 'GREEN', - 'RED', - 'NIR_NARROW', - 'SWIR_1', - 'SWIR_2'], - 'batch_size': 16, - 'constant_scale': 0.0001, - 'data_root': '/dccstor/geofm-finetuning/datasets/sen1floods11', - 'drop_last': True, - 'no_data_replace': 0.0, - 'no_label_replace': -1, - 'num_workers': 8, - 'test_transform': [albumentations.Resize(always_apply=False, - height=448, - interpolation=1, - p=1, - width=448), - albumentations.pytorch.ToTensorV2( - transpose_mask=False, - always_apply=True, - p=1.0 - )], + "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], + "batch_size": 16, + "constant_scale": 0.0001, + "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11", + "drop_last": True, + "no_data_replace": 0.0, + "no_label_replace": -1, + "num_workers": 8, + "test_transform": [ + albumentations.Resize( + always_apply=False, height=448, interpolation=1, p=1, width=448 + ), + albumentations.pytorch.ToTensorV2( + transpose_mask=False, always_apply=True, p=1.0 + ), + ], } + class PrithviMAE: - def __init__(self,model): - print("Initializing Terratorch model") - #self.model = LLM(model=os.path.join(os.path.dirname(__file__), "./model"), skip_tokenizer_init=True, dtype="float32") - self.model = LLM(model=model,skip_tokenizer_init=True, dtype="float16",enforce_eager=True) - - def patchify(self, pixel_values): - """ - Args: - pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`): - Pixel values. - - Returns: - torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: - Patchified pixel values. - """ - patch_size_t, patch_size_h, patch_size_w = self.get_patch_size() - num_channels = self.get_num_channels() - - # patchify - patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)', - c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w) - - - return patchified_pixel_values - - def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None): - """ - Args: - patchified_pixel_values (`torch.FloatTensor` of shape - `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`: - Patchified pixel values. - image_size (`Tuple[int, int]`, *optional*): - Original image size. - - Returns: - `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`: - Pixel values. - """ - patch_size_t, patch_size_h, patch_size_w = self.get_patch_size() - image_size = to_2tuple(image_size) if image_size is not None else self.get_img_size() - original_height, original_width = image_size - num_patches_h = original_height // patch_size_h - num_patches_w = original_width // patch_size_w - num_channels = self.get_num_channels() - - pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)', - c=num_channels, h=num_patches_h, w=num_patches_w, - s=patch_size_t, p=patch_size_h, q=patch_size_w) - return pixel_values - - def run(self, input_data, temporal_coords, location_coords): - print("################ Running inference on vLLM ##############") + def __init__(self, model): + self.model = LLM( + model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True + ) + + def run(self, input_data, location_coords): # merge the inputs into one data structure - if input_data is not None and input_data.dtype == torch.float32 : - input_data= input_data.to(torch.float16) + if input_data is not None and input_data.dtype == torch.float32: + input_data = input_data.to(torch.float16) input_data = input_data[0] mm_data = { - "pixel_values": torch.empty(0) if input_data is None else input_data, - "location_coords": torch.empty(0) if location_coords is None else location_coords + "pixel_values": input_data, + "location_coords": location_coords, } - prompt = { - "prompt_token_ids": [1], - "multi_modal_data": mm_data - } - - start = time.time() + prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} outputs = self.model.encode(prompt, use_tqdm=False) - end = time.time() - elapsed = end - start - print(f"################ Inference done (it took {round(elapsed,2)} seconds) ##############") return outputs[0].outputs.data -def generate_datamodule(): - - datamodule = Sen1Floods11NonGeoDataModule(data_root=datamodule_config['data_root'], - batch_size=datamodule_config["batch_size"], - num_workers=datamodule_config["num_workers"], - bands=datamodule_config["bands"], - drop_last=datamodule_config["drop_last"], - test_transform=datamodule_config["test_transform" - ""]) +def generate_datamodule(): + datamodule = Sen1Floods11NonGeoDataModule( + data_root=datamodule_config["data_root"], + batch_size=datamodule_config["batch_size"], + num_workers=datamodule_config["num_workers"], + bands=datamodule_config["bands"], + drop_last=datamodule_config["drop_last"], + test_transform=datamodule_config["test_transform"], + ) return datamodule + def process_channel_group(orig_img, channels): """ Args: - orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W). + orig_img: torch.Tensor representing original image (reference) + with shape = (bands, H, W). channels: list of indices representing RGB channels. Returns: - torch.Tensor with shape (num_channels, height, width) for original image + torch.Tensor with shape (num_channels, height, width) + for original image """ orig_img = orig_img[channels, ...] valid_mask = torch.ones_like(orig_img, dtype=torch.bool) valid_mask[orig_img == NO_DATA_FLOAT] = False - # Rescale (enhancing contrast) max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE)) min_value = OFFSET @@ -178,7 +122,7 @@ def read_geotiff(file_path: str): meta = src.meta try: coords = src.lnglat() - except: + except Exception: # Cannot read coords coords = None @@ -209,17 +153,19 @@ def _convert_np_uint8(float_image: torch.Tensor): def load_example( - file_paths: List[str], - mean: List[float] = None, - std: List[float] = None, + file_paths: list[str], + mean: list[float] = None, + std: list[float] = None, indices: Union[list[int], None] = None, ): """Build an input example by loading images in *file_paths*. Args: file_paths: list of file paths . - mean: list containing mean values for each band in the images in *file_paths*. - std: list containing std values for each band in the images in *file_paths*. + mean: list containing mean values for each band in the + images in *file_paths*. + std: list containing std values for each band in the + images in *file_paths*. Returns: np.array containing created example @@ -247,17 +193,21 @@ def load_example( location_coords.append(coords) try: - match = re.search(r'(\d{7,8}T\d{6})', file) + match = re.search(r"(\d{7,8}T\d{6})", file) if match: year = int(match.group(1)[:4]) - julian_day = match.group(1).split('T')[0][4:] + julian_day = match.group(1).split("T")[0][4:] if len(julian_day) == 3: julian_day = int(julian_day) else: - julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday + julian_day = ( + datetime.datetime.strptime(julian_day, "%m%d") + .timetuple() + .tm_yday + ) temporal_coords.append([year, julian_day]) except Exception as e: - print(f'Could not extract timestamp for {file} ({e})') + print(f"Could not extract timestamp for {file} ({e})") imgs = np.stack(imgs, axis=0) # num_frames, H, W, C imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W @@ -266,7 +216,15 @@ def load_example( return imgs, temporal_coords, location_coords, metas -def run_model(input_data, temporal_coords, location_coords, model, datamodule, img_size, lightning_model=None): +def run_model( + input_data, + temporal_coords, + location_coords, + model, + datamodule, + img_size, + lightning_model=None, +): # Reflect pad if not divisible by img_size original_h, original_w = input_data.shape[-2:] pad_h = (img_size - (original_h % img_size)) % img_size @@ -278,7 +236,7 @@ def run_model(input_data, temporal_coords, location_coords, model, datamodule, i # Build sliding window batch_size = 1 - #batch = torch.tensor(input_data, device="cpu") + # batch = torch.tensor(input_data, device="cpu") batch = torch.tensor(input_data) windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) h1, w1 = windows.shape[3:5] @@ -290,18 +248,11 @@ def run_model(input_data, temporal_coords, location_coords, model, datamodule, i num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 windows = torch.tensor_split(windows, num_batches, dim=0) - if torch.cuda.is_available(): - device = torch.device('cuda') - else: - device = torch.device('cpu') - if temporal_coords: - #temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0) temporal_coords = torch.tensor(temporal_coords).unsqueeze(0) else: temporal_coords = None if location_coords: - #location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0) location_coords = torch.tensor(location_coords[0]).unsqueeze(0) else: location_coords = None @@ -310,19 +261,16 @@ def run_model(input_data, temporal_coords, location_coords, model, datamodule, i pred_imgs = [] for x in windows: # Apply standardization - x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0)) - x = datamodule.aug(x)['image'] + x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0)) + x = datamodule.aug(x)["image"] with torch.no_grad(): - pred = model.run(x, temporal_coords=temporal_coords, location_coords=location_coords) - if lightning_model: - pred_lightning = lightning_model(x, temporal_coords=temporal_coords, location_coords=location_coords) - pred_lightning = pred_lightning.output.detach().cpu() - if not torch.equal(pred, pred_lightning): - print("Inference output is not equal") + pred = model.run(x, location_coords=location_coords) y_hat = pred.argmax(dim=1) - y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), size=img_size, mode="nearest") + y_hat = torch.nn.functional.interpolate( + y_hat.unsqueeze(1).float(), size=img_size, mode="nearest" + ) pred_imgs.append(y_hat) @@ -348,6 +296,7 @@ def run_model(input_data, temporal_coords, location_coords, model, datamodule, i return pred_imgs + def main( data_file: str, model: str, @@ -357,16 +306,13 @@ def main( ): os.makedirs(output_dir, exist_ok=True) - # Load Prithvi-EO-V2-300M-TL-Sen1Floods11 --------------------------------------------------------------------------------- - model_obj = PrithviMAE(model=model) datamodule = generate_datamodule() img_size = 512 # Size of Sen1Floods11 - # Loading data --------------------------------------------------------------------------------- - input_data, temporal_coords, location_coords, meta_data = load_example( - file_paths=[data_file], indices=input_indices, + file_paths=[data_file], + indices=input_indices, ) meta_data = meta_data[0] # only one image @@ -374,17 +320,18 @@ def main( if input_data.mean() > 1: input_data = input_data / 10000 # Convert to range 0-1 - # Running Prithvi-EO-V2-300M-TL-Sen1Floods11 -------------------------------------------------------------------------------- - - channels = [datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]] # BGR -> RGB - - # lightning_model = LightningInferenceModel.from_config(config, checkpoint) - pred = run_model(input_data, temporal_coords, location_coords, - model_obj, datamodule, img_size) + channels = [ + datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"] + ] # BGR -> RGB + pred = run_model( + input_data, temporal_coords, location_coords, model_obj, datamodule, img_size + ) # Save pred meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) - pred_file = os.path.join(output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + pred_file = os.path.join( + output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff" + ) save_geotiff(_convert_np_uint8(pred), pred_file, meta_data) # Save image + pred @@ -397,13 +344,15 @@ def main( orig_img=torch.Tensor(input_data[0, :, 0, ...]), channels=channels, ) - rgb_orig= rgb_orig.to(torch.float32) + rgb_orig = rgb_orig.to(torch.float32) - pred[pred == 0.] = np.nan + pred[pred == 0.0] = np.nan img_pred = rgb_orig * 0.7 + pred * 0.3 img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()] - img_pred_file = os.path.join(output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + img_pred_file = os.path.join( + output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff" + ) save_geotiff( image=_convert_np_uint8(img_pred), output_path=img_pred_file, @@ -412,15 +361,17 @@ def main( # Save image rgb if rgb_outputs: - rgb_file = os.path.join(output_dir, f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") + name_suffix = os.path.splitext(os.path.basename(data_file))[0] + rgb_file = os.path.join( + output_dir, + f"original_rgb_{name_suffix}.tiff", + ) save_geotiff( image=_convert_np_uint8(rgb_orig), output_path=rgb_file, meta=meta_data, ) - print("Done!") - if __name__ == "__main__": parser = argparse.ArgumentParser("MAE run inference", add_help=False) @@ -428,10 +379,9 @@ def main( parser.add_argument( "--data_file", type=str, - default="/workspace/scripts/demo_flooding/examples/India_900498_S2Hand.tif", + default="./India_900498_S2Hand.tif", help="Path to the file.", ) - parser.add_argument( "--model", type=str, @@ -446,10 +396,13 @@ def main( ) parser.add_argument( "--input_indices", - default=[1,2,3,8,11,12], + default=[1, 2, 3, 8, 11, 12], type=int, nargs="+", - help="0-based indices of the six Prithvi channels to be selected from the input. By default selects [1,2,3,8,11,12] for S2L1C data.", + help=""" + 0-based indices of the six Prithvi channels to be selected from the input. + By default selects [1,2,3,8,11,12] for S2L1C data. + """, ) parser.add_argument( "--rgb_outputs", diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index 3ff419f3d7b..55b7a9af2e4 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -36,6 +36,7 @@ def _run_test( MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + @pytest.mark.core_model @pytest.mark.parametrize("model", MODELS) def test_models_image( diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 010d6dc4351..439e766a38b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -145,6 +145,7 @@ def supports_multimodal( return isinstance(model, SupportsMultiModal) + @runtime_checkable class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): """The interface required for all multi-modal models.""" @@ -159,6 +160,7 @@ class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): MRO of your model class. """ + @runtime_checkable class _SupportsMultiModalWithRawInput(Protocol): supports_multimodal_raw_input: ClassVar[Literal[True]] @@ -185,6 +187,7 @@ def supports_multimodal_raw_input( return isinstance(model, SupportsMultiModalWithRawInput) + @runtime_checkable class SupportsScoreTemplate(Protocol): """The interface required for all models that support score template.""" diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 05ea6cc492e..3be6c482121 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -330,14 +330,13 @@ def add_request( tokenizer = None if not self.tokenizer else \ self.tokenizer.get_lora_tokenizer(request.lora_request) - req_state = RequestState.from_new_request( - tokenizer=tokenizer, - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + req_state = RequestState.from_new_request(tokenizer=tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 45ac61fb30b..60e97c78a5f 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -384,10 +384,6 @@ def _validate_model_input( tokenizer = None else: tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError( - f"Token id {max_input_id} is out of vocabulary") prompt_ids = prompt_inputs["prompt_token_ids"] if not prompt_ids: @@ -396,6 +392,12 @@ def _validate_model_input( else: raise ValueError(f"The {prompt_type} prompt cannot be empty") + if tokenizer: + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") + max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 87bcabf99a0..5cf11920271 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -329,7 +329,7 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: scheduler_output: The scheduler output. """ if self.model_config.is_attention_free: - return False + return self.attn_metadata_builders[0].reorder_batch(self.input_batch, scheduler_output) @@ -560,10 +560,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.refresh_metadata() def _maybe_add_multimodal_kwargs( - self, - model_kwargs: dict[str, Any], - scheduler_output: "SchedulerOutput" = None, - num_reqs: int = -1, + self, + model_kwargs: dict[str, Any], + scheduler_output: Optional["SchedulerOutput"] = None, + num_reqs: int = -1, ): if not self.model_supports_multimodal_raw_input: @@ -591,7 +591,6 @@ def _maybe_add_multimodal_kwargs( model_kwargs.update(multi_modal_kwargs) - def _get_cumsum_and_arange( self, num_tokens: np.ndarray, From 7b3081b320c76255645e09f5dbf06d4728f3cdcf Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 16 Jul 2025 12:29:43 +0000 Subject: [PATCH 13/14] Few more style edits Signed-off-by: Christian Pinto --- vllm/v1/engine/processor.py | 7 +++---- vllm/v1/worker/gpu_model_runner.py | 19 ++++++++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 60e97c78a5f..72c00796cb0 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -380,10 +380,9 @@ def _validate_model_input( prompt_type: Literal["encoder", "decoder"], ): model_config = self.model_config - if model_config.skip_tokenizer_init: - tokenizer = None - else: - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + + tokenizer = (None if model_config.skip_tokenizer_init else + self.tokenizer.get_lora_tokenizer(lora_request)) prompt_ids = prompt_inputs["prompt_token_ids"] if not prompt_ids: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5cf11920271..a6e65c28d17 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1415,7 +1415,8 @@ def execute_model( **MultiModalKwargs.as_kwargs( model_kwargs, device=self.device, - )) + ), + ) self.maybe_wait_for_kv_save() @@ -2073,12 +2074,16 @@ def _dummy_run( self.vllm_config, num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp): - outputs = model(input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **MultiModalKwargs.as_kwargs( - model_kwargs, device=self.device)) + outputs = model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + ), + ) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs From c0f390774422f25f96747bb0aa5a7d2c4e2cf08e Mon Sep 17 00:00:00 2001 From: Christian Pinto Date: Wed, 16 Jul 2025 12:59:05 +0000 Subject: [PATCH 14/14] Skip tokenizer init in async LLM engine Signed-off-by: Christian Pinto --- vllm/v1/engine/async_llm.py | 13 ++++++++----- vllm/v1/engine/llm_engine.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3754570dfaa..af7bd7f7ce6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -102,11 +102,14 @@ def __init__( custom_stat_loggers=stat_loggers, ) - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + if not self.model_config.skip_tokenizer_init: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + else: + self.tokenizer = None # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index e13a544b50f..61675876ed2 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,7 +82,7 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - if not self.vllm_config.model_config.skip_tokenizer_init: + if not self.model_config.skip_tokenizer_init: # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( model_config=vllm_config.model_config,