Skip to content

Commit 9555f1b

Browse files
Some reformatting to make the pre-commit hooks succeed
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent c9dea50 commit 9555f1b

File tree

11 files changed

+153
-84
lines changed

11 files changed

+153
-84
lines changed

examples/offline_inference/prithvi_geospatial_mae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(self):
144144
model=os.path.join(os.path.dirname(__file__), "./model"),
145145
skip_tokenizer_init=True,
146146
dtype="float16",
147-
enforce_eager=True
147+
enforce_eager=True,
148148
)
149149

150150
def run(self, input_data, location_coords):
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from ....conftest import VllmRunner
8+
9+
def generate_test_mm_data():
10+
mm_data = {
11+
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
12+
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
13+
}
14+
return mm_data
15+
16+
def _run_test(
17+
vllm_runner: type[VllmRunner],
18+
model: str,
19+
) -> None:
20+
21+
mm_data = generate_test_mm_data()
22+
prompt = {
23+
# This model deals with no text input
24+
"prompt_token_ids": [1],
25+
"multi_modal_data": mm_data
26+
}
27+
with vllm_runner(model, task="embed",
28+
dtype=torch.float16,
29+
enforce_eager=True,
30+
skip_tokenizer_init=True) as vllm_model:
31+
output = vllm_model.encode(prompt)
32+
33+
MODELS=["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
34+
@pytest.mark.parametrize("model", MODELS)
35+
def test_models_image(
36+
hf_runner,
37+
vllm_runner,
38+
image_assets,
39+
model: str,
40+
) -> None:
41+
_run_test(
42+
vllm_runner,
43+
model,
44+
)

vllm/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,8 +612,10 @@ def __post_init__(self) -> None:
612612
self.served_model_name = get_served_model_name(self.model,
613613
self.served_model_name)
614614
self.multimodal_config = self._init_multimodal_config()
615-
self.is_pooling_model = self.registry.is_pooling_model(self.architectures)
616-
self.model_supports_multimodal_raw_input = self._init_model_supports_multimodal_raw_input()
615+
self.is_pooling_model = self.registry.is_pooling_model(
616+
self.architectures)
617+
self.model_supports_multimodal_raw_input = (
618+
self._init_model_supports_multimodal_raw_input())
617619
if not self.skip_tokenizer_init:
618620
self._verify_tokenizer_mode()
619621

vllm/model_executor/models/interfaces.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def supports_multimodal(
129129

130130
return isinstance(model, SupportsMultiModal)
131131

132+
132133
@runtime_checkable
133134
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
134135
"""The interface required for all multi-modal models."""
@@ -143,29 +144,34 @@ class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
143144
MRO of your model class.
144145
"""
145146

147+
146148
@runtime_checkable
147149
class _SupportsMultiModalWithRawInput(Protocol):
148150
supports_multimodal_raw_input: ClassVar[Literal[True]]
149151

150152

151153
@overload
152-
def supports_multimodal_raw_input(model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
154+
def supports_multimodal_raw_input(
155+
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
153156
...
154157

155158

156159
@overload
157-
def supports_multimodal_raw_input(model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
160+
def supports_multimodal_raw_input(
161+
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
158162
...
159163

160164

161165
def supports_multimodal_raw_input(
162166
model: Union[type[object], object]
163-
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], TypeIs[SupportsMultiModalWithRawInput]]:
167+
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
168+
TypeIs[SupportsMultiModalWithRawInput]]:
164169
if isinstance(model, type):
165170
return isinstance(model, _SupportsMultiModalWithRawInput)
166171

167172
return isinstance(model, SupportsMultiModalWithRawInput)
168173

174+
169175
@runtime_checkable
170176
class SupportsLoRA(Protocol):
171177
"""The interface required for all models that support LoRA."""

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525

2626
from vllm.config import VllmConfig
2727
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28-
from vllm.model_executor.models.interfaces import (IsAttentionFree,
29-
SupportsMultiModalWithRawInput)
28+
from vllm.model_executor.models.interfaces import (
29+
IsAttentionFree, SupportsMultiModalWithRawInput)
3030
from vllm.model_executor.models.utils import AutoWeightsLoader
3131
from vllm.model_executor.pooling_metadata import PoolingMetadata
3232
from vllm.multimodal import MULTIMODAL_REGISTRY
33-
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalFieldElem,
34-
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
33+
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
34+
MultiModalFieldElem, MultiModalInputs,
35+
MultiModalKwargs, MultiModalKwargsItem,
3536
MultiModalSharedField, PlaceholderRange)
3637
from vllm.multimodal.parse import MultiModalDataItems
3738
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -62,7 +63,8 @@ def get_dummy_mm_data(
6263
# The size of pixel_values might change in the cases where we resize
6364
# the input but never exceeds the dimensions below.
6465
return {
65-
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
66+
"pixel_values": torch.full((6, 512, 512), 1.0,
67+
dtype=torch.float16),
6668
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
6769
}
6870

@@ -75,8 +77,10 @@ def _get_mm_fields_config(
7577
hf_processor_mm_kwargs: Mapping[str, object],
7678
) -> Mapping[str, MultiModalFieldConfig]:
7779
return dict(
78-
pixel_values=MultiModalFieldConfig.shared(batch_size=1, modality="image"),
79-
location_coords=MultiModalFieldConfig.shared(batch_size=1, modality="image"),
80+
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
81+
modality="image"),
82+
location_coords=MultiModalFieldConfig.shared(batch_size=1,
83+
modality="image"),
8084
)
8185

8286
def _get_prompt_updates(
@@ -99,15 +103,16 @@ def apply(
99103

100104
for k, v in mm_data.items():
101105
mm_kwargs[k] = v
102-
mm_place_holders = {
103-
"image": [PlaceholderRange(offset=0, length=0)]
104-
}
106+
mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]}
105107

106108
multimodal_kwargs_items = [
107-
MultiModalKwargsItem.from_elems(
108-
[MultiModalFieldElem(modality="image", key=key, data=data, field=MultiModalSharedField(1))
109-
for key, data in mm_kwargs.items()]
110-
)
109+
MultiModalKwargsItem.from_elems([
110+
MultiModalFieldElem(modality="image",
111+
key=key,
112+
data=data,
113+
field=MultiModalSharedField(1))
114+
for key, data in mm_kwargs.items()
115+
])
111116
]
112117

113118
return MultiModalInputs(
@@ -124,7 +129,8 @@ def apply(
124129
PrithviGeoSpatialMAEMultiModalProcessor,
125130
info=PrithviGeoSpatialMAEProcessingInfo,
126131
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
127-
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModalWithRawInput):
132+
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
133+
SupportsMultiModalWithRawInput):
128134
""" Prithvi Masked Autoencoder"""
129135

130136
@classmethod
@@ -189,13 +195,14 @@ def _parse_and_validate_multimodal_data(
189195
location_coords = None
190196

191197
return pixel_values, location_coords
192-
198+
193199
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
194-
# We do not really use any input tokens and therefore no embeddings to be calculated
195-
# However, due to the mandatory token ids in the input prompt we pass one token and the
196-
# size of the dummy embedding tensors must reflect that.
200+
# We do not really use any input tokens and therefore no embeddings
201+
# to be calculated. However, due to the mandatory token ids in
202+
# the input prompt we pass one token and the size of the dummy
203+
# embedding tensors must reflect that.
197204
return torch.empty(input_ids.shape)
198-
205+
199206
def forward(
200207
self,
201208
input_ids: Optional[torch.Tensor],
@@ -217,7 +224,10 @@ def pooler(
217224
hidden_states: torch.Tensor,
218225
pooling_metadata: PoolingMetadata,
219226
) -> Optional[PoolerOutput]:
220-
return PoolerOutput([PoolingSequenceGroupOutput(hidden_state) for hidden_state in hidden_states])
227+
return PoolerOutput([
228+
PoolingSequenceGroupOutput(hidden_state)
229+
for hidden_state in hidden_states
230+
])
221231

222232
def load_weights(self, weights: Iterable[tuple[str,
223233
torch.Tensor]]) -> set[str]:

vllm/model_executor/models/registry.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from .interfaces import (has_inner_state, has_noops, is_attention_free,
2525
is_hybrid, supports_cross_encoding,
2626
supports_multimodal, supports_multimodal_raw_input,
27-
supports_pp, supports_transcription,
28-
supports_v0_only)
27+
supports_pp, supports_transcription, supports_v0_only)
2928
from .interfaces_base import is_text_generation_model
3029

3130
logger = init_logger(__name__)
@@ -530,7 +529,7 @@ def is_multimodal_model(
530529
) -> bool:
531530
model_cls, _ = self.inspect_model_cls(architectures)
532531
return model_cls.supports_multimodal
533-
532+
534533
def supports_multimodal_raw_input(
535534
self,
536535
architectures: Union[str, list[str]],

vllm/v1/engine/core.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -140,24 +140,27 @@ def _initialize_kv_caches(
140140
# is attention free.
141141
kv_cache_specs = []
142142
kv_cache_configs = [
143-
KVCacheConfig(num_blocks=0, kv_cache_tensors={}, kv_cache_groups=[])
144-
]
143+
KVCacheConfig(num_blocks=0,
144+
kv_cache_tensors={},
145+
kv_cache_groups=[])
146+
]
145147
else:
146148
# Get all kv cache needed by the model
147149
kv_cache_specs = self.model_executor.get_kv_cache_specs()
148150

149-
# Profiles the peak memory usage of the model to determine how much
150-
# memory can be allocated for kv cache.
151-
available_gpu_memory = self.model_executor.determine_available_memory()
152-
153-
assert len(kv_cache_specs) == len(available_gpu_memory)
154-
# Get the kv cache tensor size
155-
kv_cache_configs = [
156-
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
157-
available_gpu_memory_one_worker)
158-
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
159-
zip(kv_cache_specs, available_gpu_memory)
160-
]
151+
# Profiles the peak memory usage of the model to determine how much
152+
# memory can be allocated for kv cache.
153+
available_gpu_memory = (
154+
self.model_executor.determine_available_memory())
155+
156+
assert len(kv_cache_specs) == len(available_gpu_memory)
157+
# Get the kv cache tensor size
158+
kv_cache_configs = [
159+
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
160+
available_gpu_memory_one_worker)
161+
for kv_cache_spec_one_worker, available_gpu_memory_one_worker
162+
in zip(kv_cache_specs, available_gpu_memory)
163+
]
161164

162165
# Since we use a shared centralized controller, we need the
163166
# `kv_cache_config` to be consistent across all workers to make sure

vllm/v1/engine/llm_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def __init__(
8282
self.dp_group = None
8383
self.should_execute_dummy_batch = False
8484

85-
86-
if not self.vllm_config.model_config.skip_tokenizer_init:
85+
if not self.vllm_config.model_config.skip_tokenizer_init:
8786
# Tokenizer (+ ensure liveness if running in another process).
8887
self.tokenizer = init_tokenizer_from_configs(
8988
model_config=vllm_config.model_config,

vllm/v1/engine/output_processor.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,13 @@ def add_request(
330330
tokenizer = None if not self.tokenizer else \
331331
self.tokenizer.get_lora_tokenizer(request.lora_request)
332332

333-
req_state = RequestState.from_new_request(
334-
tokenizer=tokenizer,
335-
request=request,
336-
prompt=prompt,
337-
parent_req=parent_req,
338-
request_index=request_index,
339-
queue=queue,
340-
log_stats=self.log_stats)
333+
req_state = RequestState.from_new_request(tokenizer=tokenizer,
334+
request=request,
335+
prompt=prompt,
336+
parent_req=parent_req,
337+
request_index=request_index,
338+
queue=queue,
339+
log_stats=self.log_stats)
341340
self.request_states[request_id] = req_state
342341
self.lora_states.add_request(req_state)
343342
if parent_req:

vllm/v1/engine/processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ def _validate_model_input(
390390
if tokenizer:
391391
max_input_id = max(prompt_ids, default=0)
392392
if max_input_id > tokenizer.max_token_id:
393-
raise ValueError(f"Token id {max_input_id} is out of vocabulary")
393+
raise ValueError(
394+
f"Token id {max_input_id} is out of vocabulary")
394395

395396
max_prompt_len = self.model_config.max_model_len
396397
if len(prompt_ids) > max_prompt_len:

0 commit comments

Comments
 (0)