diff --git a/vllm/config.py b/vllm/config.py index 1a3ff9d42ff..c0f9df99cce 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1331,6 +1331,17 @@ def get_num_layers_by_block_type( return sum(t == 1 for t in attn_type_list[start:end]) + def get_mamba_chunk_size(self) -> Optional[int]: + """ + Returns the mamba chunk size if it exists + """ + # used by e.g. Bamba, FalconH1, Granite, PLaMo2 + chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None) + if chunk_size is None: + # used by e.g. Mamba2, NemotronH, Zamba + chunk_size = getattr(self.hf_text_config, "chunk_size", None) + return chunk_size + def get_multimodal_config(self) -> "MultiModalConfig": """ Get the multimodal configuration of the model. diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py new file mode 100644 index 00000000000..4c4997b4894 --- /dev/null +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Iterable + +import torch + + +class MambaBase(ABC): + """ + Base class for Mamba-like layers which support the v1 engine. + Inherit from this class if you implement a custom layer. + """ + + # Contains the KV cache (mamba state) for the layer + # in the shape specified by `self.get_state_shape`. + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + kv_cache: list[Iterable[torch.Tensor]] + + @abstractmethod + def get_state_shape(self) -> Iterable[tuple[int, ...]]: + """ + Defines the shape of the state. + For mamba layers this is usually a (conv_state, ssm_state) tuple. + In this case, returns (conv_state_shape, ssm_state_shape). + """ + pass diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 2cc30e4d3f7..4ca8e6b97fc 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -17,6 +17,7 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, update_metadata) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( @@ -219,7 +220,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @CustomOp.register("mamba_mixer2") -class MambaMixer2(CustomOp): +class MambaMixer2(MambaBase, CustomOp): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent @@ -231,22 +232,21 @@ class MambaMixer2(CustomOp): """ def __init__( - self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation: str = "silu", - use_rms_norm: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - chunk_size: int = -1, # the chunk size used by v1 + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -428,10 +428,7 @@ def __init__( # of Attention + v0 PP. # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] - assert chunk_size != -1, "chunk_size must be set for v1" - # NOTE: chunk_size may be -1 for models without v1 support - self.chunk_size = chunk_size self.prefix = prefix def forward_native( diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index d743c52074c..dfc55b0c341 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -99,8 +99,7 @@ def __init__(self, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.mamba_chunk_size) + prefix=f"{prefix}.mixer") self.feed_forward = BambaMLP(config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index a76e1f256e0..ad3f39793b6 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -109,7 +109,6 @@ def __init__( quant_config=quant_config, use_rms_norm=config.mamba_rms_norm, prefix=f"{prefix}.mixer", - chunk_size=config.mamba_chunk_size, ) # n_groups is overridden later by `MambaMixer2` self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 676ef24fc4d..1055fa0372b 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -69,8 +69,7 @@ def __init__(self, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.mamba_chunk_size) + prefix=f"{prefix}.mixer") self.block_sparse_moe = None if getattr(config, "num_local_experts", 0) > 0: diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index d2403ccbb97..b9fa5707393 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -62,8 +62,7 @@ def __init__(self, rms_norm_eps=config.layer_norm_epsilon, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.chunk_size) + prefix=f"{prefix}.mixer") self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 5d51b01df9d..60fb7254725 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -154,7 +154,6 @@ def __init__( activation=config.mamba_hidden_act, quant_config=quant_config, prefix=f"{prefix}.mixer", - chunk_size=config.chunk_size, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 54c80cfa592..4935fd9e6df 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -501,8 +501,7 @@ def __init__(self, rms_norm_eps=config.rms_norm_eps, activation="silu", quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.chunk_size) + prefix=f"{prefix}.mixer") # Input normalization self.input_layernorm = RMSNorm(config.hidden_size, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 9dea08b6583..7b4ecd7c359 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -7,7 +7,6 @@ import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import MambaSpec @@ -19,15 +18,6 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner -def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: - from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 - layers = get_layers_from_vllm_config(vllm_config, MambaMixer2) - chunk_sizes = set(layer.chunk_size for layer in layers.values()) - assert len( - chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size" - return chunk_sizes.pop() - - def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int): @@ -102,7 +92,10 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, self.runner = runner self.kv_cache_spec = kv_cache_spec self.block_table = block_table - self.chunk_size = get_mamba2_chunk_size(runner.vllm_config) + self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size( + ) + assert self.chunk_size is not None, ( + "chunk_size needs to be set in the model config for Mamba2 models") def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e264285859e..f3279fa5fa8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -30,7 +30,7 @@ from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (has_step_pooler, @@ -2623,8 +2623,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") - mamba_layers = get_layers_from_vllm_config(self.vllm_config, - MambaMixer2) + mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: if self.vllm_config.speculative_config is not None: raise NotImplementedError( @@ -2655,7 +2654,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def _maybe_pad_mamba_page_size( self, attn_layers: dict[str, Attention], - mamba_layers: dict[str, MambaMixer2], + mamba_layers: dict[str, MambaBase], kv_cache_spec: dict[str, KVCacheSpec], max_model_len: int, block_size: int,