Skip to content

[Core][Model] PrithviMAE Enablement on vLLM v1 engine (superseded by PR 20577) #20072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/offline_inference/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __init__(self):
self.model = LLM(
model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True,
dtype="float32",
dtype="float16",
enforce_eager=True,
)

def run(self, input_data, location_coords):
Expand Down
50 changes: 50 additions & 0 deletions tests/models/multimodal/pooling/test_prithvi_mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from ....conftest import VllmRunner


def generate_test_mm_data():
mm_data = {
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return mm_data


def _run_test(
vllm_runner: type[VllmRunner],
model: str,
) -> None:

mm_data = generate_test_mm_data()
prompt = {
# This model deals with no text input
"prompt_token_ids": [1],
"multi_modal_data": mm_data
}
with vllm_runner(model,
task="embed",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True) as vllm_model:
vllm_model.encode(prompt)


MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]


@pytest.mark.parametrize("model", MODELS)
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
) -> None:
_run_test(
vllm_runner,
model,
)
13 changes: 11 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,8 @@ def __post_init__(self) -> None:
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
self.multimodal_config = self._init_multimodal_config()
self.model_supports_multimodal_raw_input = (
self._init_model_supports_multimodal_raw_input())
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()

Expand Down Expand Up @@ -715,6 +717,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:

return None

def _init_model_supports_multimodal_raw_input(self):
return self.registry.supports_multimodal_raw_input(self.architectures)

def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
Expand Down Expand Up @@ -1120,10 +1125,10 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
return self.get_hf_config_sliding_window()

def get_vocab_size(self) -> int:
return self.hf_text_config.vocab_size
return getattr(self.hf_text_config, "vocab_size", 0)

def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
return getattr(self.hf_text_config, "hidden_size", 0)

@property
def is_deepseek_mla(self) -> bool:
Expand Down Expand Up @@ -1417,6 +1422,10 @@ def uses_mrope(self) -> bool:
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None

@property
def is_pooling_model(self) -> bool:
return self.registry.is_pooling_model(self.architectures)

@property
def is_cross_encoder(self) -> bool:
Expand Down
42 changes: 42 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,48 @@ def supports_multimodal(
return isinstance(model, SupportsMultiModal)


@runtime_checkable
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
"""The interface required for all multi-modal models."""

supports_multimodal_raw_input: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""


@runtime_checkable
class _SupportsMultiModalWithRawInput(Protocol):
supports_multimodal_raw_input: ClassVar[Literal[True]]


@overload
def supports_multimodal_raw_input(
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
...


@overload
def supports_multimodal_raw_input(
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
...


def supports_multimodal_raw_input(
model: Union[type[object], object]
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
TypeIs[SupportsMultiModalWithRawInput]]:
if isinstance(model, type):
return isinstance(model, _SupportsMultiModalWithRawInput)

return isinstance(model, SupportsMultiModalWithRawInput)


@runtime_checkable
class SupportsLoRA(Protocol):
"""The interface required for all models that support LoRA."""
Expand Down
53 changes: 39 additions & 14 deletions vllm/model_executor/models/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@

from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.interfaces import (
IsAttentionFree, SupportsMultiModalWithRawInput)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
MultiModalFieldElem, MultiModalInputs,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalSharedField, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate)
Expand Down Expand Up @@ -62,8 +63,9 @@ def get_dummy_mm_data(
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
return {
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
"location_coords": torch.full((1, 2), 1.0),
"pixel_values": torch.full((6, 512, 512), 1.0,
dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}


Expand All @@ -75,8 +77,10 @@ def _get_mm_fields_config(
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
location_coords=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
)

def _get_prompt_updates(
Expand All @@ -99,23 +103,34 @@ def apply(

for k, v in mm_data.items():
mm_kwargs[k] = v
mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]}

multimodal_kwargs_items = [
MultiModalKwargsItem.from_elems([
MultiModalFieldElem(modality="image",
key=key,
data=data,
field=MultiModalSharedField(1))
for key, data in mm_kwargs.items()
])
]

return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=[1],
mm_kwargs=MultiModalKwargs(mm_kwargs),
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
mm_hashes=None,
mm_placeholders={},
mm_placeholders=mm_place_holders,
)


@MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
SupportsV0Only):
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
SupportsMultiModalWithRawInput):
""" Prithvi Masked Autoencoder"""

@classmethod
Expand Down Expand Up @@ -169,7 +184,7 @@ def _parse_and_validate_multimodal_data(
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
pixel_values = torch.unbind(pixel_values, dim=0)[0]
# pixel_values = torch.unbind(pixel_values, dim=0)[0]

location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
Expand All @@ -181,6 +196,13 @@ def _parse_and_validate_multimodal_data(

return pixel_values, location_coords

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
# the input prompt we pass one token and the size of the dummy
# embedding tensors must reflect that.
return torch.empty(input_ids.shape)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the get_multimodal_embeddings() function here. How is this not failing in GPUModelRunner.profile_run()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that memory profiling is only done while initializing the KVCache, which is skipped because the model is attention free.

It should happen in Worker.determine_available_memory() if I am not mistaken.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but if we're adding a new programming model for MM inputs I think it would be nice to handle this. If not a full solution, perhaps a test to disable cuda graphs and memory profiling for this special case.

But I'm thinking that maybe the raw MM inputs will be hard to generalize. Perhaps a hacky solution for this particular model would be to store the raw MM inputs in the get_multimodal_embeddings() method and then use them when forward() is called.

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand All @@ -202,7 +224,10 @@ def pooler(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
return PoolerOutput([
PoolingSequenceGroupOutput(hidden_state)
for hidden_state in hidden_states
])

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding,
supports_multimodal, supports_pp,
supports_transcription, supports_v0_only)
supports_multimodal, supports_multimodal_raw_input,
supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_text_generation_model

logger = init_logger(__name__)
Expand Down Expand Up @@ -275,6 +275,7 @@ class _ModelInfo:
is_pooling_model: bool
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input: bool
supports_pp: bool
has_inner_state: bool
is_attention_free: bool
Expand All @@ -291,6 +292,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
is_pooling_model=True, # Can convert any model into a pooling model
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
Expand Down Expand Up @@ -528,6 +530,13 @@ def is_multimodal_model(
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal

def supports_multimodal_raw_input(
self,
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal_raw_input

def is_pp_supported_model(
self,
architectures: Union[str, list[str]],
Expand Down
2 changes: 1 addition & 1 deletion vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def create_processor(
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")

if tokenizer is None:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None:
mm_config = model_config.get_multimodal_config()
Expand Down
Loading