Skip to content

Commit f3ab1fb

Browse files
Support for attention free models in V1
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent f174bbf commit f3ab1fb

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
from vllm.config import VllmConfig
2727
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2828
from vllm.model_executor.models.interfaces import (IsAttentionFree,
29-
SupportsMultiModal,
30-
SupportsV0Only)
29+
SupportsMultiModalWithRawInput)
3130
from vllm.model_executor.models.utils import AutoWeightsLoader
3231
from vllm.model_executor.pooling_metadata import PoolingMetadata
3332
from vllm.multimodal import MULTIMODAL_REGISTRY
34-
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
35-
MultiModalInputs, MultiModalKwargs)
33+
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalFieldElem,
34+
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
35+
MultiModalSharedField, PlaceholderRange)
3636
from vllm.multimodal.parse import MultiModalDataItems
3737
from vllm.multimodal.processing import (BaseMultiModalProcessor,
3838
BaseProcessingInfo, PromptUpdate)
@@ -75,8 +75,8 @@ def _get_mm_fields_config(
7575
hf_processor_mm_kwargs: Mapping[str, object],
7676
) -> Mapping[str, MultiModalFieldConfig]:
7777
return dict(
78-
pixel_values=MultiModalFieldConfig.batched("image"),
79-
location_coords=MultiModalFieldConfig.batched("image"),
78+
pixel_values=MultiModalFieldConfig.shared(batch_size=1, modality="image"),
79+
location_coords=MultiModalFieldConfig.shared(batch_size=1, modality="image"),
8080
)
8181

8282
def _get_prompt_updates(
@@ -99,23 +99,32 @@ def apply(
9999

100100
for k, v in mm_data.items():
101101
mm_kwargs[k] = v
102+
mm_place_holders = {
103+
"image": [PlaceholderRange(offset=0, length=0)]
104+
}
105+
106+
multimodal_kwargs_items = [
107+
MultiModalKwargsItem.from_elems(
108+
[MultiModalFieldElem(modality="image", key=key, data=data, field=MultiModalSharedField(1))
109+
for key, data in mm_kwargs.items()]
110+
)
111+
]
102112

103113
return MultiModalInputs(
104114
type="multimodal",
105115
prompt=prompt,
106116
prompt_token_ids=[1],
107-
mm_kwargs=MultiModalKwargs(mm_kwargs),
117+
mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items),
108118
mm_hashes=None,
109-
mm_placeholders={},
119+
mm_placeholders=mm_place_holders,
110120
)
111121

112122

113123
@MULTIMODAL_REGISTRY.register_processor(
114124
PrithviGeoSpatialMAEMultiModalProcessor,
115125
info=PrithviGeoSpatialMAEProcessingInfo,
116126
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
117-
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal,
118-
SupportsV0Only):
127+
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModalWithRawInput):
119128
""" Prithvi Masked Autoencoder"""
120129

121130
@classmethod
@@ -180,7 +189,13 @@ def _parse_and_validate_multimodal_data(
180189
location_coords = None
181190

182191
return pixel_values, location_coords
183-
192+
193+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
194+
# We do not really use any input tokens and therefore no embeddings to be calculated
195+
# However, due to the mandatory token ids in the input prompt we pass one token and the
196+
# size of the dummy embedding tensors must reflect that.
197+
return torch.empty(input_ids.shape)
198+
184199
def forward(
185200
self,
186201
input_ids: Optional[torch.Tensor],
@@ -202,7 +217,7 @@ def pooler(
202217
hidden_states: torch.Tensor,
203218
pooling_metadata: PoolingMetadata,
204219
) -> Optional[PoolerOutput]:
205-
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
220+
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states[0])])
206221

207222
def load_weights(self, weights: Iterable[tuple[str,
208223
torch.Tensor]]) -> set[str]:

0 commit comments

Comments
 (0)