26
26
from vllm .config import VllmConfig
27
27
from vllm .model_executor .model_loader .weight_utils import default_weight_loader
28
28
from vllm .model_executor .models .interfaces import (IsAttentionFree ,
29
- SupportsMultiModal ,
30
- SupportsV0Only )
29
+ SupportsMultiModalWithRawInput )
31
30
from vllm .model_executor .models .utils import AutoWeightsLoader
32
31
from vllm .model_executor .pooling_metadata import PoolingMetadata
33
32
from vllm .multimodal import MULTIMODAL_REGISTRY
34
- from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig ,
35
- MultiModalInputs , MultiModalKwargs )
33
+ from vllm .multimodal .inputs import (MultiModalDataDict , MultiModalFieldConfig , MultiModalFieldElem ,
34
+ MultiModalInputs , MultiModalKwargs , MultiModalKwargsItem ,
35
+ MultiModalSharedField , PlaceholderRange )
36
36
from vllm .multimodal .parse import MultiModalDataItems
37
37
from vllm .multimodal .processing import (BaseMultiModalProcessor ,
38
38
BaseProcessingInfo , PromptUpdate )
@@ -75,8 +75,8 @@ def _get_mm_fields_config(
75
75
hf_processor_mm_kwargs : Mapping [str , object ],
76
76
) -> Mapping [str , MultiModalFieldConfig ]:
77
77
return dict (
78
- pixel_values = MultiModalFieldConfig .batched ( "image" ),
79
- location_coords = MultiModalFieldConfig .batched ( "image" ),
78
+ pixel_values = MultiModalFieldConfig .shared ( batch_size = 1 , modality = "image" ),
79
+ location_coords = MultiModalFieldConfig .shared ( batch_size = 1 , modality = "image" ),
80
80
)
81
81
82
82
def _get_prompt_updates (
@@ -98,23 +98,32 @@ def apply(
98
98
99
99
for k , v in mm_data .items ():
100
100
mm_kwargs [k ] = v
101
+ mm_place_holders = {
102
+ "image" : [PlaceholderRange (offset = 0 , length = 0 )]
103
+ }
104
+
105
+ multimodal_kwargs_items = [
106
+ MultiModalKwargsItem .from_elems (
107
+ [MultiModalFieldElem (modality = "image" , key = key , data = data , field = MultiModalSharedField (1 ))
108
+ for key , data in mm_kwargs .items ()]
109
+ )
110
+ ]
101
111
102
112
return MultiModalInputs (
103
113
type = "multimodal" ,
104
114
prompt = prompt ,
105
115
prompt_token_ids = [1 ],
106
- mm_kwargs = MultiModalKwargs ( mm_kwargs ),
116
+ mm_kwargs = MultiModalKwargs . from_items ( multimodal_kwargs_items ),
107
117
mm_hashes = None ,
108
- mm_placeholders = {} ,
118
+ mm_placeholders = mm_place_holders ,
109
119
)
110
120
111
121
112
122
@MULTIMODAL_REGISTRY .register_processor (
113
123
PrithviGeoSpatialMAEMultiModalProcessor ,
114
124
info = PrithviGeoSpatialMAEProcessingInfo ,
115
125
dummy_inputs = PrithviGeoSpatialMAEInputBuilder )
116
- class PrithviGeoSpatialMAE (nn .Module , IsAttentionFree , SupportsMultiModal ,
117
- SupportsV0Only ):
126
+ class PrithviGeoSpatialMAE (nn .Module , IsAttentionFree , SupportsMultiModalWithRawInput ):
118
127
""" Prithvi Masked Autoencoder"""
119
128
120
129
def _instantiate_model (self , config : dict ) -> Optional [nn .Module ]:
@@ -172,7 +181,13 @@ def _parse_and_validate_multimodal_data(
172
181
location_coords = None
173
182
174
183
return pixel_values , location_coords
175
-
184
+
185
+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
186
+ # We do not really use any input tokens and therefore no embeddings to be calculated
187
+ # However, due to the mandatory token ids in the input prompt we pass one token and the
188
+ # size of the dummy embedding tensors must reflect that.
189
+ return torch .empty (input_ids .shape )
190
+
176
191
def forward (
177
192
self ,
178
193
input_ids : Optional [torch .Tensor ],
@@ -194,7 +209,7 @@ def pooler(
194
209
hidden_states : torch .Tensor ,
195
210
pooling_metadata : PoolingMetadata ,
196
211
) -> Optional [PoolerOutput ]:
197
- return PoolerOutput ([PoolingSequenceGroupOutput (hidden_states )])
212
+ return PoolerOutput ([PoolingSequenceGroupOutput (hidden_states [ 0 ] )])
198
213
199
214
def load_weights (self , weights : Iterable [tuple [str ,
200
215
torch .Tensor ]]) -> set [str ]:
0 commit comments