Skip to content

Commit e59d7dc

Browse files
Rebased to master
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent 137ec29 commit e59d7dc

File tree

10 files changed

+215
-283
lines changed

10 files changed

+215
-283
lines changed

examples/offline_inference/prithvi_geospatial_mae.py

Lines changed: 189 additions & 248 deletions
Large diffs are not rendered by default.

tests/models/multimodal/pooling/test_prithvi_mae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _run_test(
3636

3737
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
3838

39-
39+
@pytest.mark.core_model
4040
@pytest.mark.parametrize("model", MODELS)
4141
def test_models_image(
4242
hf_runner,

vllm/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,10 +1514,6 @@ def uses_mrope(self) -> bool:
15141514
@property
15151515
def is_multimodal_model(self) -> bool:
15161516
return self.multimodal_config is not None
1517-
1518-
@property
1519-
def is_pooling_model(self) -> bool:
1520-
return self.registry.is_pooling_model(self.architectures)
15211517

15221518
@property
15231519
def is_cross_encoder(self) -> bool:

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ def _parse_and_validate_multimodal_data(
184184
if not isinstance(pixel_values, torch.Tensor):
185185
raise ValueError(f"Incorrect type of pixel_values. "
186186
f"Got type: {type(pixel_values)}")
187-
# pixel_values = torch.unbind(pixel_values, dim=0)[0]
188187

189188
location_coords = kwargs.pop("location_coords", None)
190189
if not isinstance(location_coords, torch.Tensor):
@@ -201,7 +200,7 @@ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
201200
# to be calculated. However, due to the mandatory token ids in
202201
# the input prompt we pass one token and the size of the dummy
203202
# embedding tensors must reflect that.
204-
return torch.empty(input_ids.shape)
203+
return torch.empty((input_ids.shape[0], 0))
205204

206205
def forward(
207206
self,

vllm/v1/core/kv_cache_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from abc import ABC, abstractmethod
54
from collections import defaultdict
65
from dataclasses import dataclass
76
from typing import Optional
@@ -66,6 +65,7 @@ def new_empty(self) -> "KVCacheBlocks":
6665

6766

6867
class KVCacheManager:
68+
6969
def __init__(
7070
self,
7171
kv_cache_config: KVCacheConfig,

vllm/v1/core/sched/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,8 @@ def schedule(self) -> SchedulerOutput:
488488

489489
if self.lora_config and request.lora_request:
490490
scheduled_loras.add(request.lora_request.lora_int_id)
491-
req_to_new_block_ids[request.request_id] = \
492-
self.kv_cache_manager.get_block_ids(request.request_id)
491+
req_to_new_block_ids[request.request_id] = (
492+
self.kv_cache_manager.get_block_ids(request.request_id))
493493
num_scheduled_tokens[request.request_id] = num_new_tokens
494494
token_budget -= num_new_tokens
495495
request.status = RequestStatus.RUNNING

vllm/v1/engine/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def _initialize_kv_caches(
152152
kv_cache_configs = [
153153
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
154154
available_gpu_memory_one_worker)
155-
for kv_cache_spec_one_worker, available_gpu_memory_one_worker
156-
in zip(kv_cache_specs, available_gpu_memory)
155+
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
156+
zip(kv_cache_specs, available_gpu_memory)
157157
]
158158

159159
# Since we use a shared centralized controller, we need the

vllm/v1/engine/output_processor.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,13 +330,14 @@ def add_request(
330330
tokenizer = None if not self.tokenizer else \
331331
self.tokenizer.get_lora_tokenizer(request.lora_request)
332332

333-
req_state = RequestState.from_new_request(tokenizer=tokenizer,
334-
request=request,
335-
prompt=prompt,
336-
parent_req=parent_req,
337-
request_index=request_index,
338-
queue=queue,
339-
log_stats=self.log_stats)
333+
req_state = RequestState.from_new_request(
334+
tokenizer=tokenizer,
335+
request=request,
336+
prompt=prompt,
337+
parent_req=parent_req,
338+
request_index=request_index,
339+
queue=queue,
340+
log_stats=self.log_stats)
340341
self.request_states[request_id] = req_state
341342
self.lora_states.add_request(req_state)
342343
if parent_req:

vllm/v1/engine/processor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,10 @@ def _validate_model_input(
384384
tokenizer = None
385385
else:
386386
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
387+
max_input_id = max(prompt_ids, default=0)
388+
if max_input_id > tokenizer.max_token_id:
389+
raise ValueError(
390+
f"Token id {max_input_id} is out of vocabulary")
387391

388392
prompt_ids = prompt_inputs["prompt_token_ids"]
389393
if not prompt_ids:
@@ -392,12 +396,6 @@ def _validate_model_input(
392396
else:
393397
raise ValueError(f"The {prompt_type} prompt cannot be empty")
394398

395-
if tokenizer:
396-
max_input_id = max(prompt_ids, default=0)
397-
if max_input_id > tokenizer.max_token_id:
398-
raise ValueError(
399-
f"Token id {max_input_id} is out of vocabulary")
400-
401399
max_prompt_len = self.model_config.max_model_len
402400
if len(prompt_ids) > max_prompt_len:
403401
if prompt_type == "encoder" and model_config.is_multimodal_model:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123
cache_config.cache_dtype]
124124

125125
self.is_multimodal_model = model_config.is_multimodal_model
126-
self.is_pooling_model = model_config.is_pooling_model
126+
self.is_pooling_model = model_config.pooler_config is not None
127127
self.model_supports_multimodal_raw_input = (
128128
model_config.model_supports_multimodal_raw_input)
129129
self.max_model_len = model_config.max_model_len
@@ -328,8 +328,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
328328
Args:
329329
scheduler_output: The scheduler output.
330330
"""
331-
332-
# nothing to be reordered when the mdoel is attention free
333331
if self.model_config.is_attention_free:
334332
return False
335333

@@ -1059,14 +1057,13 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
10591057
curr_group_outputs = self.model.get_multimodal_embeddings(
10601058
**batched_mm_inputs)
10611059

1062-
if curr_group_outputs:
1063-
sanity_check_mm_encoder_outputs(
1064-
curr_group_outputs,
1065-
expected_num_items=len(grouped_mm_inputs),
1066-
)
1060+
sanity_check_mm_encoder_outputs(
1061+
curr_group_outputs,
1062+
expected_num_items=len(grouped_mm_inputs),
1063+
)
10671064

1068-
for output in curr_group_outputs:
1069-
encoder_outputs.append(output)
1065+
for output in curr_group_outputs:
1066+
encoder_outputs.append(output)
10701067

10711068
# Cache the encoder outputs.
10721069
for (req_id, input_id, pos_info), output in zip(

0 commit comments

Comments
 (0)