-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Core][Model] PrithviMAE Enablement on vLLM v1 engine (with zero kv_cache_groups) #20577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4c68b8f
d018203
dbdb7db
ad667ce
ceedf19
8e3945b
b8f6189
3ea3ce0
c727744
eda0697
645d061
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import pytest | ||
import torch | ||
|
||
from ....conftest import VllmRunner | ||
|
||
|
||
def generate_test_mm_data(): | ||
mm_data = { | ||
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), | ||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), | ||
} | ||
return mm_data | ||
|
||
|
||
def _run_test( | ||
vllm_runner: type[VllmRunner], | ||
model: str, | ||
) -> None: | ||
|
||
mm_data = generate_test_mm_data() | ||
prompt = { | ||
# This model deals with no text input | ||
"prompt_token_ids": [1], | ||
"multi_modal_data": mm_data | ||
} | ||
with vllm_runner(model, | ||
task="embed", | ||
dtype=torch.float16, | ||
enforce_eager=True, | ||
skip_tokenizer_init=True) as vllm_model: | ||
vllm_model.encode(prompt) | ||
|
||
|
||
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
def test_models_image( | ||
hf_runner, | ||
vllm_runner, | ||
image_assets, | ||
model: str, | ||
) -> None: | ||
_run_test( | ||
vllm_runner, | ||
model, | ||
) |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -612,6 +612,8 @@ def __post_init__(self) -> None: | |||||||||
self.served_model_name = get_served_model_name(self.model, | ||||||||||
self.served_model_name) | ||||||||||
self.multimodal_config = self._init_multimodal_config() | ||||||||||
self.model_supports_multimodal_raw_input = ( | ||||||||||
self._init_model_supports_multimodal_raw_input()) | ||||||||||
if not self.skip_tokenizer_init: | ||||||||||
self._verify_tokenizer_mode() | ||||||||||
|
||||||||||
|
@@ -715,6 +717,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: | |||||||||
|
||||||||||
return None | ||||||||||
|
||||||||||
def _init_model_supports_multimodal_raw_input(self): | ||||||||||
return self.registry.supports_multimodal_raw_input(self.architectures) | ||||||||||
|
||||||||||
def _get_encoder_config(self): | ||||||||||
return get_sentence_transformer_tokenizer_config( | ||||||||||
self.model, self.revision) | ||||||||||
|
@@ -1120,10 +1125,10 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: | |||||||||
return self.get_hf_config_sliding_window() | ||||||||||
|
||||||||||
def get_vocab_size(self) -> int: | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
return self.hf_text_config.vocab_size | ||||||||||
return getattr(self.hf_text_config, "vocab_size", 0) | ||||||||||
|
||||||||||
def get_hidden_size(self) -> int: | ||||||||||
return self.hf_text_config.hidden_size | ||||||||||
return getattr(self.hf_text_config, "hidden_size", 0) | ||||||||||
Comment on lines
1130
to
+1131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider providing a default value when using
Suggested change
|
||||||||||
|
||||||||||
@property | ||||||||||
def is_deepseek_mla(self) -> bool: | ||||||||||
|
@@ -1417,6 +1422,10 @@ def uses_mrope(self) -> bool: | |||||||||
@property | ||||||||||
def is_multimodal_model(self) -> bool: | ||||||||||
return self.multimodal_config is not None | ||||||||||
|
||||||||||
@property | ||||||||||
def is_pooling_model(self) -> bool: | ||||||||||
return self.registry.is_pooling_model(self.architectures) | ||||||||||
|
||||||||||
@property | ||||||||||
def is_cross_encoder(self) -> bool: | ||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,14 +25,15 @@ | |
|
||
from vllm.config import VllmConfig | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.models.interfaces import (IsAttentionFree, | ||
SupportsMultiModal, | ||
SupportsV0Only) | ||
from vllm.model_executor.models.interfaces import ( | ||
IsAttentionFree, SupportsMultiModalWithRawInput) | ||
from vllm.model_executor.models.utils import AutoWeightsLoader | ||
from vllm.model_executor.pooling_metadata import PoolingMetadata | ||
from vllm.multimodal import MULTIMODAL_REGISTRY | ||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, | ||
MultiModalInputs, MultiModalKwargs) | ||
MultiModalFieldElem, MultiModalInputs, | ||
MultiModalKwargs, MultiModalKwargsItem, | ||
MultiModalSharedField, PlaceholderRange) | ||
from vllm.multimodal.parse import MultiModalDataItems | ||
from vllm.multimodal.processing import (BaseMultiModalProcessor, | ||
BaseProcessingInfo, PromptUpdate) | ||
|
@@ -62,8 +63,9 @@ def get_dummy_mm_data( | |
# The size of pixel_values might change in the cases where we resize | ||
# the input but never exceeds the dimensions below. | ||
return { | ||
"pixel_values": torch.full((1, 6, 512, 512), 1.0), | ||
"location_coords": torch.full((1, 2), 1.0), | ||
"pixel_values": torch.full((6, 512, 512), 1.0, | ||
dtype=torch.float16), | ||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), | ||
} | ||
|
||
|
||
|
@@ -75,8 +77,10 @@ def _get_mm_fields_config( | |
hf_processor_mm_kwargs: Mapping[str, object], | ||
) -> Mapping[str, MultiModalFieldConfig]: | ||
return dict( | ||
pixel_values=MultiModalFieldConfig.batched("image"), | ||
location_coords=MultiModalFieldConfig.batched("image"), | ||
pixel_values=MultiModalFieldConfig.shared(batch_size=1, | ||
modality="image"), | ||
location_coords=MultiModalFieldConfig.shared(batch_size=1, | ||
modality="image"), | ||
) | ||
|
||
def _get_prompt_updates( | ||
|
@@ -99,23 +103,34 @@ def apply( | |
|
||
for k, v in mm_data.items(): | ||
mm_kwargs[k] = v | ||
mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]} | ||
|
||
multimodal_kwargs_items = [ | ||
MultiModalKwargsItem.from_elems([ | ||
MultiModalFieldElem(modality="image", | ||
key=key, | ||
data=data, | ||
field=MultiModalSharedField(1)) | ||
for key, data in mm_kwargs.items() | ||
]) | ||
] | ||
|
||
return MultiModalInputs( | ||
type="multimodal", | ||
prompt=prompt, | ||
prompt_token_ids=[1], | ||
mm_kwargs=MultiModalKwargs(mm_kwargs), | ||
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), | ||
mm_hashes=None, | ||
mm_placeholders={}, | ||
mm_placeholders=mm_place_holders, | ||
) | ||
|
||
|
||
@MULTIMODAL_REGISTRY.register_processor( | ||
PrithviGeoSpatialMAEMultiModalProcessor, | ||
info=PrithviGeoSpatialMAEProcessingInfo, | ||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder) | ||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, | ||
SupportsV0Only): | ||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, | ||
SupportsMultiModalWithRawInput): | ||
""" Prithvi Masked Autoencoder""" | ||
|
||
@classmethod | ||
|
@@ -169,7 +184,7 @@ def _parse_and_validate_multimodal_data( | |
if not isinstance(pixel_values, torch.Tensor): | ||
raise ValueError(f"Incorrect type of pixel_values. " | ||
f"Got type: {type(pixel_values)}") | ||
pixel_values = torch.unbind(pixel_values, dim=0)[0] | ||
# pixel_values = torch.unbind(pixel_values, dim=0)[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
location_coords = kwargs.pop("location_coords", None) | ||
if not isinstance(location_coords, torch.Tensor): | ||
|
@@ -181,6 +196,13 @@ def _parse_and_validate_multimodal_data( | |
|
||
return pixel_values, location_coords | ||
|
||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
# We do not really use any input tokens and therefore no embeddings | ||
# to be calculated. However, due to the mandatory token ids in | ||
# the input prompt we pass one token and the size of the dummy | ||
# embedding tensors must reflect that. | ||
return torch.empty(input_ids.shape) | ||
|
||
def forward( | ||
self, | ||
input_ids: Optional[torch.Tensor], | ||
|
@@ -202,7 +224,10 @@ def pooler( | |
hidden_states: torch.Tensor, | ||
pooling_metadata: PoolingMetadata, | ||
) -> Optional[PoolerOutput]: | ||
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) | ||
return PoolerOutput([ | ||
PoolingSequenceGroupOutput(hidden_state) | ||
for hidden_state in hidden_states | ||
]) | ||
|
||
def load_weights(self, weights: Iterable[tuple[str, | ||
torch.Tensor]]) -> set[str]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,7 +219,9 @@ def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, | |
super().__init__(kv_cache_config, max_model_len, use_eagle, | ||
enable_caching, caching_hash_fn, | ||
enable_kv_cache_events) | ||
self.verify_and_split_kv_cache_groups() | ||
# attention free models are initialized with 0 kv_cache_groups | ||
if len(self.kv_cache_config.kv_cache_groups) > 0: | ||
self.verify_and_split_kv_cache_groups() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm comfortable with adding another coordinator for 0 kv cache groups and re-implement find_longest_cache_hit for it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I've observed more and more case that current There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @christian-pinto I have introduced a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @nopperl thanks for that. Your approach solves my issue too. |
||
|
||
def verify_and_split_kv_cache_groups(self) -> None: | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using
torch.float16
instead of the string literal for specifying the dtype. This approach is more type-safe and avoids potential errors due to typos or inconsistencies.