Skip to content

Commit e0404c9

Browse files
committed
Clean up page size padding logic
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent c2da03e commit e0404c9

File tree

2 files changed

+32
-26
lines changed

2 files changed

+32
-26
lines changed

vllm/v1/kv_cache_interface.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from vllm.config import VllmConfig
1313
from vllm.logger import init_logger
14-
from vllm.utils import cdiv, get_dtype_size, round_up
14+
from vllm.utils import cdiv, get_dtype_size
1515

1616
logger = init_logger(__name__)
1717

@@ -159,7 +159,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
159159
class MambaSpec(KVCacheSpec):
160160
shapes: tuple[tuple[int, ...], ...]
161161
dtype: torch.dtype
162-
multiple_of: Optional[int]
162+
page_size_padded: Optional[int] = None
163163

164164
def __post_init__(self):
165165
self.num_elements = sum(prod(shape) for shape in self.shapes)
@@ -171,8 +171,9 @@ def type_id(self) -> str:
171171
@property
172172
def page_size_bytes(self) -> int:
173173
page_size = self.num_elements * get_dtype_size(self.dtype)
174-
if self.multiple_of is not None:
175-
page_size = round_up(page_size, self.multiple_of)
174+
if self.page_size_padded is not None:
175+
assert self.page_size_padded >= page_size
176+
return self.page_size_padded
176177
return page_size
177178

178179
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2575,15 +2575,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
25752575
mamba_layers = get_layers_from_vllm_config(self.vllm_config,
25762576
MambaMixer2)
25772577
if len(mamba_layers) > 0:
2578-
if len(attn_layers) > 0:
2579-
# Mamba state must be padded to an integer number of
2580-
# 16th tokens worth of attention pages
2581-
attn_layer_name = next(iter(attn_layers))
2582-
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
2583-
multiple_of = 16 * attn_page_size // block_size
2584-
else:
2585-
multiple_of = None
2586-
25872578
if self.vllm_config.speculative_config is not None:
25882579
raise NotImplementedError(
25892580
"Mamba with speculative decoding is not supported yet.")
@@ -2594,25 +2585,39 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
25942585
raise NotImplementedError(
25952586
"Prefix caching is not supported for Mamba yet.")
25962587
max_model_len = self.vllm_config.model_config.max_model_len
2588+
2589+
if len(attn_layers) > 0:
2590+
attn_layer_name = next(iter(attn_layers))
2591+
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
2592+
mamba_layer_name = next(iter(mamba_layers))
2593+
mamba_page_size = MambaSpec(
2594+
shapes=mamba_layers[mamba_layer_name].get_state_shape(),
2595+
dtype=self.kv_cache_dtype,
2596+
block_size=max_model_len).page_size_bytes
2597+
if attn_page_size < mamba_page_size:
2598+
# attention page size (for 16 tokens)
2599+
attn_page_size_16 = 16 * attn_page_size // block_size
2600+
# some attention backends (e.g. FA) only support setting
2601+
# block size to multiple of 16, so let's suggest a value
2602+
# that would work (note: FA is currently not compatible
2603+
# with mamba layers, use FlashInfer instead).
2604+
suggest_attn_block_size = 16 * cdiv(
2605+
mamba_page_size, attn_page_size_16)
2606+
raise ValueError(
2607+
"Attention block size should be increased to at least "
2608+
f"{suggest_attn_block_size} in order to match "
2609+
"the mamba page size")
2610+
page_size_padded = attn_page_size
2611+
else:
2612+
page_size_padded = None
2613+
25972614
# Set block_size to max_model_len, so that mamba model will always
25982615
# have only one block in the KV cache.
25992616
for layer_name, mamba_module in mamba_layers.items():
26002617
kv_cache_spec[layer_name] = MambaSpec(
26012618
shapes=mamba_module.get_state_shape(),
26022619
dtype=self.kv_cache_dtype,
26032620
block_size=max_model_len,
2604-
multiple_of=multiple_of)
2605-
2606-
if len(attn_layers) > 0:
2607-
mamba_layer_name = next(iter(mamba_layers))
2608-
mamba_page_size = kv_cache_spec[
2609-
mamba_layer_name].page_size_bytes
2610-
if attn_page_size < mamba_page_size:
2611-
required_attn_block_size = cdiv(mamba_page_size,
2612-
multiple_of) * 16
2613-
raise ValueError(
2614-
"Attention block size must be increased to "
2615-
f"{required_attn_block_size} in order to match "
2616-
"the mamba page size")
2621+
page_size_padded=page_size_padded)
26172622

26182623
return kv_cache_spec

0 commit comments

Comments
 (0)