Skip to content

Commit 3534c39

Browse files
authored
[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent c586b55 commit 3534c39

File tree

14 files changed

+441
-353
lines changed

14 files changed

+441
-353
lines changed

docs/usage/v1_guide.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ enforcing eager mode and disabling prefix caching in V1.
112112
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
113113
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
114114
these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
115-
backend in V1. It is also necessary to pass a non-standard block size for attention layers (this is not possible
116-
using the `vllm serve` CLI yet).
115+
backend in V1.
117116

118117
#### Encoder-Decoder Models
119118

tests/models/language/generation/test_hybrid.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,6 @@
6161
"tiiuae/Falcon-H1-0.5B-Base",
6262
]
6363

64-
ATTN_BLOCK_SIZES = {
65-
"ibm-ai-platform/Bamba-9B-v1": 528,
66-
"Zyphra/Zamba2-1.2B-instruct": 80,
67-
"nvidia/Nemotron-H-8B-Base-8K": 528,
68-
"ibm-granite/granite-4.0-tiny-preview": 400,
69-
"tiiuae/Falcon-H1-0.5B-Base": 800,
70-
}
71-
7264
# Avoid OOM
7365
MAX_NUM_SEQS = 4
7466

@@ -105,11 +97,6 @@ def test_models(
10597
example_prompts, max_tokens, num_logprobs)
10698

10799
if model in V1_SUPPORTED_MODELS:
108-
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
109-
block_size = ATTN_BLOCK_SIZES[model]
110-
else:
111-
block_size = 16
112-
113100
with monkeypatch.context() as m:
114101
m.setenv("VLLM_USE_V1", "1")
115102
if model in HYBRID_MODELS:
@@ -118,8 +105,7 @@ def test_models(
118105
with vllm_runner(model,
119106
max_num_seqs=MAX_NUM_SEQS,
120107
enforce_eager=True,
121-
enable_prefix_caching=False,
122-
block_size=block_size) as vllm_model:
108+
enable_prefix_caching=False) as vllm_model:
123109
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
124110
example_prompts, max_tokens, num_logprobs)
125111
else:

vllm/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,9 @@ class CacheConfig:
16301630
checkpoint if available. Otherwise, the scales will default to 1.0."""
16311631
cpu_kvcache_space_bytes: Optional[int] = None
16321632
"""(CPU backend only) CPU key-value cache space."""
1633+
mamba_page_size_padded: Optional[int] = None
1634+
""" Optional override for mamba page size; used by hybrid mamba/attention
1635+
models to ensure exact alignment with attention page size."""
16331636

16341637
# Will be set after profiling.
16351638
num_gpu_blocks: Optional[int] = field(default=None, init=False)
@@ -4882,11 +4885,15 @@ def try_verify_and_update_config(self):
48824885
if architecture is None:
48834886
return
48844887

4885-
from vllm.model_executor.models.config import MODELS_CONFIG_MAP
4888+
from vllm.model_executor.models.config import (
4889+
MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig)
48864890
cls = MODELS_CONFIG_MAP.get(architecture, None)
48874891
if cls is not None:
48884892
cls.verify_and_update_config(self)
48894893

4894+
if self.model_config.is_hybrid:
4895+
HybridAttentionMambaModelConfig.verify_and_update_config(self)
4896+
48904897
if self.model_config.task == "classify":
48914898
# Maybe convert ForCausalLM into ForSequenceClassification model.
48924899
from vllm.model_executor.models.adapters import (

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from vllm.model_executor.layers.mamba.abstract import MambaBase
2121
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
2222
update_metadata)
23+
from vllm.model_executor.layers.mamba.mamba_utils import (
24+
extra_groups_for_head_shards, get_mamba_state_shape)
2325
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
2426
causal_conv1d_fn, causal_conv1d_update)
2527
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@@ -146,18 +148,6 @@ def forward_cuda(
146148
return out
147149

148150

149-
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
150-
"""Compute the increase in group numbers to account for
151-
replication in order to accompany the head shards."""
152-
153-
# in the case ngoups % tp_size == 0, this will be zero
154-
if ngroups % tp_size == 0:
155-
return 0
156-
157-
# for n_groups == 1, this is exactly tp_size - n_groups
158-
return tp_size - ngroups
159-
160-
161151
def mamba_v2_sharded_weight_loader(
162152
shard_spec: list[tuple[int, int, float]],
163153
tp_size: int,
@@ -707,30 +697,12 @@ def forward_cuda(
707697
return out
708698

709699
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
710-
world_size = get_tensor_model_parallel_world_size()
711-
712-
conv_state_shape, temporal_state_shape = None, None
713-
714-
# if n_groups is not divisible by world_size, need to extend the shards
715-
# to ensure all groups needed by a head is sharded along with it
716-
n_groups = (self.n_groups +
717-
extra_groups_for_head_shards(self.n_groups, world_size))
718-
719-
# - heads and n_groups are TP-ed
720-
conv_dim = (self.intermediate_size +
721-
2 * n_groups * self.ssm_state_size)
722-
# contiguous along 'dim' axis
723-
conv_state_shape = (
724-
self.conv_kernel_size - 1,
725-
divide(conv_dim, world_size),
726-
)
727-
728-
# These are not TP-ed as they depend on A, dt_bias, D
729-
# - they are typically small
730-
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
731-
temporal_state_shape = (
732-
divide(self.num_heads, world_size),
733-
self.head_dim,
734-
self.ssm_state_size,
700+
return get_mamba_state_shape(
701+
intermediate_size=self.intermediate_size,
702+
tp_world_size=get_tensor_model_parallel_world_size(),
703+
n_groups=self.n_groups,
704+
num_heads=self.num_heads,
705+
head_dim=self.head_dim,
706+
state_size=self.ssm_state_size,
707+
conv_kernel=self.conv_kernel_size,
735708
)
736-
return conv_state_shape, temporal_state_shape
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from vllm.distributed import divide
4+
5+
6+
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
7+
"""Compute the increase in group numbers to account for
8+
replication in order to accompany the head shards."""
9+
10+
# in the case ngoups % tp_size == 0, this will be zero
11+
if ngroups % tp_size == 0:
12+
return 0
13+
14+
# for n_groups == 1, this is exactly tp_size - n_groups
15+
return tp_size - ngroups
16+
17+
18+
def get_mamba_state_shape(
19+
intermediate_size: int,
20+
tp_world_size: int,
21+
n_groups: int,
22+
num_heads: int,
23+
head_dim: int,
24+
state_size: int,
25+
conv_kernel: int,
26+
use_v1: bool = True,
27+
) -> tuple[tuple[int, int], tuple[int, int, int]]:
28+
""" Get the shape of mamba state."""
29+
30+
# if n_groups is not divisible by world_size, need to extend the shards
31+
# to ensure all groups needed by a head is sharded along with it
32+
n_groups = (n_groups +
33+
extra_groups_for_head_shards(n_groups, tp_world_size))
34+
35+
# - heads and n_groups are TP-ed
36+
conv_dim = (intermediate_size + 2 * n_groups * state_size)
37+
# contiguous along 'dim' axis
38+
conv_state_shape = (
39+
conv_kernel - 1,
40+
divide(conv_dim, tp_world_size),
41+
)
42+
43+
if not use_v1:
44+
conv_state_shape = (conv_state_shape[1], conv_state_shape[0])
45+
46+
# These are not TP-ed as they depend on A, dt_bias, D
47+
# - they are typically small
48+
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
49+
temporal_state_shape = (
50+
divide(num_heads, tp_world_size),
51+
head_dim,
52+
state_size,
53+
)
54+
55+
return conv_state_shape, temporal_state_shape

vllm/model_executor/models/bamba.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm import envs
1313
from vllm.attention.layer import Attention
1414
from vllm.config import CacheConfig, VllmConfig
15-
from vllm.distributed import divide, get_tensor_model_parallel_world_size
15+
from vllm.distributed import get_tensor_model_parallel_world_size
1616
from vllm.distributed.parallel_state import get_pp_group
1717
from vllm.forward_context import get_forward_context
1818
from vllm.model_executor.layers.activation import SiluAndMul
@@ -23,8 +23,8 @@
2323
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2424
from vllm.model_executor.layers.mamba.mamba2_metadata import (
2525
Mamba2Metadata, prepare_mamba2_metadata)
26-
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
27-
MambaMixer2, extra_groups_for_head_shards)
26+
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
27+
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
2828
from vllm.model_executor.layers.quantization import QuantizationConfig
2929
from vllm.model_executor.layers.rotary_embedding import get_rope
3030
from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -435,6 +435,38 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
435435
}
436436
embedding_padding_modules = ["lm_head"]
437437

438+
@classmethod
439+
def get_mamba_state_shape_from_config(
440+
cls,
441+
vllm_config: "VllmConfig",
442+
use_v1: bool = True,
443+
) -> tuple[tuple[int, int], tuple[int, int, int]]:
444+
"""Calculate shapes for Mamba's convolutional and state caches.
445+
446+
Args:
447+
vllm_config: vLLM config
448+
use_v1: Get shapes for V1 (or V0)
449+
450+
Returns:
451+
Tuple containing:
452+
- conv_state_shape: Shape for convolutional state cache
453+
- temporal_state_shape: Shape for state space model cache
454+
"""
455+
parallel_config = vllm_config.parallel_config
456+
hf_config = vllm_config.model_config.hf_config
457+
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
458+
459+
return get_mamba_state_shape(
460+
intermediate_size=intermediate_size,
461+
tp_world_size=parallel_config.tensor_parallel_size,
462+
n_groups=hf_config.mamba_n_groups,
463+
num_heads=hf_config.mamba_n_heads,
464+
head_dim=hf_config.mamba_d_head,
465+
state_size=hf_config.mamba_d_state,
466+
conv_kernel=hf_config.mamba_d_conv,
467+
use_v1=use_v1,
468+
)
469+
438470
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
439471
config = vllm_config.model_config.hf_config
440472
self.vllm_config = vllm_config
@@ -491,10 +523,13 @@ def forward(self,
491523
self.vllm_config.parallel_config,
492524
LayerBlockType.mamba
493525
)
494-
495-
self.mamba_cache = MambaCacheManager(
496-
self.vllm_config, self.lm_head.weight.dtype,
497-
num_mamba_layers, *self._get_mamba_cache_shape())
526+
mamba_state_shape = \
527+
self.get_mamba_state_shape_from_config(
528+
self.vllm_config, use_v1=False)
529+
self.mamba_cache = MambaCacheManager(self.vllm_config,
530+
self.lm_head.weight.dtype,
531+
num_mamba_layers,
532+
*mamba_state_shape)
498533

499534
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
500535

@@ -510,38 +545,6 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
510545
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
511546
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
512547

513-
def _get_mamba_cache_shape(
514-
self) -> tuple[tuple[int, int], tuple[int, int]]:
515-
world_size = get_tensor_model_parallel_world_size()
516-
hidden_size = self.config.hidden_size
517-
518-
conv_state_shape, temporal_state_shape = None, None
519-
520-
intermediate_size = self.config.mamba_expand * hidden_size
521-
522-
# if n_groups is not divisible by world_size, need to extend the shards
523-
# to ensure all groups needed by a head is sharded along with it
524-
n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards(
525-
self.config.mamba_n_groups, world_size))
526-
527-
# - heads and n_groups are TP-ed
528-
conv_dim = (intermediate_size +
529-
2 * n_groups * self.config.mamba_d_state)
530-
conv_state_shape = (
531-
divide(conv_dim, world_size),
532-
self.config.mamba_d_conv - 1,
533-
)
534-
535-
# These are not TP-ed as they depend on A, dt_bias, D
536-
# - they are typically small
537-
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
538-
temporal_state_shape = (
539-
divide(self.config.mamba_n_heads, world_size),
540-
self.config.mamba_d_head,
541-
self.config.mamba_d_state,
542-
)
543-
return conv_state_shape, temporal_state_shape
544-
545548
def compute_logits(
546549
self,
547550
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)