Skip to content

Commit d754b7a

Browse files
committed
retrieve chunk size from the model config instead of the mamba layers
Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
1 parent 5a6e465 commit d754b7a

File tree

9 files changed

+34
-42
lines changed

9 files changed

+34
-42
lines changed

vllm/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,17 @@ def get_num_layers_by_block_type(
13301330

13311331
return sum(t == 1 for t in attn_type_list[start:end])
13321332

1333+
def get_mamba_chunk_size(self) -> Optional[int]:
1334+
"""
1335+
Returns the mamba chunk size if it exists
1336+
"""
1337+
# used by e.g. Bamba, FalconH1, Granite, PLaMo2
1338+
chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None)
1339+
if chunk_size is None:
1340+
# used by e.g. Mamba2, NemotronH, Zamba
1341+
chunk_size = getattr(self.hf_text_config, "chunk_size", None)
1342+
return chunk_size
1343+
13331344
def get_multimodal_config(self) -> "MultiModalConfig":
13341345
"""
13351346
Get the multimodal configuration of the model.

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,6 @@ class Mamba2Layer(ABC):
223223
Inherit from this class if you implement a custom Mamba2 layer.
224224
"""
225225

226-
chunk_size: int
227-
228226
# Contains the KV cache (mamba state) for the layer
229227
# in the shape specified by `self.get_state_shape`.
230228
# The outer list is for v0 PP virtual engine. Though this code path
@@ -256,22 +254,21 @@ class MambaMixer2(Mamba2Layer, CustomOp):
256254
"""
257255

258256
def __init__(
259-
self,
260-
hidden_size: int,
261-
ssm_state_size: int,
262-
conv_kernel_size: int,
263-
intermediate_size: int,
264-
use_conv_bias: bool,
265-
use_bias: bool,
266-
n_groups: int = 1,
267-
num_heads: int = 128,
268-
head_dim: int = 64,
269-
rms_norm_eps: float = 1e-5,
270-
activation: str = "silu",
271-
use_rms_norm: bool = True,
272-
quant_config: Optional[QuantizationConfig] = None,
273-
prefix: str = "",
274-
chunk_size: int = -1, # the chunk size used by v1
257+
self,
258+
hidden_size: int,
259+
ssm_state_size: int,
260+
conv_kernel_size: int,
261+
intermediate_size: int,
262+
use_conv_bias: bool,
263+
use_bias: bool,
264+
n_groups: int = 1,
265+
num_heads: int = 128,
266+
head_dim: int = 64,
267+
rms_norm_eps: float = 1e-5,
268+
activation: str = "silu",
269+
use_rms_norm: bool = True,
270+
quant_config: Optional[QuantizationConfig] = None,
271+
prefix: str = "",
275272
):
276273
super().__init__()
277274

@@ -453,10 +450,7 @@ def __init__(
453450
# of Attention + v0 PP.
454451
# The inner tuple is (conv_state, ssm_state)
455452
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
456-
assert chunk_size != -1, "chunk_size must be set for v1"
457453

458-
# NOTE: chunk_size may be -1 for models without v1 support
459-
self.chunk_size = chunk_size
460454
self.prefix = prefix
461455

462456
def forward_native(

vllm/model_executor/models/bamba.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ def __init__(self,
9999
rms_norm_eps=config.rms_norm_eps,
100100
activation=config.hidden_act,
101101
quant_config=quant_config,
102-
prefix=f"{prefix}.mixer",
103-
chunk_size=config.mamba_chunk_size)
102+
prefix=f"{prefix}.mixer")
104103

105104
self.feed_forward = BambaMLP(config, quant_config=quant_config)
106105
self.input_layernorm = RMSNorm(config.hidden_size,

vllm/model_executor/models/falcon_h1.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def __init__(
109109
quant_config=quant_config,
110110
use_rms_norm=config.mamba_rms_norm,
111111
prefix=f"{prefix}.mixer",
112-
chunk_size=config.mamba_chunk_size,
113112
)
114113
# n_groups is overridden later by `MambaMixer2`
115114
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ def __init__(self,
6969
rms_norm_eps=config.rms_norm_eps,
7070
activation=config.hidden_act,
7171
quant_config=quant_config,
72-
prefix=f"{prefix}.mixer",
73-
chunk_size=config.mamba_chunk_size)
72+
prefix=f"{prefix}.mixer")
7473

7574
self.block_sparse_moe = None
7675
if getattr(config, "num_local_experts", 0) > 0:

vllm/model_executor/models/mamba2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def __init__(self,
6262
rms_norm_eps=config.layer_norm_epsilon,
6363
activation=config.hidden_act,
6464
quant_config=quant_config,
65-
prefix=f"{prefix}.mixer",
66-
chunk_size=config.chunk_size)
65+
prefix=f"{prefix}.mixer")
6766

6867
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
6968

vllm/model_executor/models/nemotron_h.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ def __init__(
154154
activation=config.mamba_hidden_act,
155155
quant_config=quant_config,
156156
prefix=f"{prefix}.mixer",
157-
chunk_size=config.chunk_size,
158157
)
159158

160159
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

vllm/model_executor/models/zamba2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,7 @@ def __init__(self,
501501
rms_norm_eps=config.rms_norm_eps,
502502
activation="silu",
503503
quant_config=quant_config,
504-
prefix=f"{prefix}.mixer",
505-
chunk_size=config.chunk_size)
504+
prefix=f"{prefix}.mixer")
506505

507506
# Input normalization
508507
self.input_layernorm = RMSNorm(config.hidden_size,

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77

88
from vllm.attention.backends.abstract import AttentionBackend
9-
from vllm.config import VllmConfig, get_layers_from_vllm_config
109
from vllm.model_executor.layers.mamba.mamba2_metadata import (
1110
_query_start_loc_to_chunk_indices_offsets)
1211
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
@@ -20,15 +19,6 @@
2019
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2120

2221

23-
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
24-
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mamba2Layer
25-
layers = get_layers_from_vllm_config(vllm_config, Mamba2Layer)
26-
chunk_sizes = set(layer.chunk_size for layer in layers.values())
27-
assert len(
28-
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
29-
return chunk_sizes.pop()
30-
31-
3222
class Mamba2AttentionBackend(AttentionBackend):
3323

3424
@staticmethod
@@ -63,7 +53,10 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
6353
self.runner = runner
6454
self.kv_cache_spec = kv_cache_spec
6555
self.block_table = block_table
66-
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
56+
self.chunk_size = runner.vllm_config.model_config.get_mamba_chunk_size(
57+
)
58+
assert self.chunk_size is not None, (
59+
"chunk_size needs to be set in the model config for Mamba2 models")
6760

6861
def reorder_batch(self, input_batch: "InputBatch",
6962
scheduler_output: "SchedulerOutput") -> bool:

0 commit comments

Comments
 (0)