diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 567c448a8c9..4fdc7a3cf70 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -1,122 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This is a demo script showing how to use the -PrithviGeospatialMAE model with vLLM -This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa - -Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa - -The requirements for running this script are: -- Installing [terratorch, albumentations, rasterio] in your python environment -- downloading the model weights in a 'model' folder local to the script - (temporary measure until the proper config.json file is uploaded to HF) -- download an input example image (India_900498_S2Hand.tif) and place it in - the same folder with the script (or specify with the --data_file argument) - -Run the example: -python prithvi_geospatial_mae.py - -""" # noqa: E501 - import argparse import datetime import os +import re from typing import Union import albumentations import numpy as np import rasterio -import regex as re import torch from einops import rearrange from terratorch.datamodules import Sen1Floods11NonGeoDataModule from vllm import LLM +torch.set_default_dtype(torch.float16) + NO_DATA = -9999 NO_DATA_FLOAT = 0.0001 OFFSET = 0 PERCENTILE = 99 -model_config = """{ - "architectures": ["PrithviGeoSpatialMAE"], - "num_classes": 0, - "pretrained_cfg": { - "task_args": { - "task": "SemanticSegmentationTask", - "model_factory": "EncoderDecoderFactory", - "loss": "ce", - "ignore_index": -1, - "lr": 0.001, - "freeze_backbone": false, - "freeze_decoder": false, - "plot_on_val": 10, - "optimizer": "AdamW", - "scheduler": "CosineAnnealingLR" - }, - "model_args": { - "backbone_pretrained": false, - "backbone": "prithvi_eo_v2_300_tl", - "decoder": "UperNetDecoder", - "decoder_channels": 256, - "decoder_scale_modules": true, - "num_classes": 2, - "rescale": true, - "backbone_bands": [ - "BLUE", - "GREEN", - "RED", - "NIR_NARROW", - "SWIR_1", - "SWIR_2" - ], - "head_dropout": 0.1, - "necks": [ - { - "name": "SelectIndices", - "indices": [ - 5, - 11, - 17, - 23 - ] - }, - { - "name": "ReshapeTokensToImage" - } - ] - }, - "optimizer_params" : { - "lr": 5.0e-05, - "betas": [0.9, 0.999], - "eps": [1.0e-08], - "weight_decay": 0.05, - "amsgrad": false, - "maximize": false, - "capturable": false, - "differentiable": false - }, - "scheduler_params" : { - "T_max": 50, - "eta_min": 0, - "last_epoch": -1, - "verbose": "deprecated" - } - }, - - - "torch_dtype": "float32" -} -""" - -# Temporarily creating the "config.json" for the model. -# This is going to disappear once the correct config.json is available on HF -with open( - os.path.join(os.path.dirname(__file__), "./model/config.json"), "w" -) as config_file: - config_file.write(model_config) - datamodule_config = { "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], "batch_size": 16, @@ -138,28 +43,24 @@ class PrithviMAE: - def __init__(self): - print("Initializing PrithviMAE model") + def __init__(self, model): self.model = LLM( - model=os.path.join(os.path.dirname(__file__), "./model"), - skip_tokenizer_init=True, - dtype="float32", + model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True ) def run(self, input_data, location_coords): - print("################ Running inference on vLLM ##############") # merge the inputs into one data structure + if input_data is not None and input_data.dtype == torch.float32: + input_data = input_data.to(torch.float16) + input_data = input_data[0] + mm_data = { - "pixel_values": torch.empty(0) if input_data is None else input_data, - "location_coords": torch.empty(0) - if location_coords is None - else location_coords, + "pixel_values": input_data, + "location_coords": location_coords, } prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} - outputs = self.model.encode(prompt, use_tqdm=False) - print("################ Inference done (it took seconds) ##############") return outputs[0].outputs.data @@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels): """ Args: orig_img: torch.Tensor representing original image (reference) - with shape = (bands, H, W). + with shape = (bands, H, W). channels: list of indices representing RGB channels. Returns: - torch.Tensor with shape (num_channels, height, width) for original image + torch.Tensor with shape (num_channels, height, width) + for original image """ orig_img = orig_img[channels, ...] @@ -260,10 +162,10 @@ def load_example( Args: file_paths: list of file paths . - mean: list containing mean values for each band in the images - in *file_paths*. - std: list containing std values for each band in the images - in *file_paths*. + mean: list containing mean values for each band in the + images in *file_paths*. + std: list containing std values for each band in the + images in *file_paths*. Returns: np.array containing created example @@ -308,7 +210,7 @@ def load_example( print(f"Could not extract timestamp for {file} ({e})") imgs = np.stack(imgs, axis=0) # num_frames, H, W, C - imgs = np.moveaxis(imgs, -1, 0).astype("float32") + imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W imgs = np.expand_dims(imgs, axis=0) # add batch di return imgs, temporal_coords, location_coords, metas @@ -332,8 +234,10 @@ def run_model( ) # Build sliding window + batch_size = 1 - batch = torch.tensor(input_data, device="cpu") + # batch = torch.tensor(input_data, device="cpu") + batch = torch.tensor(input_data) windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) h1, w1 = windows.shape[3:5] windows = rearrange( @@ -344,18 +248,16 @@ def run_model( num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 windows = torch.tensor_split(windows, num_batches, dim=0) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - if temporal_coords: - temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0) + temporal_coords = torch.tensor(temporal_coords).unsqueeze(0) else: temporal_coords = None if location_coords: - location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0) + location_coords = torch.tensor(location_coords[0]).unsqueeze(0) else: location_coords = None - # Run model + # Run Prithvi-EO-V2-300M-TL-Sen1Floods11 pred_imgs = [] for x in windows: # Apply standardization @@ -363,15 +265,7 @@ def run_model( x = datamodule.aug(x)["image"] with torch.no_grad(): - x = x.to(device) pred = model.run(x, location_coords=location_coords) - if lightning_model: - pred_lightning = lightning_model( - x, temporal_coords=temporal_coords, location_coords=location_coords - ) - pred_lightning = pred_lightning.output.detach().cpu() - if not torch.equal(pred, pred_lightning): - print("Inference output is not equal") y_hat = pred.argmax(dim=1) y_hat = torch.nn.functional.interpolate( @@ -403,52 +297,18 @@ def run_model( return pred_imgs -def parse_args(): - parser = argparse.ArgumentParser("MAE run inference", add_help=False) - - parser.add_argument( - "--data_file", - type=str, - default="./India_900498_S2Hand.tif", - help="Path to the file.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="output", - help="Path to the directory where to save outputs.", - ) - parser.add_argument( - "--input_indices", - default=[1, 2, 3, 8, 11, 12], - type=int, - nargs="+", - help="0-based indices of the six Prithvi channels to be selected from the " - "input. By default selects [1,2,3,8,11,12] for S2L1C data.", - ) - parser.add_argument( - "--rgb_outputs", - action="store_true", - help="If present, output files will only contain RGB channels. " - "Otherwise, all bands will be saved.", - ) - - def main( data_file: str, + model: str, output_dir: str, rgb_outputs: bool, input_indices: list[int] = None, ): os.makedirs(output_dir, exist_ok=True) - # Load model --------------------------------------------------------------- - - model_obj = PrithviMAE() + model_obj = PrithviMAE(model=model) datamodule = generate_datamodule() - img_size = 256 # Size of Sen1Floods11 - - # Loading data ------------------------------------------------------------- + img_size = 512 # Size of Sen1Floods11 input_data, temporal_coords, location_coords, meta_data = load_example( file_paths=[data_file], @@ -460,8 +320,6 @@ def main( if input_data.mean() > 1: input_data = input_data / 10000 # Convert to range 0-1 - # Running model ------------------------------------------------------------ - channels = [ datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"] ] # BGR -> RGB @@ -469,7 +327,6 @@ def main( pred = run_model( input_data, temporal_coords, location_coords, model_obj, datamodule, img_size ) - # Save pred meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) pred_file = os.path.join( @@ -487,6 +344,7 @@ def main( orig_img=torch.Tensor(input_data[0, :, 0, ...]), channels=channels, ) + rgb_orig = rgb_orig.to(torch.float32) pred[pred == 0.0] = np.nan img_pred = rgb_orig * 0.7 + pred * 0.3 @@ -503,9 +361,10 @@ def main( # Save image rgb if rgb_outputs: + name_suffix = os.path.splitext(os.path.basename(data_file))[0] rgb_file = os.path.join( output_dir, - f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff", + f"original_rgb_{name_suffix}.tiff", ) save_geotiff( image=_convert_np_uint8(rgb_orig), @@ -515,6 +374,42 @@ def main( if __name__ == "__main__": - args = parse_args() + parser = argparse.ArgumentParser("MAE run inference", add_help=False) + + parser.add_argument( + "--data_file", + type=str, + default="./India_900498_S2Hand.tif", + help="Path to the file.", + ) + parser.add_argument( + "--model", + type=str, + default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + help="Path to a checkpoint file to load from.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Path to the directory where to save outputs.", + ) + parser.add_argument( + "--input_indices", + default=[1, 2, 3, 8, 11, 12], + type=int, + nargs="+", + help=""" + 0-based indices of the six Prithvi channels to be selected from the input. + By default selects [1,2,3,8,11,12] for S2L1C data. + """, + ) + parser.add_argument( + "--rgb_outputs", + action="store_true", + help="If present, output files will only contain RGB channels. " + "Otherwise, all bands will be saved.", + ) + args = parser.parse_args() main(**vars(args)) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py new file mode 100644 index 00000000000..55b7a9af2e4 --- /dev/null +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from ....conftest import VllmRunner + + +def generate_test_mm_data(): + mm_data = { + "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), + } + return mm_data + + +def _run_test( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + + mm_data = generate_test_mm_data() + prompt = { + # This model deals with no text input + "prompt_token_ids": [1], + "multi_modal_data": mm_data + } + with vllm_runner(model, + task="embed", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True) as vllm_model: + vllm_model.encode(prompt) + + +MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", MODELS) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, +) -> None: + _run_test( + vllm_runner, + model, + ) diff --git a/vllm/config.py b/vllm/config.py index 766d7708625..c5f91e4c217 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -642,6 +642,8 @@ def __post_init__(self) -> None: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.multimodal_config = self._init_multimodal_config() + self.model_supports_multimodal_raw_input = ( + self._init_model_supports_multimodal_raw_input()) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() @@ -753,6 +755,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: return None + def _init_model_supports_multimodal_raw_input(self): + return self.registry.supports_multimodal_raw_input(self.architectures) + def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -1201,10 +1206,10 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: return self.get_hf_config_sliding_window() def get_vocab_size(self) -> int: - return self.hf_text_config.vocab_size + return getattr(self.hf_text_config, "vocab_size", 0) def get_hidden_size(self) -> int: - return self.hf_text_config.hidden_size + return getattr(self.hf_text_config, "hidden_size", 0) @property def is_deepseek_mla(self) -> bool: diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 92ecb8972d5..439e766a38b 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -146,6 +146,48 @@ def supports_multimodal( return isinstance(model, SupportsMultiModal) +@runtime_checkable +class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): + """The interface required for all multi-modal models.""" + + supports_multimodal_raw_input: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + +@runtime_checkable +class _SupportsMultiModalWithRawInput(Protocol): + supports_multimodal_raw_input: ClassVar[Literal[True]] + + +@overload +def supports_multimodal_raw_input( + model: object) -> TypeIs[SupportsMultiModalWithRawInput]: + ... + + +@overload +def supports_multimodal_raw_input( + model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: + ... + + +def supports_multimodal_raw_input( + model: Union[type[object], object] +) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], + TypeIs[SupportsMultiModalWithRawInput]]: + if isinstance(model, type): + return isinstance(model, _SupportsMultiModalWithRawInput) + + return isinstance(model, SupportsMultiModalWithRawInput) + + @runtime_checkable class SupportsScoreTemplate(Protocol): """The interface required for all models that support score template.""" diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index a36f24bc80e..26f5e594f30 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -25,14 +25,15 @@ from vllm.config import VllmConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModal, - SupportsV0Only) +from vllm.model_executor.models.interfaces import ( + IsAttentionFree, SupportsMultiModalWithRawInput) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalFieldElem, MultiModalInputs, + MultiModalKwargs, MultiModalKwargsItem, + MultiModalSharedField, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) @@ -62,8 +63,9 @@ def get_dummy_mm_data( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((1, 6, 512, 512), 1.0), - "location_coords": torch.full((1, 2), 1.0), + "pixel_values": torch.full((6, 512, 512), 1.0, + dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } @@ -75,8 +77,10 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - location_coords=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), ) def _get_prompt_updates( @@ -99,14 +103,25 @@ def apply( for k, v in mm_data.items(): mm_kwargs[k] = v + mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]} + + multimodal_kwargs_items = [ + MultiModalKwargsItem.from_elems([ + MultiModalFieldElem(modality="image", + key=key, + data=data, + field=MultiModalSharedField(1)) + for key, data in mm_kwargs.items() + ]) + ] return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs(mm_kwargs), + mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), mm_hashes=None, - mm_placeholders={}, + mm_placeholders=mm_place_holders, ) @@ -114,8 +129,8 @@ def apply( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, - SupportsV0Only): +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, + SupportsMultiModalWithRawInput): """ Prithvi Masked Autoencoder""" @classmethod @@ -169,7 +184,6 @@ def _parse_and_validate_multimodal_data( if not isinstance(pixel_values, torch.Tensor): raise ValueError(f"Incorrect type of pixel_values. " f"Got type: {type(pixel_values)}") - pixel_values = torch.unbind(pixel_values, dim=0)[0] location_coords = kwargs.pop("location_coords", None) if not isinstance(location_coords, torch.Tensor): @@ -181,6 +195,13 @@ def _parse_and_validate_multimodal_data( return pixel_values, location_coords + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + # We do not really use any input tokens and therefore no embeddings + # to be calculated. However, due to the mandatory token ids in + # the input prompt we pass one token and the size of the dummy + # embedding tensors must reflect that. + return torch.empty((input_ids.shape[0], 0)) + def forward( self, input_ids: Optional[torch.Tensor], @@ -202,7 +223,10 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) + return PoolerOutput([ + PoolingSequenceGroupOutput(hidden_state) + for hidden_state in hidden_states + ]) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b7f9638d322..6e706a6eb0a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -22,8 +22,8 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_pp, - supports_transcription, supports_v0_only) + supports_multimodal, supports_multimodal_raw_input, + supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -281,6 +281,7 @@ class _ModelInfo: is_pooling_model: bool supports_cross_encoding: bool supports_multimodal: bool + supports_multimodal_raw_input: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -298,6 +299,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": is_pooling_model=True, # Can convert any model into a pooling model supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), + supports_multimodal_raw_input=supports_multimodal_raw_input(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), @@ -537,6 +539,13 @@ def is_multimodal_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_multimodal + def supports_multimodal_raw_input( + self, + architectures: Union[str, list[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal_raw_input + def is_pp_supported_model( self, architectures: Union[str, list[str]], diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 27aaa661c35..c44fcacd246 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -266,7 +266,7 @@ def create_processor( if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if tokenizer is None: + if tokenizer is None and not model_config.skip_tokenizer_init: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: mm_config = model_config.get_multimodal_config() diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3754570dfaa..af7bd7f7ce6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -102,11 +102,14 @@ def __init__( custom_stat_loggers=stat_loggers, ) - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + if not self.model_config.skip_tokenizer_init: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + else: + self.tokenizer = None # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a2328c37ba0..61675876ed2 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,11 +82,14 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + if not self.model_config.skip_tokenizer_init: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + else: + self.tokenizer = None # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bcd61d1f0a..3be6c482121 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -327,14 +327,16 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - req_state = RequestState.from_new_request( - tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + tokenizer = None if not self.tokenizer else \ + self.tokenizer.get_lora_tokenizer(request.lora_request) + + req_state = RequestState.from_new_request(tokenizer=tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7af4ed54a22..72c00796cb0 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -380,7 +380,9 @@ def _validate_model_input( prompt_type: Literal["encoder", "decoder"], ): model_config = self.model_config - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + + tokenizer = (None if model_config.skip_tokenizer_init else + self.tokenizer.get_lora_tokenizer(lora_request)) prompt_ids = prompt_inputs["prompt_token_ids"] if not prompt_ids: @@ -389,9 +391,11 @@ def _validate_model_input( else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError(f"Token id {max_input_id} is out of vocabulary") + if tokenizer: + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index af216539c90..a6e65c28d17 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -124,6 +124,8 @@ def __init__( self.is_multimodal_model = model_config.is_multimodal_model self.is_pooling_model = model_config.pooler_config is not None + self.model_supports_multimodal_raw_input = ( + model_config.model_supports_multimodal_raw_input) self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -326,6 +328,9 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: Args: scheduler_output: The scheduler output. """ + if self.model_config.is_attention_free: + return + self.attn_metadata_builders[0].reorder_batch(self.input_batch, scheduler_output) @@ -554,6 +559,38 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _maybe_add_multimodal_kwargs( + self, + model_kwargs: dict[str, Any], + scheduler_output: Optional["SchedulerOutput"] = None, + num_reqs: int = -1, + ): + + if not self.model_supports_multimodal_raw_input: + return + + # Multi-modal data. + if scheduler_output: + multi_modal_kwargs_list = [] + for req in scheduler_output.scheduled_new_reqs: + req_mm_inputs = req.mm_inputs + if not isinstance(req_mm_inputs, list): + req_mm_inputs = list(req_mm_inputs) + multi_modal_kwargs_list.extend(req_mm_inputs) + multi_modal_kwargs = MultiModalKwargs.batch( + multi_modal_kwargs_list) + else: + # The only case where SchedulerOtput is None is for a dummy run, + # let's get some dummy data. + dummy_data = [ + self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, seq_len=1).multi_modal_data + for i in range(num_reqs) + ] + multi_modal_kwargs = MultiModalKwargs.batch(dummy_data) + + model_kwargs.update(multi_modal_kwargs) + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -1319,11 +1356,14 @@ def execute_model( else: mm_embeds = [] + model_kwargs: dict[str, Any] = {} if self.is_multimodal_model and get_pp_group().is_first_rank: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] + self._maybe_add_multimodal_kwargs( + model_kwargs=model_kwargs, scheduler_output=scheduler_output) if mm_embeds: inputs_embeds = self.model.get_input_embeddings( input_ids, mm_embeds) @@ -1372,6 +1412,10 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + ), ) self.maybe_wait_for_kv_save() @@ -1998,7 +2042,10 @@ def _dummy_run( with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model + model_kwargs: dict[str, Any] = {} if self.is_multimodal_model: + self._maybe_add_multimodal_kwargs(model_kwargs=model_kwargs, + num_reqs=num_reqs) input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: @@ -2032,7 +2079,12 @@ def _dummy_run( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + ), ) + if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: