From 35a8f05028e89b579067c5596fa80451b266199b Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 9 Jul 2025 01:56:21 +0000 Subject: [PATCH 1/3] introduce Mamba2Layer abstraction to allow mamba layers other than MambaMixer2 in v1 engine Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com> --- .../layers/mamba/mamba_mixer2.py | 29 ++++++++++++++++++- vllm/v1/attention/backends/mamba_attn.py | 4 +-- vllm/v1/worker/gpu_model_runner.py | 6 ++-- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 9dcbcb2e6f2..eaeeeb577d2 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Iterable from typing import Optional, Union import torch @@ -216,9 +218,34 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: return loader +class Mamba2Layer(ABC): + """ + Base class for all Mamba2 layers which support the v1 engine. + Inherit from this class if you implement a custom Mamba2 layer. + """ + + chunk_size: int + + # Contains the KV cache (mamba state) for the layer + # in the shape specified by `self.get_state_shape`. + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + kv_cache: list[Iterable[torch.Tensor]] + + @abstractmethod + def get_state_shape(self) -> Iterable[tuple[int, ...]]: + """ + Defines the shape of the mamba state. + Usually, the mamba state is a (conv_state, ssm_state) tuple. + In this case, returns (conv_state_shape, ssm_state_shape). + """ + pass + + # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @CustomOp.register("mamba_mixer2") -class MambaMixer2(CustomOp): +class MambaMixer2(Mamba2Layer, CustomOp): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 74d619aadbd..27b99a5457c 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -21,8 +21,8 @@ def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: - from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 - layers = get_layers_from_vllm_config(vllm_config, MambaMixer2) + from vllm.model_executor.layers.mamba.mamba_mixer2 import Mamba2Layer + layers = get_layers_from_vllm_config(vllm_config, Mamba2Layer) chunk_sizes = set(layer.chunk_size for layer in layers.values()) assert len( chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8658d7d916f..2a6224d0cb9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -31,7 +31,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_mixer2 import Mamba2Layer from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (has_step_pooler, @@ -2660,7 +2660,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: f"Unknown attention type: {attn_module.attn_type}") mamba_layers = get_layers_from_vllm_config(self.vllm_config, - MambaMixer2) + Mamba2Layer) if len(mamba_layers) > 0: if self.vllm_config.speculative_config is not None: raise NotImplementedError( @@ -2691,7 +2691,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def _maybe_pad_mamba_page_size( self, attn_layers: dict[str, Attention], - mamba_layers: dict[str, MambaMixer2], + mamba_layers: dict[str, Mamba2Layer], kv_cache_spec: dict[str, KVCacheSpec], max_model_len: int, block_size: int, From e3af68af96b26d5cf0d7f2da08759b19d421c8f9 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Wed, 9 Jul 2025 02:10:35 +0000 Subject: [PATCH 2/3] retrieve chunk size from the model config instead of the mamba layers Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com> --- vllm/config.py | 11 ++++++ .../layers/mamba/mamba_mixer2.py | 36 ++++++++----------- vllm/model_executor/models/bamba.py | 3 +- vllm/model_executor/models/falcon_h1.py | 1 - .../model_executor/models/granitemoehybrid.py | 3 +- vllm/model_executor/models/mamba2.py | 3 +- vllm/model_executor/models/nemotron_h.py | 1 - vllm/model_executor/models/zamba2.py | 3 +- vllm/v1/attention/backends/mamba_attn.py | 15 +++----- 9 files changed, 34 insertions(+), 42 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 90cf885a40d..5b9723a01c1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1330,6 +1330,17 @@ def get_num_layers_by_block_type( return sum(t == 1 for t in attn_type_list[start:end]) + def get_mamba_chunk_size(self) -> Optional[int]: + """ + Returns the mamba chunk size if it exists + """ + # used by e.g. Bamba, FalconH1, Granite, PLaMo2 + chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None) + if chunk_size is None: + # used by e.g. Mamba2, NemotronH, Zamba + chunk_size = getattr(self.hf_text_config, "chunk_size", None) + return chunk_size + def get_multimodal_config(self) -> "MultiModalConfig": """ Get the multimodal configuration of the model. diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index eaeeeb577d2..67cb030fab5 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -224,8 +224,6 @@ class Mamba2Layer(ABC): Inherit from this class if you implement a custom Mamba2 layer. """ - chunk_size: int - # Contains the KV cache (mamba state) for the layer # in the shape specified by `self.get_state_shape`. # The outer list is for v0 PP virtual engine. Though this code path @@ -257,22 +255,21 @@ class MambaMixer2(Mamba2Layer, CustomOp): """ def __init__( - self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation: str = "silu", - use_rms_norm: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - chunk_size: int = -1, # the chunk size used by v1 + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -454,10 +451,7 @@ def __init__( # of Attention + v0 PP. # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] - assert chunk_size != -1, "chunk_size must be set for v1" - # NOTE: chunk_size may be -1 for models without v1 support - self.chunk_size = chunk_size self.prefix = prefix def forward_native( diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index d743c52074c..dfc55b0c341 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -99,8 +99,7 @@ def __init__(self, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.mamba_chunk_size) + prefix=f"{prefix}.mixer") self.feed_forward = BambaMLP(config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index a76e1f256e0..ad3f39793b6 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -109,7 +109,6 @@ def __init__( quant_config=quant_config, use_rms_norm=config.mamba_rms_norm, prefix=f"{prefix}.mixer", - chunk_size=config.mamba_chunk_size, ) # n_groups is overridden later by `MambaMixer2` self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 676ef24fc4d..1055fa0372b 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -69,8 +69,7 @@ def __init__(self, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.mamba_chunk_size) + prefix=f"{prefix}.mixer") self.block_sparse_moe = None if getattr(config, "num_local_experts", 0) > 0: diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index d2403ccbb97..b9fa5707393 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -62,8 +62,7 @@ def __init__(self, rms_norm_eps=config.layer_norm_epsilon, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.chunk_size) + prefix=f"{prefix}.mixer") self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 5d51b01df9d..60fb7254725 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -154,7 +154,6 @@ def __init__( activation=config.mamba_hidden_act, quant_config=quant_config, prefix=f"{prefix}.mixer", - chunk_size=config.chunk_size, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 54c80cfa592..4935fd9e6df 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -501,8 +501,7 @@ def __init__(self, rms_norm_eps=config.rms_norm_eps, activation="silu", quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.chunk_size) + prefix=f"{prefix}.mixer") # Input normalization self.input_layernorm = RMSNorm(config.hidden_size, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 27b99a5457c..482901c3f10 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -6,7 +6,6 @@ import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.model_executor.layers.mamba.mamba2_metadata import ( _query_start_loc_to_chunk_indices_offsets) from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, @@ -20,15 +19,6 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner -def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: - from vllm.model_executor.layers.mamba.mamba_mixer2 import Mamba2Layer - layers = get_layers_from_vllm_config(vllm_config, Mamba2Layer) - chunk_sizes = set(layer.chunk_size for layer in layers.values()) - assert len( - chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size" - return chunk_sizes.pop() - - class Mamba2AttentionBackend(AttentionBackend): @staticmethod @@ -63,7 +53,10 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, self.runner = runner self.kv_cache_spec = kv_cache_spec self.block_table = block_table - self.chunk_size = get_mamba2_chunk_size(runner.vllm_config) + self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size( + ) + assert self.chunk_size is not None, ( + "chunk_size needs to be set in the model config for Mamba2 models") def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: From 9f7606e3908ec7c54ee6250825b3b7d3e16846d1 Mon Sep 17 00:00:00 2001 From: nopperl <54780682+nopperl@users.noreply.github.com> Date: Fri, 11 Jul 2025 09:36:03 +0900 Subject: [PATCH 3/3] rename Mamba2Layer to MambaBase Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com> --- vllm/model_executor/layers/mamba/abstract.py | 29 +++++++++++++++++++ .../layers/mamba/mamba_mixer2.py | 28 ++---------------- vllm/v1/worker/gpu_model_runner.py | 7 ++--- 3 files changed, 34 insertions(+), 30 deletions(-) create mode 100644 vllm/model_executor/layers/mamba/abstract.py diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py new file mode 100644 index 00000000000..4c4997b4894 --- /dev/null +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Iterable + +import torch + + +class MambaBase(ABC): + """ + Base class for Mamba-like layers which support the v1 engine. + Inherit from this class if you implement a custom layer. + """ + + # Contains the KV cache (mamba state) for the layer + # in the shape specified by `self.get_state_shape`. + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + kv_cache: list[Iterable[torch.Tensor]] + + @abstractmethod + def get_state_shape(self) -> Iterable[tuple[int, ...]]: + """ + Defines the shape of the state. + For mamba layers this is usually a (conv_state, ssm_state) tuple. + In this case, returns (conv_state_shape, ssm_state_shape). + """ + pass diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index b1ffee25654..4ca8e6b97fc 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod -from collections.abc import Iterable from typing import Optional, Union import torch @@ -19,6 +17,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, update_metadata) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( @@ -219,32 +218,9 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: return loader -class Mamba2Layer(ABC): - """ - Base class for all Mamba2 layers which support the v1 engine. - Inherit from this class if you implement a custom Mamba2 layer. - """ - - # Contains the KV cache (mamba state) for the layer - # in the shape specified by `self.get_state_shape`. - # The outer list is for v0 PP virtual engine. Though this code path - # only runs for v1, we have to do this to unify with the interface - # of Attention + v0 PP. - kv_cache: list[Iterable[torch.Tensor]] - - @abstractmethod - def get_state_shape(self) -> Iterable[tuple[int, ...]]: - """ - Defines the shape of the mamba state. - Usually, the mamba state is a (conv_state, ssm_state) tuple. - In this case, returns (conv_state_shape, ssm_state_shape). - """ - pass - - # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @CustomOp.register("mamba_mixer2") -class MambaMixer2(Mamba2Layer, CustomOp): +class MambaMixer2(MambaBase, CustomOp): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2a2f3931f1f..f3279fa5fa8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -30,7 +30,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import Mamba2Layer +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (has_step_pooler, @@ -2623,8 +2623,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") - mamba_layers = get_layers_from_vllm_config(self.vllm_config, - Mamba2Layer) + mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: if self.vllm_config.speculative_config is not None: raise NotImplementedError( @@ -2655,7 +2654,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def _maybe_pad_mamba_page_size( self, attn_layers: dict[str, Attention], - mamba_layers: dict[str, Mamba2Layer], + mamba_layers: dict[str, MambaBase], kv_cache_spec: dict[str, KVCacheSpec], max_model_len: int, block_size: int,