Skip to content

[Core][Model] PrithviMAE Enablement on vLLM v1 engine (with zero kv_cache_groups) #20577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/offline_inference/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __init__(self):
self.model = LLM(
model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True,
dtype="float32",
dtype="float16",
enforce_eager=True,
Comment on lines +146 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using torch.float16 instead of the string literal for specifying the dtype. This approach is more type-safe and avoids potential errors due to typos or inconsistencies.

Suggested change
dtype="float16",
enforce_eager=True,
dtype=torch.float16,
enforce_eager=True,

)

def run(self, input_data, location_coords):
Expand Down
50 changes: 50 additions & 0 deletions tests/models/multimodal/pooling/test_prithvi_mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from ....conftest import VllmRunner


def generate_test_mm_data():
mm_data = {
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return mm_data


def _run_test(
vllm_runner: type[VllmRunner],
model: str,
) -> None:

mm_data = generate_test_mm_data()
prompt = {
# This model deals with no text input
"prompt_token_ids": [1],
"multi_modal_data": mm_data
}
with vllm_runner(model,
task="embed",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True) as vllm_model:
vllm_model.encode(prompt)


MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]


@pytest.mark.parametrize("model", MODELS)
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
) -> None:
_run_test(
vllm_runner,
model,
)
13 changes: 11 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,8 @@ def __post_init__(self) -> None:
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
self.multimodal_config = self._init_multimodal_config()
self.model_supports_multimodal_raw_input = (
self._init_model_supports_multimodal_raw_input())
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()

Expand Down Expand Up @@ -715,6 +717,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:

return None

def _init_model_supports_multimodal_raw_input(self):
return self.registry.supports_multimodal_raw_input(self.architectures)

def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
Expand Down Expand Up @@ -1120,10 +1125,10 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
return self.get_hf_config_sliding_window()

def get_vocab_size(self) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider providing a default value when using getattr to avoid potential AttributeError if hf_text_config does not have the vocab_size attribute. This makes the code more robust.

Suggested change
def get_vocab_size(self) -> int:
return getattr(self.hf_text_config, "vocab_size", 0)

return self.hf_text_config.vocab_size
return getattr(self.hf_text_config, "vocab_size", 0)

def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
return getattr(self.hf_text_config, "hidden_size", 0)
Comment on lines 1130 to +1131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider providing a default value when using getattr to avoid potential AttributeError if hf_text_config does not have the hidden_size attribute. This makes the code more robust.

Suggested change
def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
return getattr(self.hf_text_config, "hidden_size", 0)
return getattr(self.hf_text_config, "hidden_size", 0)


@property
def is_deepseek_mla(self) -> bool:
Expand Down Expand Up @@ -1417,6 +1422,10 @@ def uses_mrope(self) -> bool:
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None

@property
def is_pooling_model(self) -> bool:
return self.registry.is_pooling_model(self.architectures)

@property
def is_cross_encoder(self) -> bool:
Expand Down
42 changes: 42 additions & 0 deletions vllm/model_executor/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,48 @@ def supports_multimodal(
return isinstance(model, SupportsMultiModal)


@runtime_checkable
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
"""The interface required for all multi-modal models."""

supports_multimodal_raw_input: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports multi-modal inputs and processes
them in their raw form and not embeddings.

Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""


@runtime_checkable
class _SupportsMultiModalWithRawInput(Protocol):
supports_multimodal_raw_input: ClassVar[Literal[True]]


@overload
def supports_multimodal_raw_input(
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
...


@overload
def supports_multimodal_raw_input(
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
...


def supports_multimodal_raw_input(
model: Union[type[object], object]
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
TypeIs[SupportsMultiModalWithRawInput]]:
if isinstance(model, type):
return isinstance(model, _SupportsMultiModalWithRawInput)

return isinstance(model, SupportsMultiModalWithRawInput)


@runtime_checkable
class SupportsLoRA(Protocol):
"""The interface required for all models that support LoRA."""
Expand Down
53 changes: 39 additions & 14 deletions vllm/model_executor/models/prithvi_geospatial_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@

from vllm.config import VllmConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (IsAttentionFree,
SupportsMultiModal,
SupportsV0Only)
from vllm.model_executor.models.interfaces import (
IsAttentionFree, SupportsMultiModalWithRawInput)
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
MultiModalFieldElem, MultiModalInputs,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalSharedField, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptUpdate)
Expand Down Expand Up @@ -62,8 +63,9 @@ def get_dummy_mm_data(
# The size of pixel_values might change in the cases where we resize
# the input but never exceeds the dimensions below.
return {
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
"location_coords": torch.full((1, 2), 1.0),
"pixel_values": torch.full((6, 512, 512), 1.0,
dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}


Expand All @@ -75,8 +77,10 @@ def _get_mm_fields_config(
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
location_coords=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
)

def _get_prompt_updates(
Expand All @@ -99,23 +103,34 @@ def apply(

for k, v in mm_data.items():
mm_kwargs[k] = v
mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]}

multimodal_kwargs_items = [
MultiModalKwargsItem.from_elems([
MultiModalFieldElem(modality="image",
key=key,
data=data,
field=MultiModalSharedField(1))
for key, data in mm_kwargs.items()
])
]

return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=[1],
mm_kwargs=MultiModalKwargs(mm_kwargs),
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
mm_hashes=None,
mm_placeholders={},
mm_placeholders=mm_place_holders,
)


@MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
SupportsV0Only):
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
SupportsMultiModalWithRawInput):
""" Prithvi Masked Autoencoder"""

@classmethod
Expand Down Expand Up @@ -169,7 +184,7 @@ def _parse_and_validate_multimodal_data(
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
pixel_values = torch.unbind(pixel_values, dim=0)[0]
# pixel_values = torch.unbind(pixel_values, dim=0)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is commented out. If it's no longer needed, consider removing it to reduce code clutter and improve readability. If it's temporarily disabled for debugging, add a comment explaining why and when it should be re-enabled.


location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
Expand All @@ -181,6 +196,13 @@ def _parse_and_validate_multimodal_data(

return pixel_values, location_coords

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
# We do not really use any input tokens and therefore no embeddings
# to be calculated. However, due to the mandatory token ids in
# the input prompt we pass one token and the size of the dummy
# embedding tensors must reflect that.
return torch.empty(input_ids.shape)

def forward(
self,
input_ids: Optional[torch.Tensor],
Expand All @@ -202,7 +224,10 @@ def pooler(
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
return PoolerOutput([
PoolingSequenceGroupOutput(hidden_state)
for hidden_state in hidden_states
])

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
Expand Down
13 changes: 11 additions & 2 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding,
supports_multimodal, supports_pp,
supports_transcription, supports_v0_only)
supports_multimodal, supports_multimodal_raw_input,
supports_pp, supports_transcription, supports_v0_only)
from .interfaces_base import is_text_generation_model

logger = init_logger(__name__)
Expand Down Expand Up @@ -275,6 +275,7 @@ class _ModelInfo:
is_pooling_model: bool
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input: bool
supports_pp: bool
has_inner_state: bool
is_attention_free: bool
Expand All @@ -291,6 +292,7 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo":
is_pooling_model=True, # Can convert any model into a pooling model
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
Expand Down Expand Up @@ -528,6 +530,13 @@ def is_multimodal_model(
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal

def supports_multimodal_raw_input(
self,
architectures: Union[str, list[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_multimodal_raw_input

def is_pp_supported_model(
self,
architectures: Union[str, list[str]],
Expand Down
2 changes: 1 addition & 1 deletion vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def create_processor(
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")

if tokenizer is None:
if tokenizer is None and not model_config.skip_tokenizer_init:
tokenizer = cached_tokenizer_from_config(model_config)
if disable_cache is None:
mm_config = model_config.get_multimodal_config()
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
self.verify_and_split_kv_cache_groups()
# attention free models are initialized with 0 kv_cache_groups
if len(self.kv_cache_config.kv_cache_groups) > 0:
self.verify_and_split_kv_cache_groups()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm comfortable with adding another coordinator for 0 kv cache groups and re-implement find_longest_cache_hit for it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I've observed more and more case that current find_longest_cache_hit can't handle, I'm suggesting a new KVCacheCoordinatorNoPrefixCache and use it when prefix caching is disabled. Can you sync with the author of #20661 to avoid duplication of work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@christian-pinto I have introduced a KVCacheCoordinatorNoPrefixCache in this PR (#20661 ). I think it should handle your case as well. Could you give it a try?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nopperl thanks for that. Your approach solves my issue too.


def verify_and_split_kv_cache_groups(self) -> None:
"""
Expand Down
19 changes: 12 additions & 7 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
Expand Down Expand Up @@ -65,7 +66,6 @@ def new_empty(self) -> "KVCacheBlocks":


class KVCacheManager:

def __init__(
self,
kv_cache_config: KVCacheConfig,
Expand All @@ -84,12 +84,17 @@ def __init__(
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size

if len(kv_cache_config.kv_cache_groups) == 0:
#This is an attention free model that is started with 0 KVCache groups.
self.block_size = 0
else:
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size

self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
Expand Down
Loading