Skip to content

[V1] Enable Mamba2 layers other than MambaMixer2 in the v1 engine #20660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,17 @@ def get_num_layers_by_block_type(

return sum(t == 1 for t in attn_type_list[start:end])

def get_mamba_chunk_size(self) -> Optional[int]:
"""
Returns the mamba chunk size if it exists
"""
# used by e.g. Bamba, FalconH1, Granite, PLaMo2
chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None)
if chunk_size is None:
# used by e.g. Mamba2, NemotronH, Zamba
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
return chunk_size

def get_multimodal_config(self) -> "MultiModalConfig":
"""
Get the multimodal configuration of the model.
Expand Down
29 changes: 29 additions & 0 deletions vllm/model_executor/layers/mamba/abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Iterable

import torch


class MambaBase(ABC):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
"""

# Contains the KV cache (mamba state) for the layer
# in the shape specified by `self.get_state_shape`.
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
kv_cache: list[Iterable[torch.Tensor]]

@abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
"""
Defines the shape of the state.
For mamba layers this is usually a (conv_state, ssm_state) tuple.
In this case, returns (conv_state_shape, ssm_state_shape).
"""
pass
37 changes: 17 additions & 20 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
Expand Down Expand Up @@ -219,7 +220,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:

# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer2")
class MambaMixer2(CustomOp):
class MambaMixer2(MambaBase, CustomOp):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
Expand All @@ -231,22 +232,21 @@ class MambaMixer2(CustomOp):
"""

def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
chunk_size: int = -1, # the chunk size used by v1
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()

Expand Down Expand Up @@ -428,10 +428,7 @@ def __init__(
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert chunk_size != -1, "chunk_size must be set for v1"

# NOTE: chunk_size may be -1 for models without v1 support
self.chunk_size = chunk_size
self.prefix = prefix

def forward_native(
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def __init__(self,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size)
prefix=f"{prefix}.mixer")

self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(
quant_config=quant_config,
use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size,
)
# n_groups is overridden later by `MambaMixer2`
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def __init__(self,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size)
prefix=f"{prefix}.mixer")

self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def __init__(self,
rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)
prefix=f"{prefix}.mixer")

self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/models/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def __init__(
activation=config.mamba_hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size,
)

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,7 @@ def __init__(self,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)
prefix=f"{prefix}.mixer")

# Input normalization
self.input_layernorm = RMSNorm(config.hidden_size,
Expand Down
15 changes: 4 additions & 11 deletions vllm/v1/attention/backends/mamba_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch

from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import MambaSpec
Expand All @@ -19,15 +18,6 @@
from vllm.v1.worker.gpu_model_runner import GPUModelRunner


def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
chunk_sizes = set(layer.chunk_size for layer in layers.values())
assert len(
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
return chunk_sizes.pop()


def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
Expand Down Expand Up @@ -102,7 +92,10 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
self.runner = runner
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size(
)
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
Comment on lines +95 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider logging a warning message instead of raising an assertion error. This will allow the program to continue running, while still informing the user that there might be an issue with their configuration.

if self.chunk_size is None:
            logger.warning("chunk_size needs to be set in the model config for Mamba2 models")


def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
Expand Down
7 changes: 3 additions & 4 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (has_step_pooler,
Expand Down Expand Up @@ -2623,8 +2623,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")

mamba_layers = get_layers_from_vllm_config(self.vllm_config,
MambaMixer2)
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
if len(mamba_layers) > 0:
if self.vllm_config.speculative_config is not None:
raise NotImplementedError(
Expand Down Expand Up @@ -2655,7 +2654,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
def _maybe_pad_mamba_page_size(
self,
attn_layers: dict[str, Attention],
mamba_layers: dict[str, MambaMixer2],
mamba_layers: dict[str, MambaBase],
kv_cache_spec: dict[str, KVCacheSpec],
max_model_len: int,
block_size: int,
Expand Down