Skip to content

Commit bec5628

Browse files
Support for attention free models in V1
1 parent 9a3b883 commit bec5628

File tree

3 files changed

+91
-31
lines changed

3 files changed

+91
-31
lines changed

vllm/v1/core/kv_cache_manager.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,53 @@ def new_empty(self) -> "KVCacheBlocks":
6363
"""Creates a new KVCacheBlocks instance with no blocks."""
6464
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
6565

66+
class DummyKVCacheManager:
67+
@property
68+
def usage(self) -> float:
69+
return 0.0
70+
71+
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
72+
return None
73+
74+
def get_computed_blocks(self,
75+
request: Request) -> tuple[KVCacheBlocks, int]:
76+
return(KVCacheBlocks([]), 0)
77+
78+
def allocate_slots(
79+
self,
80+
request: Request,
81+
num_new_tokens: int,
82+
num_new_computed_tokens: int = 0,
83+
new_computed_blocks: Optional[KVCacheBlocks] = None,
84+
num_draft_tokens: int = 0,
85+
num_lookahead_tokens: int = 0,
86+
delay_cache_blocks: bool = False,
87+
) -> Optional[KVCacheBlocks]:
88+
#if we do not return a KV cache block requests are unschedulable
89+
return KVCacheBlocks([KVCacheBlock(block_id=0)])
90+
91+
def free(self, request: Request) -> None:
92+
pass
93+
94+
def reset_prefix_cache(self) -> bool:
95+
return True
96+
97+
def get_num_common_prefix_blocks(
98+
self,
99+
request: Request,
100+
num_running_requests: int,
101+
) -> list[int]:
102+
return []
103+
104+
def free_block_hashes(self, request: Request) -> None:
105+
pass
106+
107+
def take_events(self) -> list[KVCacheEvent]:
108+
return []
109+
110+
def get_block_ids(self, request_id: str) -> list[list[int]]:
111+
"""Get the block ids of a request."""
112+
return []
66113

67114
class KVCacheManager:
68115

vllm/v1/core/sched/scheduler.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
1919
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
2020
compute_encoder_budget)
21-
from vllm.v1.core.kv_cache_manager import KVCacheManager
21+
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager, DummyKVCacheManager
2222
from vllm.v1.core.sched.interface import SchedulerInterface
2323
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
2424
SchedulerOutput)
@@ -90,7 +90,8 @@ def __init__(
9090
)
9191

9292
num_gpu_blocks = self.cache_config.num_gpu_blocks
93-
assert num_gpu_blocks is not None and num_gpu_blocks > 0
93+
# num_gpu_blocks can be ero for attention free models
94+
assert num_gpu_blocks is not None
9495

9596
self.block_size = self.cache_config.block_size
9697

@@ -155,15 +156,18 @@ def __init__(
155156
self.num_lookahead_tokens = self.num_spec_tokens
156157

157158
# Create the KV cache manager.
158-
self.kv_cache_manager = KVCacheManager(
159-
kv_cache_config=kv_cache_config,
160-
max_model_len=self.max_model_len,
161-
enable_caching=self.cache_config.enable_prefix_caching,
162-
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
163-
use_eagle=self.use_eagle,
164-
log_stats=self.log_stats,
165-
enable_kv_cache_events=self.enable_kv_cache_events,
166-
)
159+
if self.cache_config.is_attention_free:
160+
self.kv_cache_manager = DummyKVCacheManager()
161+
else:
162+
self.kv_cache_manager = KVCacheManager(
163+
kv_cache_config=kv_cache_config,
164+
max_model_len=self.max_model_len,
165+
enable_caching=self.cache_config.enable_prefix_caching,
166+
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
167+
use_eagle=self.use_eagle,
168+
log_stats=self.log_stats,
169+
enable_kv_cache_events=self.enable_kv_cache_events,
170+
)
167171

168172
def schedule(self) -> SchedulerOutput:
169173
# NOTE(woosuk) on the scheduling algorithm:

vllm/v1/engine/core.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -134,26 +134,34 @@ def _initialize_kv_caches(
134134
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
135135
start = time.time()
136136

137-
# Get all kv cache needed by the model
138-
kv_cache_specs = self.model_executor.get_kv_cache_specs()
139-
140-
# Profiles the peak memory usage of the model to determine how much
141-
# memory can be allocated for kv cache.
142-
available_gpu_memory = self.model_executor.determine_available_memory()
143-
144-
assert len(kv_cache_specs) == len(available_gpu_memory)
145-
# Get the kv cache tensor size
146-
kv_cache_configs = [
147-
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
148-
available_gpu_memory_one_worker)
149-
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
150-
zip(kv_cache_specs, available_gpu_memory)
151-
]
152-
153-
# Since we use a shared centralized controller, we need the
154-
# `kv_cache_config` to be consistent across all workers to make sure
155-
# all the memory operators can be applied to all workers.
156-
unify_kv_cache_configs(kv_cache_configs)
137+
if vllm_config.model_config.is_attention_free:
138+
# No need for initializing anything related to KV cache if the model
139+
# is attention free.
140+
kv_cache_specs = []
141+
kv_cache_configs = [
142+
KVCacheConfig(num_blocks=0, tensors={}, kv_cache_groups=[])
143+
]
144+
else:
145+
# Get all kv cache needed by the model
146+
kv_cache_specs = self.model_executor.get_kv_cache_specs()
147+
148+
# Profiles the peak memory usage of the model to determine how much
149+
# memory can be allocated for kv cache.
150+
available_gpu_memory = self.model_executor.determine_available_memory()
151+
152+
assert len(kv_cache_specs) == len(available_gpu_memory)
153+
# Get the kv cache tensor size
154+
kv_cache_configs = [
155+
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
156+
available_gpu_memory_one_worker)
157+
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
158+
zip(kv_cache_specs, available_gpu_memory)
159+
]
160+
161+
# Since we use a shared centralized controller, we need the
162+
# `kv_cache_config` to be consistent across all workers to make sure
163+
# all the memory operators can be applied to all workers.
164+
unify_kv_cache_configs(kv_cache_configs)
157165

158166
# All workers have the same kv_cache_config except layer names, so use
159167
# an arbitrary one to initialize the scheduler.
@@ -186,6 +194,7 @@ def add_request(self, request: EngineCoreRequest):
186194
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
187195
request.mm_inputs, request.mm_hashes)
188196

197+
189198
req = Request.from_engine_core_request(request)
190199
if req.use_structured_output:
191200
# Start grammar compilation asynchronously

0 commit comments

Comments
 (0)