Skip to content

Commit fda49c3

Browse files
committed
Moving work from PR vllm-project#20499
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent b4f0b5f commit fda49c3

File tree

4 files changed

+201
-70
lines changed

4 files changed

+201
-70
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,6 @@
6161
"tiiuae/Falcon-H1-0.5B-Base",
6262
]
6363

64-
ATTN_BLOCK_SIZES = {
65-
"ibm-ai-platform/Bamba-9B-v1": 528,
66-
"Zyphra/Zamba2-1.2B-instruct": 80,
67-
"nvidia/Nemotron-H-8B-Base-8K": 528,
68-
"ibm-granite/granite-4.0-tiny-preview": 400,
69-
"tiiuae/Falcon-H1-0.5B-Base": 800,
70-
}
71-
7264
# Avoid OOM
7365
MAX_NUM_SEQS = 4
7466

@@ -105,11 +97,6 @@ def test_models(
10597
example_prompts, max_tokens, num_logprobs)
10698

10799
if model in V1_SUPPORTED_MODELS:
108-
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
109-
block_size = ATTN_BLOCK_SIZES[model]
110-
else:
111-
block_size = 16
112-
113100
with monkeypatch.context() as m:
114101
m.setenv("VLLM_USE_V1", "1")
115102
if model in HYBRID_MODELS:
@@ -118,8 +105,7 @@ def test_models(
118105
with vllm_runner(model,
119106
max_num_seqs=MAX_NUM_SEQS,
120107
enforce_eager=True,
121-
enable_prefix_caching=False,
122-
block_size=block_size) as vllm_model:
108+
enable_prefix_caching=False) as vllm_model:
123109
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
124110
example_prompts, max_tokens, num_logprobs)
125111
else:

vllm/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,9 @@ class CacheConfig:
15641564
checkpoint if available. Otherwise, the scales will default to 1.0."""
15651565
cpu_kvcache_space_bytes: Optional[int] = None
15661566
"""(CPU backend only) CPU key-value cache space."""
1567+
mamba_page_size_padded: Optional[int] = None
1568+
""" Optional override for mamba page size; used by hybrid mamaba/attention
1569+
models to ensure exact alignment with attention page size."""
15671570

15681571
# Will be set after profiling.
15691572
num_gpu_blocks: Optional[int] = field(default=None, init=False)

vllm/model_executor/models/config.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from copy import deepcopy
4+
from dataclasses import dataclass
45
from typing import TYPE_CHECKING
56

7+
import vllm.envs as envs
8+
from vllm.distributed import divide
69
from vllm.logger import init_logger
10+
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
11+
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
712

813
if TYPE_CHECKING:
14+
from transformers.configuration_utils import PretrainedConfig
15+
916
from vllm.config import VllmConfig
1017

1118
logger = init_logger(__name__)
@@ -200,11 +207,198 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
200207
}
201208

202209

210+
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
211+
212+
@classmethod
213+
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int) -> int:
214+
"""Compute the increase in group numbers to account for
215+
replication in order to accompany the head shards."""
216+
217+
# in the case ngoups % tp_size == 0, this will be zero
218+
if ngroups % tp_size == 0:
219+
return 0
220+
221+
# for n_groups == 1, this is exactly tp_size - n_groups
222+
return tp_size - ngroups
223+
224+
@dataclass
225+
class MambaConfig:
226+
expand: int
227+
n_groups: int
228+
n_heads: int
229+
d_head: int
230+
d_state: int
231+
d_conv: int
232+
233+
@classmethod
234+
def parse_mamba_config(cls, config: "PretrainedConfig") -> MambaConfig:
235+
return cls.MambaConfig(
236+
expand=config.mamba_expand,
237+
n_groups=config.mamba_n_groups,
238+
n_heads=config.mamba_n_heads,
239+
d_head=config.mamba_d_head,
240+
d_state=config.mamba_d_state,
241+
d_conv=config.mamba_d_conv,
242+
)
243+
244+
@classmethod
245+
def get_mamba_cache_shape(
246+
cls, vllm_config: "VllmConfig"
247+
) -> tuple[tuple[int, int], tuple[int, int]]:
248+
249+
parallel_config = vllm_config.parallel_config
250+
hf_config = vllm_config.model_config.hf_config
251+
mamba_config = cls.parse_mamba_config(hf_config)
252+
253+
world_size = parallel_config.tensor_parallel_size
254+
hidden_size = hf_config.hidden_size
255+
intermediate_size = mamba_config.expand * hidden_size
256+
257+
# if n_groups is not divisible by world_size, need to extend the shards
258+
# to ensure all groups needed by a head is sharded along with it
259+
n_groups = (mamba_config.n_groups + cls.extra_groups_for_head_shards(
260+
mamba_config.n_groups, world_size))
261+
262+
# - heads and n_groups are TP-ed
263+
conv_dim = (intermediate_size + 2 * n_groups * mamba_config.d_state)
264+
conv_state_shape = (
265+
divide(conv_dim, world_size),
266+
mamba_config.d_conv - 1,
267+
)
268+
269+
# These are not TP-ed as they depend on A, dt_bias, D
270+
# - they are typically small
271+
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
272+
temporal_state_shape = (
273+
divide(mamba_config.n_heads, world_size),
274+
mamba_config.d_head,
275+
mamba_config.d_state,
276+
)
277+
278+
return conv_state_shape, temporal_state_shape
279+
280+
@classmethod
281+
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
282+
"""
283+
Ensure that page size of attention layers is greater than or
284+
equal to the mamba layers. If not, automatically set the attention
285+
block size to ensure that it is. If the attention page size is
286+
strictly greater than the mamba page size, we pad the mamba page size
287+
to make them equal.
288+
289+
Args:
290+
vllm_config: vLLM Config
291+
"""
292+
293+
if not envs.VLLM_USE_V1:
294+
return
295+
296+
cache_config = vllm_config.cache_config
297+
model_config = vllm_config.model_config
298+
parallel_config = vllm_config.parallel_config
299+
300+
if cache_config.cache_dtype == "auto":
301+
kv_cache_dtype = model_config.dtype
302+
else:
303+
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
304+
305+
# get attention page size (for 1 token)
306+
attn_page_size_1_token = FullAttentionSpec(
307+
block_size=1,
308+
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
309+
head_size=model_config.get_head_size(),
310+
dtype=kv_cache_dtype,
311+
use_mla=model_config.use_mla).page_size_bytes
312+
313+
# get mamba page size
314+
mamba_page_size = MambaSpec(
315+
shapes=cls.get_mamba_cache_shape(vllm_config),
316+
dtype=kv_cache_dtype,
317+
block_size=model_config.max_model_len,
318+
).page_size_bytes
319+
320+
# some attention backends (e.g. FA) only support setting
321+
# block size to multiple of 16, so let's suggest a value
322+
# that would work (note: FA is currently not compatible
323+
# with mamba layers, use FlashInfer instead).
324+
attn_block_size = 16 * cdiv(mamba_page_size,
325+
16 * attn_page_size_1_token)
326+
327+
# override attention block size if either (a) the
328+
# user has not set it or (b) the user has set it
329+
# too small.
330+
if (cache_config.block_size is None
331+
or cache_config.block_size < attn_block_size):
332+
cache_config.block_size = attn_block_size
333+
logger.info(
334+
"Setting attention block size to %d tokens "
335+
"to ensure that attention page size is >= mamba page size.",
336+
attn_block_size)
337+
338+
# compute new attention page size
339+
attn_page_size = \
340+
cache_config.block_size * attn_page_size_1_token
341+
342+
assert attn_page_size >= mamba_page_size
343+
344+
if attn_page_size == mamba_page_size:
345+
# don't need to pad mamba page size
346+
return
347+
348+
# pad mamba page size to exactly match attention
349+
if (cache_config.mamba_page_size_padded is None
350+
or cache_config.mamba_page_size_padded != attn_page_size):
351+
cache_config.mamba_page_size_padded = (attn_page_size)
352+
mamba_padding_pct = 100 * (attn_page_size -
353+
mamba_page_size) / mamba_page_size
354+
logger.info(
355+
"Padding mamba page size by %.2f%% to ensure "
356+
"that mamba page size and attention page size are "
357+
"exactly equal.", mamba_padding_pct)
358+
359+
360+
class NemotronHModelConfig(HybridAttentionMambaModelConfig):
361+
362+
@classmethod
363+
def parse_mamba_config(
364+
cls, config: "PretrainedConfig"
365+
) -> HybridAttentionMambaModelConfig.MambaConfig:
366+
return HybridAttentionMambaModelConfig.MambaConfig(
367+
expand=config.expand,
368+
n_groups=config.n_groups,
369+
n_heads=config.mamba_num_heads,
370+
d_head=config.mamba_head_dim,
371+
d_state=config.ssm_state_size,
372+
d_conv=config.conv_kernel,
373+
)
374+
375+
376+
class Zamba2ModelConfig(HybridAttentionMambaModelConfig):
377+
378+
@classmethod
379+
def parse_mamba_config(
380+
cls, config: "PretrainedConfig"
381+
) -> HybridAttentionMambaModelConfig.MambaConfig:
382+
return HybridAttentionMambaModelConfig.MambaConfig(
383+
expand=config.mamba_expand,
384+
n_groups=config.mamba_ngroups,
385+
n_heads=config.n_mamba_heads,
386+
d_head=config.mamba_headdim,
387+
d_state=config.mamba_d_state,
388+
d_conv=config.mamba_d_conv,
389+
)
390+
391+
203392
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
204393
"GteModel": SnowflakeGteNewModelConfig,
205394
"GteNewModel": GteNewModelConfig,
206395
"NomicBertModel": NomicBertModelConfig,
207396
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
208397
"XLMRobertaModel": JinaRobertaModelConfig,
398+
"FalconH1ForCausalLM": HybridAttentionMambaModelConfig,
399+
"BambaForCausalLM": HybridAttentionMambaModelConfig,
400+
"GraniteMoeHybridForCausalLM": HybridAttentionMambaModelConfig,
401+
"NemotronHForCausalLM": NemotronHModelConfig,
402+
"Zamba2ForCausalLM": Zamba2ModelConfig,
209403
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
210404
}

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from vllm.sampling_params import SamplingType
4343
from vllm.sequence import IntermediateTensors
4444
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
45-
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
45+
GiB_bytes, LazyLoader, async_tensor_h2d,
4646
check_use_alibi, get_dtype_size,
4747
is_pin_memory_available, round_up)
4848
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
@@ -2636,9 +2636,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
26362636
"Prefix caching is not supported for Mamba yet.")
26372637
max_model_len = self.vllm_config.model_config.max_model_len
26382638

2639-
page_size_padded = self._maybe_pad_mamba_page_size(
2640-
attn_layers, mamba_layers, kv_cache_spec, max_model_len,
2641-
block_size)
2639+
page_size_padded = (
2640+
self.vllm_config.cache_config.mamba_page_size_padded)
26422641

26432642
# Set block_size to max_model_len, so that mamba model will always
26442643
# have only one block in the KV cache.
@@ -2650,54 +2649,3 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
26502649
page_size_padded=page_size_padded)
26512650

26522651
return kv_cache_spec
2653-
2654-
def _maybe_pad_mamba_page_size(
2655-
self,
2656-
attn_layers: dict[str, Attention],
2657-
mamba_layers: dict[str, MambaBase],
2658-
kv_cache_spec: dict[str, KVCacheSpec],
2659-
max_model_len: int,
2660-
block_size: int,
2661-
) -> Optional[int]:
2662-
"""
2663-
Ensure that page size of attention KV cache groups is greater than or
2664-
equal to the mamba KV cache groups. If not, we suggest to the user
2665-
how to set the attention block size to ensure that it is.
2666-
2667-
If the attention page size is strictly greater than the mamba page size,
2668-
we pad the mamba page size to make them equal.
2669-
2670-
Args:
2671-
attn_layers: Attention layers
2672-
mamba_layers: Mamba layers
2673-
kv_cache_spec: KV cache spec (populated with attention layers)
2674-
2675-
Returns:
2676-
Optional[int]: Mamba page size with padding (None if no padding).
2677-
"""
2678-
2679-
if len(attn_layers) == 0:
2680-
return None
2681-
2682-
attn_layer_name = next(iter(attn_layers))
2683-
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
2684-
mamba_layer_name = next(iter(mamba_layers))
2685-
mamba_page_size = MambaSpec(
2686-
shapes=mamba_layers[mamba_layer_name].get_state_shape(),
2687-
dtype=self.kv_cache_dtype,
2688-
block_size=max_model_len).page_size_bytes
2689-
if attn_page_size < mamba_page_size:
2690-
# attention page size (for 16 tokens)
2691-
attn_page_size_16 = 16 * attn_page_size // block_size
2692-
# some attention backends (e.g. FA) only support setting
2693-
# block size to multiple of 16, so let's suggest a value
2694-
# that would work (note: FA is currently not compatible
2695-
# with mamba layers, use FlashInfer instead).
2696-
suggest_attn_block_size = 16 * cdiv(mamba_page_size,
2697-
attn_page_size_16)
2698-
raise ValueError(
2699-
"Attention block size should be increased to at least "
2700-
f"{suggest_attn_block_size} in order to match "
2701-
"the mamba page size")
2702-
2703-
return attn_page_size

0 commit comments

Comments
 (0)