Skip to content

Commit 14a4fa7

Browse files
committed
introduce Mamba2Layer abstraction to allow mamba layers other than MambaMixer2 in v1 engine
Signed-off-by: nopperl <54780682+nopperl@users.noreply.github.com>
1 parent d8ee5a2 commit 14a4fa7

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from abc import ABC, abstractmethod
5+
from collections.abc import Iterable
46
from typing import Optional, Union
57

68
import torch
@@ -216,9 +218,34 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
216218
return loader
217219

218220

221+
class Mamba2Layer(ABC):
222+
"""
223+
Base class for all Mamba2 layers which support the v1 engine.
224+
Inherit from this class if you implement a custom Mamba2 layer.
225+
"""
226+
227+
chunk_size: int
228+
229+
# Contains the KV cache (mamba state) for the layer
230+
# in the shape specified by `self.get_state_shape`.
231+
# The outer list is for v0 PP virtual engine. Though this code path
232+
# only runs for v1, we have to do this to unify with the interface
233+
# of Attention + v0 PP.
234+
kv_cache: list[tuple[torch.Tensor]]
235+
236+
@abstractmethod
237+
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
238+
"""
239+
Defines the shape of the mamba state.
240+
Usually, the mamba state is a (conv_state, ssm_state) tuple.
241+
In this case, returns (conv_state_shape, ssm_state_shape).
242+
"""
243+
pass
244+
245+
219246
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
220247
@CustomOp.register("mamba_mixer2")
221-
class MambaMixer2(CustomOp):
248+
class MambaMixer2(Mamba2Layer, CustomOp):
222249
"""
223250
Compute ∆, A, B, C, and D the state space parameters and compute
224251
the `contextualized_states`. A, D are input independent

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222

2323
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
24-
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
25-
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
24+
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mamba2Layer
25+
layers = get_layers_from_vllm_config(vllm_config, Mamba2Layer)
2626
chunk_sizes = set(layer.chunk_size for layer in layers.values())
2727
assert len(
2828
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from vllm.forward_context import (DPMetadata, get_forward_context,
3232
set_forward_context)
3333
from vllm.logger import init_logger
34-
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
34+
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mamba2Layer
3535
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
3636
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
3737
from vllm.model_executor.models.interfaces import (has_step_pooler,
@@ -2660,7 +2660,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
26602660
f"Unknown attention type: {attn_module.attn_type}")
26612661

26622662
mamba_layers = get_layers_from_vllm_config(self.vllm_config,
2663-
MambaMixer2)
2663+
Mamba2Layer)
26642664
if len(mamba_layers) > 0:
26652665
if self.vllm_config.speculative_config is not None:
26662666
raise NotImplementedError(
@@ -2691,7 +2691,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
26912691
def _maybe_pad_mamba_page_size(
26922692
self,
26932693
attn_layers: dict[str, Attention],
2694-
mamba_layers: dict[str, MambaMixer2],
2694+
mamba_layers: dict[str, Mamba2Layer],
26952695
kv_cache_spec: dict[str, KVCacheSpec],
26962696
max_model_len: int,
26972697
block_size: int,

0 commit comments

Comments
 (0)