26
26
from vllm .config import VllmConfig
27
27
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
28
28
from vllm .model_executor .models .interfaces import (IsAttentionFree ,
29
- SupportsMultiModal ,
30
- SupportsV0Only )
29
+ SupportsMultiModalWithRawInput )
31
30
from vllm .model_executor .models .utils import AutoWeightsLoader
32
31
from vllm .model_executor .pooling_metadata import PoolingMetadata
33
32
from vllm .multimodal import MULTIMODAL_REGISTRY
34
- from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
35
- MultiModalInputs , MultiModalKwargs )
33
+ from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig , MultiModalFieldElem ,
34
+ MultiModalInputs , MultiModalKwargs , MultiModalKwargsItem ,
35
+ MultiModalSharedField , PlaceholderRange )
36
36
from vllm .multimodal .parse import MultiModalDataItems
37
37
from vllm .multimodal .processing import (BaseMultiModalProcessor ,
38
38
BaseProcessingInfo , PromptUpdate )
@@ -75,8 +75,8 @@ def _get_mm_fields_config(
75
75
hf_processor_mm_kwargs : Mapping [str , object ],
76
76
) -> Mapping [str , MultiModalFieldConfig ]:
77
77
return dict (
78
- pixel_values = MultiModalFieldConfig .batched ( "image" ),
79
- location_coords = MultiModalFieldConfig .batched ( "image" ),
78
+ pixel_values = MultiModalFieldConfig .shared ( batch_size = 1 , modality = "image" ),
79
+ location_coords = MultiModalFieldConfig .shared ( batch_size = 1 , modality = "image" ),
80
80
)
81
81
82
82
def _get_prompt_updates (
@@ -99,23 +99,32 @@ def apply(
99
99
100
100
for k , v in mm_data .items ():
101
101
mm_kwargs [k ] = v
102
+ mm_place_holders = {
103
+ "image" : [PlaceholderRange (offset = 0 , length = 0 )]
104
+ }
105
+
106
+ multimodal_kwargs_items = [
107
+ MultiModalKwargsItem .from_elems (
108
+ [MultiModalFieldElem (modality = "image" , key = key , data = data , field = MultiModalSharedField (1 ))
109
+ for key , data in mm_kwargs .items ()]
110
+ )
111
+ ]
102
112
103
113
return MultiModalInputs (
104
114
type = "multimodal" ,
105
115
prompt = prompt ,
106
116
prompt_token_ids = [1 ],
107
- mm_kwargs = MultiModalKwargs ( mm_kwargs ),
117
+ mm_kwargs = MultiModalKwargs . from_items ( multimodal_kwargs_items ),
108
118
mm_hashes = None ,
109
- mm_placeholders = {} ,
119
+ mm_placeholders = mm_place_holders ,
110
120
)
111
121
112
122
113
123
@MULTIMODAL_REGISTRY .register_processor (
114
124
PrithviGeoSpatialMAEMultiModalProcessor ,
115
125
info = PrithviGeoSpatialMAEProcessingInfo ,
116
126
dummy_inputs = PrithviGeoSpatialMAEInputBuilder )
117
- class PrithviGeoSpatialMAE (nn .Module , IsAttentionFree , SupportsMultiModal ,
118
- SupportsV0Only ):
127
+ class PrithviGeoSpatialMAE (nn .Module , IsAttentionFree , SupportsMultiModalWithRawInput ):
119
128
""" Prithvi Masked Autoencoder"""
120
129
121
130
@classmethod
@@ -180,7 +189,13 @@ def _parse_and_validate_multimodal_data(
180
189
location_coords = None
181
190
182
191
return pixel_values , location_coords
183
-
192
+
193
+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
194
+ # We do not really use any input tokens and therefore no embeddings to be calculated
195
+ # However, due to the mandatory token ids in the input prompt we pass one token and the
196
+ # size of the dummy embedding tensors must reflect that.
197
+ return torch .empty (input_ids .shape )
198
+
184
199
def forward (
185
200
self ,
186
201
input_ids : Optional [torch .Tensor ],
@@ -202,7 +217,7 @@ def pooler(
202
217
hidden_states : torch .Tensor ,
203
218
pooling_metadata : PoolingMetadata ,
204
219
) -> Optional [PoolerOutput ]:
205
- return PoolerOutput ([PoolingSequenceGroupOutput (hidden_states )])
220
+ return PoolerOutput ([PoolingSequenceGroupOutput (hidden_states [ 0 ] )])
206
221
207
222
def load_weights (self , weights : Iterable [tuple [str ,
208
223
torch .Tensor ]]) -> set [str ]:
0 commit comments