Skip to content

Commit 0dba4cd

Browse files
Few more changes to solve some other pre-commit hooks failures
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
1 parent c708266 commit 0dba4cd

File tree

4 files changed

+89
-19
lines changed

4 files changed

+89
-19
lines changed

tests/models/multimodal/pooling/test_prithvi_mae.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,35 @@
66

77
from ....conftest import VllmRunner
88

9+
910
def generate_test_mm_data():
1011
mm_data = {
1112
"pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16),
1213
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
1314
}
1415
return mm_data
15-
16+
17+
1618
def _run_test(
1719
vllm_runner: type[VllmRunner],
1820
model: str,
19-
) -> None:
21+
) -> None:
2022

2123
mm_data = generate_test_mm_data()
2224
prompt = {
2325
# This model deals with no text input
2426
"prompt_token_ids": [1],
2527
"multi_modal_data": mm_data
2628
}
27-
with vllm_runner(model, task="embed",
28-
dtype=torch.float16,
29-
enforce_eager=True,
30-
skip_tokenizer_init=True) as vllm_model:
31-
output = vllm_model.encode(prompt)
32-
33-
MODELS=["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
29+
with vllm_runner(model,
30+
task="embed",
31+
dtype=torch.float16,
32+
enforce_eager=True,
33+
skip_tokenizer_init=True) as vllm_model:
34+
vllm_model.encode(prompt)
35+
36+
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
37+
3438
@pytest.mark.parametrize("model", MODELS)
3539
def test_models_image(
3640
hf_runner,
@@ -41,4 +45,4 @@ def test_models_image(
4145
_run_test(
4246
vllm_runner,
4347
model,
44-
)
48+
)

vllm/v1/core/kv_cache_manager.py

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from abc import ABC, abstractmethod
45
from collections import defaultdict
56
from dataclasses import dataclass
67
from typing import Optional
@@ -64,7 +65,72 @@ def new_empty(self) -> "KVCacheBlocks":
6465
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
6566

6667

67-
class DummyKVCacheManager:
68+
class KVCacheManagerInterface(ABC):
69+
70+
@abstractmethod
71+
def usage(self) -> float:
72+
raise NotImplementedError
73+
74+
@abstractmethod
75+
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
76+
raise NotImplementedError
77+
78+
@abstractmethod
79+
def get_computed_blocks(self,
80+
request: Request) -> tuple[KVCacheBlocks, int]:
81+
raise NotImplementedError
82+
83+
@abstractmethod
84+
def allocate_slots(
85+
self,
86+
request: Request,
87+
num_new_tokens: int,
88+
num_new_computed_tokens: int = 0,
89+
new_computed_blocks: Optional[KVCacheBlocks] = None,
90+
num_draft_tokens: int = 0,
91+
num_lookahead_tokens: int = 0,
92+
delay_cache_blocks: bool = False,
93+
) -> Optional[KVCacheBlocks]:
94+
raise NotImplementedError
95+
96+
@abstractmethod
97+
def free(self, request: Request) -> None:
98+
raise NotImplementedError
99+
100+
@abstractmethod
101+
def reset_prefix_cache(self) -> bool:
102+
raise NotImplementedError
103+
104+
@abstractmethod
105+
def get_num_common_prefix_blocks(
106+
self,
107+
request: Request,
108+
num_running_requests: int,
109+
) -> list[int]:
110+
raise NotImplementedError
111+
112+
@abstractmethod
113+
def free_block_hashes(self, request: Request) -> None:
114+
raise NotImplementedError
115+
116+
@abstractmethod
117+
def take_events(self) -> list[KVCacheEvent]:
118+
raise NotImplementedError
119+
120+
@abstractmethod
121+
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
122+
raise NotImplementedError
123+
124+
@abstractmethod
125+
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
126+
raise NotImplementedError
127+
128+
@abstractmethod
129+
def create_empty_block_list(self) -> KVCacheBlocks:
130+
raise NotImplementedError
131+
132+
133+
class DummyKVCacheManager(KVCacheManagerInterface):
68134

69135
@property
70136
def usage(self) -> float:
@@ -88,7 +154,7 @@ def allocate_slots(
88154
delay_cache_blocks: bool = False,
89155
) -> Optional[KVCacheBlocks]:
90156
#if we do not return a KV cache block requests are unschedulable
91-
return KVCacheBlocks([KVCacheBlock(block_id=0)])
157+
return KVCacheBlocks(tuple([KVCacheBlock(block_id=0)]))
92158

93159
def free(self, request: Request) -> None:
94160
pass
@@ -109,20 +175,20 @@ def free_block_hashes(self, request: Request) -> None:
109175
def take_events(self) -> list[KVCacheEvent]:
110176
return []
111177

112-
def get_block_ids(self, request_id: str) -> list[list[int]]:
178+
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
113179
"""Get the block ids of a request."""
114-
return []
180+
return tuple([])
115181

116182
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
117183
"""Cache the blocks for the request, if enabled."""
118184
pass
119185

120186
def create_empty_block_list(self) -> KVCacheBlocks:
121187
"""Creates a new KVCacheBlocks instance with no blocks."""
122-
return (KVCacheBlocks([]), 0)
188+
return KVCacheBlocks(tuple([]))
123189

124190

125-
class KVCacheManager:
191+
class KVCacheManager(KVCacheManagerInterface):
126192

127193
def __init__(
128194
self,

vllm/v1/core/sched/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,8 +493,8 @@ def schedule(self) -> SchedulerOutput:
493493

494494
if self.lora_config and request.lora_request:
495495
scheduled_loras.add(request.lora_request.lora_int_id)
496-
req_to_new_block_ids[request.request_id] = (
497-
self.kv_cache_manager.get_block_ids(request.request_id))
496+
req_to_new_block_ids[request.request_id] = \
497+
self.kv_cache_manager.get_block_ids(request.request_id)
498498
num_scheduled_tokens[request.request_id] = num_new_tokens
499499
token_budget -= num_new_tokens
500500
request.status = RequestStatus.RUNNING

vllm/v1/engine/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _initialize_kv_caches(
140140
kv_cache_specs = []
141141
kv_cache_configs = [
142142
KVCacheConfig(num_blocks=0,
143-
kv_cache_tensors={},
143+
kv_cache_tensors=[],
144144
kv_cache_groups=[])
145145
]
146146
else:

0 commit comments

Comments
 (0)