Skip to content

Commit ce227da

Browse files
Support for attention free models in V1
1 parent ffd803c commit ce227da

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
from vllm.config import VllmConfig
2727
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2828
from vllm.model_executor.models.interfaces import (IsAttentionFree,
29-
SupportsMultiModal,
30-
SupportsV0Only)
29+
SupportsMultiModalWithRawInput)
3130
from vllm.model_executor.models.utils import AutoWeightsLoader
3231
from vllm.model_executor.pooling_metadata import PoolingMetadata
3332
from vllm.multimodal import MULTIMODAL_REGISTRY
34-
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
35-
MultiModalInputs, MultiModalKwargs)
33+
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalFieldElem,
34+
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
35+
MultiModalSharedField, PlaceholderRange)
3636
from vllm.multimodal.parse import MultiModalDataItems
3737
from vllm.multimodal.processing import (BaseMultiModalProcessor,
3838
BaseProcessingInfo, PromptUpdate)
@@ -75,8 +75,8 @@ def _get_mm_fields_config(
7575
hf_processor_mm_kwargs: Mapping[str, object],
7676
) -> Mapping[str, MultiModalFieldConfig]:
7777
return dict(
78-
pixel_values=MultiModalFieldConfig.batched("image"),
79-
location_coords=MultiModalFieldConfig.batched("image"),
78+
pixel_values=MultiModalFieldConfig.shared(batch_size=1, modality="image"),
79+
location_coords=MultiModalFieldConfig.shared(batch_size=1, modality="image"),
8080
)
8181

8282
def _get_prompt_updates(
@@ -98,23 +98,32 @@ def apply(
9898

9999
for k, v in mm_data.items():
100100
mm_kwargs[k] = v
101+
mm_place_holders = {
102+
"image": [PlaceholderRange(offset=0, length=0)]
103+
}
104+
105+
multimodal_kwargs_items = [
106+
MultiModalKwargsItem.from_elems(
107+
[MultiModalFieldElem(modality="image", key=key, data=data, field=MultiModalSharedField(1))
108+
for key, data in mm_kwargs.items()]
109+
)
110+
]
101111

102112
return MultiModalInputs(
103113
type="multimodal",
104114
prompt=prompt,
105115
prompt_token_ids=[1],
106-
mm_kwargs=MultiModalKwargs(mm_kwargs),
116+
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
107117
mm_hashes=None,
108-
mm_placeholders={},
118+
mm_placeholders=mm_place_holders,
109119
)
110120

111121

112122
@MULTIMODAL_REGISTRY.register_processor(
113123
PrithviGeoSpatialMAEMultiModalProcessor,
114124
info=PrithviGeoSpatialMAEProcessingInfo,
115125
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
116-
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
117-
SupportsV0Only):
126+
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModalWithRawInput):
118127
""" Prithvi Masked Autoencoder"""
119128

120129
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
@@ -172,7 +181,13 @@ def _parse_and_validate_multimodal_data(
172181
location_coords = None
173182

174183
return pixel_values, location_coords
175-
184+
185+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
186+
# We do not really use any input tokens and therefore no embeddings to be calculated
187+
# However, due to the mandatory token ids in the input prompt we pass one token and the
188+
# size of the dummy embedding tensors must reflect that.
189+
return torch.empty(input_ids.shape)
190+
176191
def forward(
177192
self,
178193
input_ids: Optional[torch.Tensor],
@@ -194,7 +209,7 @@ def pooler(
194209
hidden_states: torch.Tensor,
195210
pooling_metadata: PoolingMetadata,
196211
) -> Optional[PoolerOutput]:
197-
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
212+
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states[0])])
198213

199214
def load_weights(self, weights: Iterable[tuple[str,
200215
torch.Tensor]]) -> set[str]:

0 commit comments

Comments
 (0)