-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
[Core][Model] PrithviMAE Enablement on vLLM v1 engine (superseded by PR 20577) #20072
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
953f66a
f174bbf
f3ab1fb
ff86cb0
769d8dd
9cc76b1
f8226da
c708266
0dba4cd
e0ae311
aab8956
3b7729e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
||
import pytest | ||
import torch | ||
|
||
from ....conftest import VllmRunner | ||
|
||
|
||
def generate_test_mm_data(): | ||
mm_data = { | ||
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), | ||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), | ||
} | ||
return mm_data | ||
|
||
|
||
def _run_test( | ||
vllm_runner: type[VllmRunner], | ||
model: str, | ||
) -> None: | ||
|
||
mm_data = generate_test_mm_data() | ||
prompt = { | ||
# This model deals with no text input | ||
"prompt_token_ids": [1], | ||
"multi_modal_data": mm_data | ||
} | ||
with vllm_runner(model, | ||
task="embed", | ||
dtype=torch.float16, | ||
enforce_eager=True, | ||
skip_tokenizer_init=True) as vllm_model: | ||
vllm_model.encode(prompt) | ||
|
||
|
||
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
def test_models_image( | ||
hf_runner, | ||
vllm_runner, | ||
image_assets, | ||
model: str, | ||
) -> None: | ||
_run_test( | ||
vllm_runner, | ||
model, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,14 +25,15 @@ | |
|
||
from vllm.config import VllmConfig | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.models.interfaces import (IsAttentionFree, | ||
SupportsMultiModal, | ||
SupportsV0Only) | ||
from vllm.model_executor.models.interfaces import ( | ||
IsAttentionFree, SupportsMultiModalWithRawInput) | ||
from vllm.model_executor.models.utils import AutoWeightsLoader | ||
from vllm.model_executor.pooling_metadata import PoolingMetadata | ||
from vllm.multimodal import MULTIMODAL_REGISTRY | ||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, | ||
MultiModalInputs, MultiModalKwargs) | ||
MultiModalFieldElem, MultiModalInputs, | ||
MultiModalKwargs, MultiModalKwargsItem, | ||
MultiModalSharedField, PlaceholderRange) | ||
from vllm.multimodal.parse import MultiModalDataItems | ||
from vllm.multimodal.processing import (BaseMultiModalProcessor, | ||
BaseProcessingInfo, PromptUpdate) | ||
|
@@ -62,8 +63,9 @@ def get_dummy_mm_data( | |
# The size of pixel_values might change in the cases where we resize | ||
# the input but never exceeds the dimensions below. | ||
return { | ||
"pixel_values": torch.full((1, 6, 512, 512), 1.0), | ||
"location_coords": torch.full((1, 2), 1.0), | ||
"pixel_values": torch.full((6, 512, 512), 1.0, | ||
dtype=torch.float16), | ||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), | ||
} | ||
|
||
|
||
|
@@ -75,8 +77,10 @@ def _get_mm_fields_config( | |
hf_processor_mm_kwargs: Mapping[str, object], | ||
) -> Mapping[str, MultiModalFieldConfig]: | ||
return dict( | ||
pixel_values=MultiModalFieldConfig.batched("image"), | ||
location_coords=MultiModalFieldConfig.batched("image"), | ||
pixel_values=MultiModalFieldConfig.shared(batch_size=1, | ||
modality="image"), | ||
location_coords=MultiModalFieldConfig.shared(batch_size=1, | ||
modality="image"), | ||
) | ||
|
||
def _get_prompt_updates( | ||
|
@@ -99,23 +103,34 @@ def apply( | |
|
||
for k, v in mm_data.items(): | ||
mm_kwargs[k] = v | ||
mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]} | ||
|
||
multimodal_kwargs_items = [ | ||
MultiModalKwargsItem.from_elems([ | ||
MultiModalFieldElem(modality="image", | ||
key=key, | ||
data=data, | ||
field=MultiModalSharedField(1)) | ||
for key, data in mm_kwargs.items() | ||
]) | ||
] | ||
|
||
return MultiModalInputs( | ||
type="multimodal", | ||
prompt=prompt, | ||
prompt_token_ids=[1], | ||
mm_kwargs=MultiModalKwargs(mm_kwargs), | ||
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), | ||
mm_hashes=None, | ||
mm_placeholders={}, | ||
mm_placeholders=mm_place_holders, | ||
) | ||
|
||
|
||
@MULTIMODAL_REGISTRY.register_processor( | ||
PrithviGeoSpatialMAEMultiModalProcessor, | ||
info=PrithviGeoSpatialMAEProcessingInfo, | ||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder) | ||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, | ||
SupportsV0Only): | ||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, | ||
SupportsMultiModalWithRawInput): | ||
""" Prithvi Masked Autoencoder""" | ||
|
||
@classmethod | ||
|
@@ -169,7 +184,7 @@ def _parse_and_validate_multimodal_data( | |
if not isinstance(pixel_values, torch.Tensor): | ||
raise ValueError(f"Incorrect type of pixel_values. " | ||
f"Got type: {type(pixel_values)}") | ||
pixel_values = torch.unbind(pixel_values, dim=0)[0] | ||
# pixel_values = torch.unbind(pixel_values, dim=0)[0] | ||
|
||
location_coords = kwargs.pop("location_coords", None) | ||
if not isinstance(location_coords, torch.Tensor): | ||
|
@@ -181,6 +196,13 @@ def _parse_and_validate_multimodal_data( | |
|
||
return pixel_values, location_coords | ||
|
||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | ||
# We do not really use any input tokens and therefore no embeddings | ||
# to be calculated. However, due to the mandatory token ids in | ||
# the input prompt we pass one token and the size of the dummy | ||
# embedding tensors must reflect that. | ||
return torch.empty(input_ids.shape) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe that memory profiling is only done while initializing the KVCache, which is skipped because the model is attention free. It should happen in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but if we're adding a new programming model for MM inputs I think it would be nice to handle this. If not a full solution, perhaps a test to disable cuda graphs and memory profiling for this special case. But I'm thinking that maybe the raw MM inputs will be hard to generalize. Perhaps a hacky solution for this particular model would be to store the raw MM inputs in the |
||
def forward( | ||
self, | ||
input_ids: Optional[torch.Tensor], | ||
|
@@ -202,7 +224,10 @@ def pooler( | |
hidden_states: torch.Tensor, | ||
pooling_metadata: PoolingMetadata, | ||
) -> Optional[PoolerOutput]: | ||
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) | ||
return PoolerOutput([ | ||
PoolingSequenceGroupOutput(hidden_state) | ||
for hidden_state in hidden_states | ||
]) | ||
|
||
def load_weights(self, weights: Iterable[tuple[str, | ||
torch.Tensor]]) -> set[str]: | ||
|
Uh oh!
There was an error while loading. Please reload this page.