Skip to content

Commit 5a6e465

Browse files
committed
introduce Mamba2Layer abstraction to allow mamba layers other than MambaMixer2 in v1 engine
Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
1 parent d8ee5a2 commit 5a6e465

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from abc import ABC, abstractmethod
45
from typing import Optional, Union
56

67
import torch
@@ -216,9 +217,34 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
216217
return loader
217218

218219

220+
class Mamba2Layer(ABC):
221+
"""
222+
Base class for all Mamba2 layers which support the v1 engine.
223+
Inherit from this class if you implement a custom Mamba2 layer.
224+
"""
225+
226+
chunk_size: int
227+
228+
# Contains the KV cache (mamba state) for the layer
229+
# in the shape specified by `self.get_state_shape`.
230+
# The outer list is for v0 PP virtual engine. Though this code path
231+
# only runs for v1, we have to do this to unify with the interface
232+
# of Attention + v0 PP.
233+
kv_cache: list[tuple[torch.Tensor]]
234+
235+
@abstractmethod
236+
def get_state_shape(self) -> tuple[tuple[int, ...]]:
237+
"""
238+
Defines the shape of the mamba state.
239+
Usually, the mamba state is a (conv_state, ssm_state) tuple.
240+
In this case, returns (conv_state_shape, ssm_state_shape).
241+
"""
242+
pass
243+
244+
219245
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
220246
@CustomOp.register("mamba_mixer2")
221-
class MambaMixer2(CustomOp):
247+
class MambaMixer2(Mamba2Layer, CustomOp):
222248
"""
223249
Compute ∆, A, B, C, and D the state space parameters and compute
224250
the `contextualized_states`. A, D are input independent

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222

2323
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
24-
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
25-
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
24+
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mamba2Layer
25+
layers = get_layers_from_vllm_config(vllm_config, Mamba2Layer)
2626
chunk_sizes = set(layer.chunk_size for layer in layers.values())
2727
assert len(
2828
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from vllm.forward_context import (DPMetadata, get_forward_context,
3232
set_forward_context)
3333
from vllm.logger import init_logger
34-
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
34+
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mamba2Layer
3535
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
3636
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
3737
from vllm.model_executor.models.interfaces import (has_step_pooler,
@@ -2660,7 +2660,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
26602660
f"Unknown attention type: {attn_module.attn_type}")
26612661

26622662
mamba_layers = get_layers_from_vllm_config(self.vllm_config,
2663-
MambaMixer2)
2663+
Mamba2Layer)
26642664
if len(mamba_layers) > 0:
26652665
if self.vllm_config.speculative_config is not None:
26662666
raise NotImplementedError(
@@ -2691,7 +2691,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
26912691
def _maybe_pad_mamba_page_size(
26922692
self,
26932693
attn_layers: dict[str, Attention],
2694-
mamba_layers: dict[str, MambaMixer2],
2694+
mamba_layers: dict[str, Mamba2Layer],
26952695
kv_cache_spec: dict[str, KVCacheSpec],
26962696
max_model_len: int,
26972697
block_size: int,

0 commit comments

Comments
 (0)