Skip to content

Commit 5d09152

Browse files
authored
[V1] Enable Mamba2 layers other than MambaMixer2 in the v1 engine (#20660)
Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
1 parent 31d5c17 commit 5d09152

File tree

11 files changed

+68
-45
lines changed

11 files changed

+68
-45
lines changed

vllm/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,17 @@ def get_num_layers_by_block_type(
13311331

13321332
return sum(t == 1 for t in attn_type_list[start:end])
13331333

1334+
def get_mamba_chunk_size(self) -> Optional[int]:
1335+
"""
1336+
Returns the mamba chunk size if it exists
1337+
"""
1338+
# used by e.g. Bamba, FalconH1, Granite, PLaMo2
1339+
chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None)
1340+
if chunk_size is None:
1341+
# used by e.g. Mamba2, NemotronH, Zamba
1342+
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
1343+
return chunk_size
1344+
13341345
def get_multimodal_config(self) -> "MultiModalConfig":
13351346
"""
13361347
Get the multimodal configuration of the model.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from abc import ABC, abstractmethod
4+
from collections.abc import Iterable
5+
6+
import torch
7+
8+
9+
class MambaBase(ABC):
10+
"""
11+
Base class for Mamba-like layers which support the v1 engine.
12+
Inherit from this class if you implement a custom layer.
13+
"""
14+
15+
# Contains the KV cache (mamba state) for the layer
16+
# in the shape specified by `self.get_state_shape`.
17+
# The outer list is for v0 PP virtual engine. Though this code path
18+
# only runs for v1, we have to do this to unify with the interface
19+
# of Attention + v0 PP.
20+
kv_cache: list[Iterable[torch.Tensor]]
21+
22+
@abstractmethod
23+
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
24+
"""
25+
Defines the shape of the state.
26+
For mamba layers this is usually a (conv_state, ssm_state) tuple.
27+
In this case, returns (conv_state_shape, ssm_state_shape).
28+
"""
29+
pass

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.model_executor.custom_op import CustomOp
1818
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1919
RowParallelLinear)
20+
from vllm.model_executor.layers.mamba.abstract import MambaBase
2021
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
2122
update_metadata)
2223
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
@@ -219,7 +220,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
219220

220221
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
221222
@CustomOp.register("mamba_mixer2")
222-
class MambaMixer2(CustomOp):
223+
class MambaMixer2(MambaBase, CustomOp):
223224
"""
224225
Compute ∆, A, B, C, and D the state space parameters and compute
225226
the `contextualized_states`. A, D are input independent
@@ -231,22 +232,21 @@ class MambaMixer2(CustomOp):
231232
"""
232233

233234
def __init__(
234-
self,
235-
hidden_size: int,
236-
ssm_state_size: int,
237-
conv_kernel_size: int,
238-
intermediate_size: int,
239-
use_conv_bias: bool,
240-
use_bias: bool,
241-
n_groups: int = 1,
242-
num_heads: int = 128,
243-
head_dim: int = 64,
244-
rms_norm_eps: float = 1e-5,
245-
activation: str = "silu",
246-
use_rms_norm: bool = True,
247-
quant_config: Optional[QuantizationConfig] = None,
248-
prefix: str = "",
249-
chunk_size: int = -1, # the chunk size used by v1
235+
self,
236+
hidden_size: int,
237+
ssm_state_size: int,
238+
conv_kernel_size: int,
239+
intermediate_size: int,
240+
use_conv_bias: bool,
241+
use_bias: bool,
242+
n_groups: int = 1,
243+
num_heads: int = 128,
244+
head_dim: int = 64,
245+
rms_norm_eps: float = 1e-5,
246+
activation: str = "silu",
247+
use_rms_norm: bool = True,
248+
quant_config: Optional[QuantizationConfig] = None,
249+
prefix: str = "",
250250
):
251251
super().__init__()
252252

@@ -428,10 +428,7 @@ def __init__(
428428
# of Attention + v0 PP.
429429
# The inner tuple is (conv_state, ssm_state)
430430
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
431-
assert chunk_size != -1, "chunk_size must be set for v1"
432431

433-
# NOTE: chunk_size may be -1 for models without v1 support
434-
self.chunk_size = chunk_size
435432
self.prefix = prefix
436433

437434
def forward_native(

vllm/model_executor/models/bamba.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ def __init__(self,
9999
rms_norm_eps=config.rms_norm_eps,
100100
activation=config.hidden_act,
101101
quant_config=quant_config,
102-
prefix=f"{prefix}.mixer",
103-
chunk_size=config.mamba_chunk_size)
102+
prefix=f"{prefix}.mixer")
104103

105104
self.feed_forward = BambaMLP(config, quant_config=quant_config)
106105
self.input_layernorm = RMSNorm(config.hidden_size,

vllm/model_executor/models/falcon_h1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def __init__(
109109
quant_config=quant_config,
110110
use_rms_norm=config.mamba_rms_norm,
111111
prefix=f"{prefix}.mixer",
112-
chunk_size=config.mamba_chunk_size,
113112
)
114113
# n_groups is overridden later by `MambaMixer2`
115114
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ def __init__(self,
6969
rms_norm_eps=config.rms_norm_eps,
7070
activation=config.hidden_act,
7171
quant_config=quant_config,
72-
prefix=f"{prefix}.mixer",
73-
chunk_size=config.mamba_chunk_size)
72+
prefix=f"{prefix}.mixer")
7473

7574
self.block_sparse_moe = None
7675
if getattr(config, "num_local_experts", 0) > 0:

vllm/model_executor/models/mamba2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def __init__(self,
6262
rms_norm_eps=config.layer_norm_epsilon,
6363
activation=config.hidden_act,
6464
quant_config=quant_config,
65-
prefix=f"{prefix}.mixer",
66-
chunk_size=config.chunk_size)
65+
prefix=f"{prefix}.mixer")
6766

6867
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
6968

vllm/model_executor/models/nemotron_h.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ def __init__(
154154
activation=config.mamba_hidden_act,
155155
quant_config=quant_config,
156156
prefix=f"{prefix}.mixer",
157-
chunk_size=config.chunk_size,
158157
)
159158

160159
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

vllm/model_executor/models/zamba2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,7 @@ def __init__(self,
501501
rms_norm_eps=config.rms_norm_eps,
502502
activation="silu",
503503
quant_config=quant_config,
504-
prefix=f"{prefix}.mixer",
505-
chunk_size=config.chunk_size)
504+
prefix=f"{prefix}.mixer")
506505

507506
# Input normalization
508507
self.input_layernorm = RMSNorm(config.hidden_size,

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88

99
from vllm.attention.backends.abstract import AttentionBackend
10-
from vllm.config import VllmConfig, get_layers_from_vllm_config
1110
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
1211
CommonAttentionMetadata)
1312
from vllm.v1.kv_cache_interface import MambaSpec
@@ -19,15 +18,6 @@
1918
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2019

2120

22-
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
23-
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
24-
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
25-
chunk_sizes = set(layer.chunk_size for layer in layers.values())
26-
assert len(
27-
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
28-
return chunk_sizes.pop()
29-
30-
3121
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
3222
chunk_size: int,
3323
total_seqlens: int):
@@ -102,7 +92,10 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
10292
self.runner = runner
10393
self.kv_cache_spec = kv_cache_spec
10494
self.block_table = block_table
105-
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
95+
self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size(
96+
)
97+
assert self.chunk_size is not None, (
98+
"chunk_size needs to be set in the model config for Mamba2 models")
10699

107100
def reorder_batch(self, input_batch: "InputBatch",
108101
scheduler_output: "SchedulerOutput") -> bool:

0 commit comments

Comments
 (0)