Skip to content

Enable V1 for Hybrid SSM/Attention Models #20016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
de4e3a2
working change
tdoublep Jun 23, 2025
617cd26
working changes
tdoublep Jun 24, 2025
9378c54
Merge branch 'main' into tpa-bamba-v1
tdoublep Jun 24, 2025
300d25f
Working version
tdoublep Jun 24, 2025
0822308
Add support + test for Zamba2
tdoublep Jun 24, 2025
a9fc73f
Fix memory layout for KV cache tensors in mamba case
tdoublep Jun 26, 2025
0e5b6de
kv_cache_interface.py: use utils.round_up
tdoublep Jun 26, 2025
89f504a
Fix unrelated CI test failing
tdoublep Jun 26, 2025
0b7783b
Enable bamba-9b in CI
tdoublep Jun 26, 2025
ded4833
Fix unrelated CI issue
tdoublep Jun 26, 2025
31db869
Add support for Nemotron-H
tdoublep Jun 27, 2025
c45e7e5
Enable Granite 4.0 (HybridMoE)
tdoublep Jun 27, 2025
0f20e11
Merge branch 'main' into tpa-bamba-v1
tdoublep Jun 27, 2025
e2c14ba
add support for Falcon H1
tdoublep Jun 27, 2025
c5a25eb
Fix overflow issue in mamba_ssm kernel
tdoublep Jun 30, 2025
cfc38c0
Resolve merge conflicts
tdoublep Jun 30, 2025
aaa6f0e
Add check for transformers min version
tdoublep Jun 30, 2025
d187bfd
Don't fail test if model is not in HF_EXAMPLE_MODELS
tdoublep Jun 30, 2025
fde28dc
Fix test_batching in same way
tdoublep Jun 30, 2025
1777fd1
Update tests/models/language/generation/test_hybrid.py
tdoublep Jul 1, 2025
58e66c9
Update vllm/model_executor/models/granitemoehybrid.py
tdoublep Jul 1, 2025
c2da03e
page_size -> num_element_per_page
tdoublep Jul 1, 2025
e0404c9
Clean up page size padding logic
tdoublep Jul 1, 2025
c74698d
gpu_model_runner.py: add TODO about batch reordering
tdoublep Jul 1, 2025
b72b729
Fix linting issue
tdoublep Jul 1, 2025
105737c
Validate memory layout for hybrid models against attention backends
tdoublep Jul 1, 2025
d8ff3b9
Adjust comment
tdoublep Jul 1, 2025
c857ec3
Merge branch 'main' into tpa-bamba-v1
tdoublep Jul 4, 2025
ea8cf32
Move memory layout check into separate function
tdoublep Jul 4, 2025
b38d3fb
Move logic to pad mamba page size into separate function
tdoublep Jul 4, 2025
e6b0015
Add extra todo
tdoublep Jul 4, 2025
14fd006
test_oracle.py: hybrid models now supported
tdoublep Jul 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions tests/models/language/generation/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@


@pytest.mark.parametrize("model", MODELS)
def test_dummy_loader(vllm_runner, model: str) -> None:
with vllm_runner(
model,
load_format="dummy",
) as llm:
normalizers = llm.collective_rpc(lambda self: self.worker.model_runner.
model.model.normalizer.cpu().item())
assert np.allclose(
normalizers,
llm.llm_engine.model_config.hf_config.hidden_size**0.5,
rtol=1e-3)
def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None:
with monkeypatch.context() as m:
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
with vllm_runner(
model,
load_format="dummy",
) as llm:
normalizers = llm.model.collective_rpc(
lambda self: self.model_runner.model.model.normalizer.cpu(
).item())
assert np.allclose(
normalizers,
llm.model.llm_engine.model_config.hf_config.hidden_size**0.5,
rtol=1e-3)
19 changes: 18 additions & 1 deletion tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-ai-platform/Bamba-9B-v1",
]

V1_SUPPORTED_MODELS = [
"mistralai/Mamba-Codestral-7B-v0.1",
"ibm-ai-platform/Bamba-9B-v1",
"Zyphra/Zamba2-1.2B-instruct",
]

ATTN_BLOCK_SIZES = {
"ibm-ai-platform/Bamba-9B-v1": 528,
"Zyphra/Zamba2-1.2B-instruct": 80,
}

# Avoid OOM
MAX_NUM_SEQS = 4

Expand Down Expand Up @@ -68,12 +76,21 @@ def test_models(
example_prompts, max_tokens, num_logprobs)

if model in V1_SUPPORTED_MODELS:
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
block_size = ATTN_BLOCK_SIZES[model]
else:
block_size = 16

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False) as vllm_model:
enable_prefix_caching=False,
block_size=block_size) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
Expand Down
45 changes: 30 additions & 15 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import nn
from transformers import BambaConfig

from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -36,7 +37,7 @@
from vllm.utils import LayerBlockType

from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
SupportsQuant)
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -97,7 +98,9 @@ def __init__(self,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size)

self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -313,10 +316,14 @@ def forward(

attn_metadata = get_forward_context().attn_metadata

mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None

if get_pp_group().is_first_rank:
if inputs_embeds is not None:
Expand All @@ -337,7 +344,8 @@ def forward(
num_attn += 1

layer_mamba_cache_params = None
if isinstance(layer, BambaMixerDecoderLayer):
if isinstance(layer,
BambaMixerDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)

Expand Down Expand Up @@ -411,7 +419,7 @@ def load_weights(self, weights: Iterable[tuple[str,


class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only, SupportsQuant):
IsHybrid, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -475,15 +483,22 @@ def forward(self,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)

self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())

mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)

Expand Down
101 changes: 64 additions & 37 deletions vllm/model_executor/models/zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch import nn
from transformers import Zamba2Config

from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand All @@ -41,7 +42,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
from .interfaces import HasInnerState, IsHybrid
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix


Expand All @@ -58,6 +59,7 @@ def __init__(
rank: int,
output_dim: Union[int, list[int]],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
"""Initialize the attention layer.

Expand Down Expand Up @@ -283,6 +285,7 @@ def __init__(
bare_block_idx: int,
num_hybrid_layers: dict[int, int],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
"""Initialize the MLP layer.

Expand Down Expand Up @@ -471,11 +474,10 @@ class Zamba2MambaDecoderLayer(nn.Module):
computation depending on configuration.
"""

def __init__(
self,
config: Zamba2Config,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
def __init__(self,
config: Zamba2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
"""Initialize the Mamba decoder layer.

Args:
Expand All @@ -486,20 +488,21 @@ def __init__(

# Initialize Mamba mixer with expanded intermediate size
intermediate_size = config.mamba_expand * config.hidden_size
self.mamba = MambaMixer2(
hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=intermediate_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.add_bias_linear,
n_groups=config.mamba_ngroups,
num_heads=config.n_mamba_heads,
head_dim=intermediate_size // config.n_mamba_heads,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
quant_config=quant_config,
)
self.mamba = MambaMixer2(hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=intermediate_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.add_bias_linear,
n_groups=config.mamba_ngroups,
num_heads=config.n_mamba_heads,
head_dim=intermediate_size //
config.n_mamba_heads,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)

# Input normalization
self.input_layernorm = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -573,6 +576,7 @@ def __init__(
config: Zamba2Config,
block_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
"""Initialize the hybrid layer.

Expand All @@ -589,7 +593,8 @@ def __init__(
bias=False,
quant_config=quant_config)
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)

def forward(
self,
Expand Down Expand Up @@ -699,14 +704,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
# Initialize layers according to block type configuration
layers = []
for layer_idx, layer_type in enumerate(config.layers_block_type):
# tpa: avoid layers getting same index
# somewhat hacky but correct (I think)
prefix = str(len(layer2block_map) + layer_idx)
if layer_type == "hybrid":
block = next(blocks)
block_idx = layer2block_map[layer_idx]
layers.append(
Zamba2HybridLayer(block, config, block_idx, quant_config))
Zamba2HybridLayer(block,
config,
block_idx,
quant_config,
prefix=prefix))
else:
layers.append(
Zamba2MambaDecoderLayer(config, quant_config=quant_config))
Zamba2MambaDecoderLayer(config,
quant_config=quant_config,
prefix=prefix))
self.layers = nn.ModuleList(layers)

# Final layer normalization
Expand Down Expand Up @@ -751,19 +765,30 @@ def forward(

attn_metadata = get_forward_context().attn_metadata

mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None

# Process through layers
original_hidden_states = torch.clone(hidden_states)
for layer_idx, layer in enumerate(self.layers):

layer_mamba_cache_params = None
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
and mamba_cache_params):
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
layer_idx)

layer_outputs = layer(
hidden_states,
original_hidden_states=original_hidden_states,
positions=positions,
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
hidden_states = layer_outputs
Expand Down Expand Up @@ -803,7 +828,7 @@ def load_weights(self, weights: Iterable[tuple[str,
return loaded_params


class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
"""Zamba2 model with causal language modeling head.

This class wraps the core Zamba2 model and adds:
Expand Down Expand Up @@ -897,14 +922,16 @@ def forward(self,
Output hidden states
"""
# Initialize Mamba cache if needed
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())

# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())

# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

# Forward pass through model
hidden_states = self.model(
Expand Down
11 changes: 8 additions & 3 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.enable_caching = enable_caching

self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
Expand Down Expand Up @@ -267,9 +268,13 @@ def verify_and_split_kv_cache_groups(self) -> None:

self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
assert self.other_block_size % self.full_attention_block_size == 0, (
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.")

if self.enable_caching:
# this requirement is only needed for the prefix caching logic
divisible = self.other_block_size % self.full_attention_block_size
assert divisible == 0, (
"KVCacheCoordinator assumes the block_size of full "
"attention layers is divisible by other layers now.")

if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
Expand Down
Loading