Skip to content

Commit c7f80c0

Browse files
Some reformatting to make the pre-commit hooks succeed
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent b53a07d commit c7f80c0

File tree

14 files changed

+159
-80
lines changed

14 files changed

+159
-80
lines changed

examples/offline_inference/prithvi_geospatial_mae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(self):
144144
model=os.path.join(os.path.dirname(__file__), "./model"),
145145
skip_tokenizer_init=True,
146146
dtype="float16",
147-
enforce_eager=True
147+
enforce_eager=True,
148148
)
149149

150150
def run(self, input_data, location_coords):
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
7+
from ....conftest import VllmRunner
8+
9+
def generate_test_mm_data():
10+
mm_data = {
11+
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
12+
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
13+
}
14+
return mm_data
15+
16+
def _run_test(
17+
vllm_runner: type[VllmRunner],
18+
model: str,
19+
) -> None:
20+
21+
mm_data = generate_test_mm_data()
22+
prompt = {
23+
# This model deals with no text input
24+
"prompt_token_ids": [1],
25+
"multi_modal_data": mm_data
26+
}
27+
with vllm_runner(model, task="embed",
28+
dtype=torch.float16,
29+
enforce_eager=True,
30+
skip_tokenizer_init=True) as vllm_model:
31+
output = vllm_model.encode(prompt)
32+
33+
MODELS=["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
34+
@pytest.mark.parametrize("model", MODELS)
35+
def test_models_image(
36+
hf_runner,
37+
vllm_runner,
38+
image_assets,
39+
model: str,
40+
) -> None:
41+
_run_test(
42+
vllm_runner,
43+
model,
44+
)

vllm/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,8 +614,10 @@ def __post_init__(self) -> None:
614614
self.served_model_name = get_served_model_name(self.model,
615615
self.served_model_name)
616616
self.multimodal_config = self._init_multimodal_config()
617-
self.is_pooling_model = self.registry.is_pooling_model(self.architectures)
618-
self.model_supports_multimodal_raw_input = self._init_model_supports_multimodal_raw_input()
617+
self.is_pooling_model = self.registry.is_pooling_model(
618+
self.architectures)
619+
self.model_supports_multimodal_raw_input = (
620+
self._init_model_supports_multimodal_raw_input())
619621
if not self.skip_tokenizer_init:
620622
self._verify_tokenizer_mode()
621623

vllm/model_executor/models/interfaces.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def supports_multimodal(
120120

121121
return isinstance(model, SupportsMultiModal)
122122

123+
123124
@runtime_checkable
124125
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
125126
"""The interface required for all multi-modal models."""
@@ -134,29 +135,34 @@ class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
134135
MRO of your model class.
135136
"""
136137

138+
137139
@runtime_checkable
138140
class _SupportsMultiModalWithRawInput(Protocol):
139141
supports_multimodal_raw_input: ClassVar[Literal[True]]
140142

141143

142144
@overload
143-
def supports_multimodal_raw_input(model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
145+
def supports_multimodal_raw_input(
146+
model: object) -> TypeIs[SupportsMultiModalWithRawInput]:
144147
...
145148

146149

147150
@overload
148-
def supports_multimodal_raw_input(model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
151+
def supports_multimodal_raw_input(
152+
model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]:
149153
...
150154

151155

152156
def supports_multimodal_raw_input(
153157
model: Union[type[object], object]
154-
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], TypeIs[SupportsMultiModalWithRawInput]]:
158+
) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]],
159+
TypeIs[SupportsMultiModalWithRawInput]]:
155160
if isinstance(model, type):
156161
return isinstance(model, _SupportsMultiModalWithRawInput)
157162

158163
return isinstance(model, SupportsMultiModalWithRawInput)
159164

165+
160166
@runtime_checkable
161167
class SupportsLoRA(Protocol):
162168
"""The interface required for all models that support LoRA."""

vllm/model_executor/models/prithvi_geospatial_mae.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525

2626
from vllm.config import VllmConfig
2727
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28-
from vllm.model_executor.models.interfaces import (IsAttentionFree,
29-
SupportsMultiModalWithRawInput)
28+
from vllm.model_executor.models.interfaces import (
29+
IsAttentionFree, SupportsMultiModalWithRawInput)
3030
from vllm.model_executor.models.utils import AutoWeightsLoader
3131
from vllm.model_executor.pooling_metadata import PoolingMetadata
3232
from vllm.multimodal import MULTIMODAL_REGISTRY
33-
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalFieldElem,
34-
MultiModalInputs, MultiModalKwargs, MultiModalKwargsItem,
33+
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
34+
MultiModalFieldElem, MultiModalInputs,
35+
MultiModalKwargs, MultiModalKwargsItem,
3536
MultiModalSharedField, PlaceholderRange)
3637
from vllm.multimodal.parse import MultiModalDataItems
3738
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -62,7 +63,8 @@ def get_dummy_mm_data(
6263
# The size of pixel_values might change in the cases where we resize
6364
# the input but never exceeds the dimensions below.
6465
return {
65-
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
66+
"pixel_values": torch.full((6, 512, 512), 1.0,
67+
dtype=torch.float16),
6668
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
6769
}
6870

@@ -75,8 +77,10 @@ def _get_mm_fields_config(
7577
hf_processor_mm_kwargs: Mapping[str, object],
7678
) -> Mapping[str, MultiModalFieldConfig]:
7779
return dict(
78-
pixel_values=MultiModalFieldConfig.shared(batch_size=1, modality="image"),
79-
location_coords=MultiModalFieldConfig.shared(batch_size=1, modality="image"),
80+
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
81+
modality="image"),
82+
location_coords=MultiModalFieldConfig.shared(batch_size=1,
83+
modality="image"),
8084
)
8185

8286
def _get_prompt_updates(
@@ -98,15 +102,16 @@ def apply(
98102

99103
for k, v in mm_data.items():
100104
mm_kwargs[k] = v
101-
mm_place_holders = {
102-
"image": [PlaceholderRange(offset=0, length=0)]
103-
}
105+
mm_place_holders = {"image": [PlaceholderRange(offset=0, length=0)]}
104106

105107
multimodal_kwargs_items = [
106-
MultiModalKwargsItem.from_elems(
107-
[MultiModalFieldElem(modality="image", key=key, data=data, field=MultiModalSharedField(1))
108-
for key, data in mm_kwargs.items()]
109-
)
108+
MultiModalKwargsItem.from_elems([
109+
MultiModalFieldElem(modality="image",
110+
key=key,
111+
data=data,
112+
field=MultiModalSharedField(1))
113+
for key, data in mm_kwargs.items()
114+
])
110115
]
111116

112117
return MultiModalInputs(
@@ -123,7 +128,8 @@ def apply(
123128
PrithviGeoSpatialMAEMultiModalProcessor,
124129
info=PrithviGeoSpatialMAEProcessingInfo,
125130
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
126-
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModalWithRawInput):
131+
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree,
132+
SupportsMultiModalWithRawInput):
127133
""" Prithvi Masked Autoencoder"""
128134

129135
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
@@ -181,13 +187,14 @@ def _parse_and_validate_multimodal_data(
181187
location_coords = None
182188

183189
return pixel_values, location_coords
184-
190+
185191
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
186-
# We do not really use any input tokens and therefore no embeddings to be calculated
187-
# However, due to the mandatory token ids in the input prompt we pass one token and the
188-
# size of the dummy embedding tensors must reflect that.
192+
# We do not really use any input tokens and therefore no embeddings
193+
# to be calculated. However, due to the mandatory token ids in
194+
# the input prompt we pass one token and the size of the dummy
195+
# embedding tensors must reflect that.
189196
return torch.empty(input_ids.shape)
190-
197+
191198
def forward(
192199
self,
193200
input_ids: Optional[torch.Tensor],
@@ -209,7 +216,10 @@ def pooler(
209216
hidden_states: torch.Tensor,
210217
pooling_metadata: PoolingMetadata,
211218
) -> Optional[PoolerOutput]:
212-
return PoolerOutput([PoolingSequenceGroupOutput(hidden_state) for hidden_state in hidden_states])
219+
return PoolerOutput([
220+
PoolingSequenceGroupOutput(hidden_state)
221+
for hidden_state in hidden_states
222+
])
213223

214224
def load_weights(self, weights: Iterable[tuple[str,
215225
torch.Tensor]]) -> set[str]:

vllm/model_executor/models/registry.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
from .interfaces import (has_inner_state, has_noops, is_attention_free,
2525
is_hybrid, supports_cross_encoding,
2626
supports_multimodal, supports_multimodal_raw_input,
27-
supports_pp, supports_transcription,
28-
supports_v0_only)
27+
supports_pp, supports_transcription, supports_v0_only)
2928
from .interfaces_base import is_text_generation_model
3029

3130
logger = init_logger(__name__)
@@ -522,7 +521,7 @@ def is_multimodal_model(
522521
) -> bool:
523522
model_cls, _ = self.inspect_model_cls(architectures)
524523
return model_cls.supports_multimodal
525-
524+
526525
def supports_multimodal_raw_input(
527526
self,
528527
architectures: Union[str, list[str]],

vllm/v1/core/kv_cache_manager.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def new_empty(self) -> "KVCacheBlocks":
6363
"""Creates a new KVCacheBlocks instance with no blocks."""
6464
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
6565

66+
6667
class DummyKVCacheManager:
68+
6769
@property
6870
def usage(self) -> float:
6971
return 0.0
@@ -73,7 +75,7 @@ def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
7375

7476
def get_computed_blocks(self,
7577
request: Request) -> tuple[KVCacheBlocks, int]:
76-
return(KVCacheBlocks([]), 0)
78+
return (KVCacheBlocks([]), 0)
7779

7880
def allocate_slots(
7981
self,
@@ -111,6 +113,15 @@ def get_block_ids(self, request_id: str) -> list[list[int]]:
111113
"""Get the block ids of a request."""
112114
return []
113115

116+
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
117+
"""Cache the blocks for the request, if enabled."""
118+
pass
119+
120+
def create_empty_block_list(self) -> KVCacheBlocks:
121+
"""Creates a new KVCacheBlocks instance with no blocks."""
122+
return (KVCacheBlocks([]), 0)
123+
124+
114125
class KVCacheManager:
115126

116127
def __init__(

vllm/v1/core/sched/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
1919
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
2020
compute_encoder_budget)
21-
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager, DummyKVCacheManager
21+
from vllm.v1.core.kv_cache_manager import DummyKVCacheManager, KVCacheManager
2222
from vllm.v1.core.sched.interface import SchedulerInterface
2323
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
2424
SchedulerOutput)

vllm/v1/engine/core.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,23 +139,26 @@ def _initialize_kv_caches(
139139
# is attention free.
140140
kv_cache_specs = []
141141
kv_cache_configs = [
142-
KVCacheConfig(num_blocks=0, kv_cache_tensors={}, kv_cache_groups=[])
143-
]
142+
KVCacheConfig(num_blocks=0,
143+
kv_cache_tensors={},
144+
kv_cache_groups=[])
145+
]
144146
else:
145147
# Get all kv cache needed by the model
146148
kv_cache_specs = self.model_executor.get_kv_cache_specs()
147149

148150
# Profiles the peak memory usage of the model to determine how much
149151
# memory can be allocated for kv cache.
150-
available_gpu_memory = self.model_executor.determine_available_memory()
152+
available_gpu_memory = (
153+
self.model_executor.determine_available_memory())
151154

152155
assert len(kv_cache_specs) == len(available_gpu_memory)
153156
# Get the kv cache tensor size
154157
kv_cache_configs = [
155158
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
156159
available_gpu_memory_one_worker)
157-
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
158-
zip(kv_cache_specs, available_gpu_memory)
160+
for kv_cache_spec_one_worker, available_gpu_memory_one_worker
161+
in zip(kv_cache_specs, available_gpu_memory)
159162
]
160163

161164
# Since we use a shared centralized controller, we need the
@@ -194,7 +197,6 @@ def add_request(self, request: EngineCoreRequest):
194197
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
195198
request.mm_inputs, request.mm_hashes)
196199

197-
198200
req = Request.from_engine_core_request(request)
199201
if req.use_structured_output:
200202
# Start grammar compilation asynchronously

vllm/v1/engine/llm_engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ def __init__(
8282
self.dp_group = None
8383
self.should_execute_dummy_batch = False
8484

85-
86-
if not self.vllm_config.model_config.skip_tokenizer_init:
85+
if not self.vllm_config.model_config.skip_tokenizer_init:
8786
# Tokenizer (+ ensure liveness if running in another process).
8887
self.tokenizer = init_tokenizer_from_configs(
8988
model_config=vllm_config.model_config,

0 commit comments

Comments
 (0)