Skip to content

Commit 7392c45

Browse files
Last few changes after rebasing to latest branch version
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent 146b3b2 commit 7392c45

File tree

2 files changed

+29
-9
lines changed

2 files changed

+29
-9
lines changed

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def get_dummy_mm_data(
6262
# The size of pixel_values might change in the cases where we resize
6363
# the input but never exceeds the dimensions below.
6464
return {
65-
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
66-
"location_coords": torch.full((1, 2), 1.0),
65+
"pixel_values": torch.full((1, 6, 512, 512), 1.0, dtype=torch.float16),
66+
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
6767
}
6868

6969

vllm/v1/worker/gpu_model_runner.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123
cache_config.cache_dtype]
124124

125125
self.is_multimodal_model = model_config.is_multimodal_model
126-
self.is_pooling_model = model_config.pooler_config is not None
126+
self.model_supports_multimodal_raw_input = model_config.model_supports_multimodal_raw_input
127127
self.max_model_len = model_config.max_model_len
128128
self.max_num_tokens = scheduler_config.max_num_batched_tokens
129129
self.max_num_reqs = scheduler_config.max_num_seqs
@@ -326,6 +326,11 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
326326
Args:
327327
scheduler_output: The scheduler output.
328328
"""
329+
330+
# nothing to be reordered when the mdoel is attention free
331+
if self.model_config.is_attention_free:
332+
return False
333+
329334
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
330335
scheduler_output)
331336

@@ -1019,13 +1024,14 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
10191024
curr_group_outputs = self.model.get_multimodal_embeddings(
10201025
**batched_mm_inputs)
10211026

1022-
sanity_check_mm_encoder_outputs(
1023-
curr_group_outputs,
1024-
expected_num_items=len(grouped_mm_inputs),
1025-
)
1027+
if curr_group_outputs:
1028+
sanity_check_mm_encoder_outputs(
1029+
curr_group_outputs,
1030+
expected_num_items=len(grouped_mm_inputs),
1031+
)
10261032

1027-
for output in curr_group_outputs:
1028-
encoder_outputs.append(output)
1033+
for output in curr_group_outputs:
1034+
encoder_outputs.append(output)
10291035

10301036
# Cache the encoder outputs.
10311037
for (req_id, input_id, pos_info), output in zip(
@@ -1324,6 +1330,9 @@ def execute_model(
13241330
# embeddings), we always use embeddings (rather than token ids)
13251331
# as input to the multimodal model, even when the input is text.
13261332
input_ids = self.input_ids[:num_scheduled_tokens]
1333+
self._maybe_add_model_args(num_scheduled_tokens,
1334+
model_kwargs, scheduler_output)
1335+
13271336
if mm_embeds:
13281337
inputs_embeds = self.model.get_input_embeddings(
13291338
input_ids, mm_embeds)
@@ -1339,6 +1348,7 @@ def execute_model(
13391348
# multimodal models, it is not desirable for performance since
13401349
# then the embedding layer is not included in the CUDA graph.
13411350
input_ids = self.input_ids[:num_input_tokens]
1351+
self._maybe_add_model_args(num_input_tokens, model_kwargs, scheduler_output)
13421352
inputs_embeds = None
13431353
if self.uses_mrope:
13441354
positions = self.mrope_positions[:, :num_input_tokens]
@@ -1372,6 +1382,10 @@ def execute_model(
13721382
positions=positions,
13731383
intermediate_tensors=intermediate_tensors,
13741384
inputs_embeds=inputs_embeds,
1385+
**MultiModalKwargs.as_kwargs(
1386+
model_kwargs,
1387+
device=self.device,
1388+
)
13751389
)
13761390

13771391
self.maybe_wait_for_kv_save()
@@ -1998,6 +2012,8 @@ def _dummy_run(
19982012
with self.maybe_dummy_run_with_lora(self.lora_config,
19992013
num_scheduled_tokens):
20002014
model = self.model
2015+
model_kwargs: dict[str, Any] = {}
2016+
self._maybe_add_model_args(num_tokens, model_kwargs)
20012017
if self.is_multimodal_model:
20022018
input_ids = None
20032019
inputs_embeds = self.inputs_embeds[:num_tokens]
@@ -2032,7 +2048,11 @@ def _dummy_run(
20322048
positions=positions,
20332049
intermediate_tensors=intermediate_tensors,
20342050
inputs_embeds=inputs_embeds,
2051+
**MultiModalKwargs.as_kwargs(
2052+
model_kwargs,
2053+
device=self.device)
20352054
)
2055+
20362056
if self.use_aux_hidden_state_outputs:
20372057
hidden_states, _ = outputs
20382058
else:

0 commit comments

Comments
 (0)