Skip to content

[V1] [Doc] Update V1 docs for Mamba models #20499

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 9, 2025
12 changes: 6 additions & 6 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ Specified using `--task generate`.
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | |
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | |
Copy link
Member

@DarkLight1337 DarkLight1337 Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the checkmark icon consistent with the other entries in this table? Same for the other tables

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checkmark icon looks identical to me. Could you paste a screenshot of the difference you are seeing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also visible in the git diff

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol, wth. this is what git diff (and the table) shows for me:
image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm let me just add suggestions in GitHub and you can commit the changes

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must be some unicode weirdness

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thanks for that

| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand All @@ -332,7 +332,7 @@ Specified using `--task generate`.
| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ |
| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ |
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | |
| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | |
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand All @@ -345,7 +345,7 @@ Specified using `--task generate`.
| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ |
| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | |
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | |
| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ |
| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | |
| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ |
Expand All @@ -357,14 +357,14 @@ Specified using `--task generate`.
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | |
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | |
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ |
| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | |
| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | |
| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ |
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ |
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
Expand All @@ -389,7 +389,7 @@ Specified using `--task generate`.
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | |
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | |
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | |
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | |

!!! note
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
Expand Down
13 changes: 10 additions & 3 deletions docs/usage/v1_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
| **Decoder-only Models** | <nobr>🚀 Optimized</nobr> |
| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> |
| **Embedding Models** | <nobr>🟢 Functional</nobr> |
| **Mamba Models** | <nobr>🚧 WIP (<gh-pr:19327>)</nobr> |
| **Mamba Models** | <nobr>🚧 WIP (<gh-pr:19327>, <gh-pr:20016>)</nobr> |
| **Multimodal Models** | <nobr>🟢 Functional</nobr> |

vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol.
Expand All @@ -104,8 +104,15 @@ to enable simultaneous generation and embedding using the same engine instance i

#### Mamba Models

Models using selective state-space mechanisms instead of standard transformer attention (e.g., `MambaForCausalLM`, `JambaForCausalLM`)
will be supported via <gh-pr:19327>.
Models using selective state-space mechanisms instead of standard transformer attention are partially supported.
Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet suported. Please note that these models currently require
enforcing eager mode and disabling prefix caching in V1.

Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
backend in V1.

#### Encoder-Decoder Models

Expand Down
16 changes: 1 addition & 15 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,6 @@
"tiiuae/Falcon-H1-0.5B-Base",
]

ATTN_BLOCK_SIZES = {
"ibm-ai-platform/Bamba-9B-v1": 528,
"Zyphra/Zamba2-1.2B-instruct": 80,
"nvidia/Nemotron-H-8B-Base-8K": 528,
"ibm-granite/granite-4.0-tiny-preview": 400,
"tiiuae/Falcon-H1-0.5B-Base": 800,
}

# Avoid OOM
MAX_NUM_SEQS = 4

Expand Down Expand Up @@ -105,11 +97,6 @@ def test_models(
example_prompts, max_tokens, num_logprobs)

if model in V1_SUPPORTED_MODELS:
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
block_size = ATTN_BLOCK_SIZES[model]
else:
block_size = 16

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
Expand All @@ -118,8 +105,7 @@ def test_models(
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False,
block_size=block_size) as vllm_model:
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
Expand Down
173 changes: 114 additions & 59 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2615,10 +2615,107 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
format. Layers that do not need KV cache are not included.
"""

block_size = self.vllm_config.cache_config.block_size
attn_block_size = self.vllm_config.cache_config.block_size

kv_cache_spec, attn_page_size, mamba_page_size = \
self._get_kv_cache_spec_for_block_size(
attn_block_size=attn_block_size,
)

if (attn_page_size is not None and mamba_page_size is not None
and attn_page_size != mamba_page_size):

kv_cache_spec = self._align_hybrid_page_size(
attn_page_size,
mamba_page_size,
attn_block_size,
)

return kv_cache_spec

def _align_hybrid_page_size(
self, attn_page_size: int, mamba_page_size: int,
attn_block_size: int) -> dict[str, KVCacheSpec]:
"""
Ensures that the page size for attention layers and mamba layers
are equal. This is achieved by (1) ensuring that the attention
block size is set such that the page size is greater than or
equal to the mamba page size and then (2) padding the mamba page
size to ensure they are precisely equal.

Args:
attn_page_size: Page size (bytes) for attention layers.
mamba_page_size: Page size (bytes) for mamba layers.
attn_block_size: Block size (tokens) for attention layers.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
"""

# check if we need to increase attention block size
if attn_page_size < mamba_page_size:

# attention page size (for 1 token)
attn_page_size_1 = attn_page_size // attn_block_size

# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size, 16 * attn_page_size_1)

logger.info(
"Increasing attention block size from %d to %d tokens "
"to ensure that attention page size is >= mamba page size.",
self.vllm_config.cache_config.block_size, attn_block_size)

mamba_page_size_padded = attn_block_size * attn_page_size_1

mamba_padding_pct = 100 * (mamba_page_size_padded -
mamba_page_size) / mamba_page_size

logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.", mamba_padding_pct)

kv_cache_spec, new_attn_page_size, new_mamba_page_size = \
self._get_kv_cache_spec_for_block_size(
attn_block_size=attn_block_size,
mamba_page_size_padded=mamba_page_size_padded,
)

assert new_attn_page_size == new_mamba_page_size, (
f"Attention page size ({new_attn_page_size}) does not equal "
"mamba page size ({new_mamba_page_size}) after alignment.")

return kv_cache_spec

def _get_kv_cache_spec_for_block_size(
self,
attn_block_size: int,
mamba_page_size_padded: Optional[int] = None,
) -> tuple[dict[str, KVCacheSpec], Optional[int], Optional[int]]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
Attention module in the static forward context, assuming a given
block size (in tokens) for attention layers.

Args:
attn_block_size: Block size (in tokens) for attention layers.
mamba_page_size_padded: Padded page size for mamba layers.
Returns:
KVCacheSpec: A dictionary mapping layer names to their KV cache
format. Layers that do not need KV cache are not included.
Optional[int]: Page size for attention layers.
Optional[int]: Page size for mamba layers.
"""

use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)

attn_page_size = None
for layer_name, attn_module in attn_layers.items():
if (kv_tgt_layer :=
attn_module.kv_sharing_target_layer_name) is not None:
Expand All @@ -2636,19 +2733,25 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
block_size=attn_block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
use_mla=use_mla)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
block_size=attn_block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
page_size = kv_cache_spec[layer_name].page_size_bytes
if attn_page_size is None:
attn_page_size = page_size
else:
assert page_size == attn_page_size

elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.
Expand All @@ -2661,6 +2764,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:

mamba_layers = get_layers_from_vllm_config(self.vllm_config,
MambaMixer2)
mamba_page_size = None
if len(mamba_layers) > 0:
if self.vllm_config.speculative_config is not None:
raise NotImplementedError(
Expand All @@ -2673,68 +2777,19 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
"Prefix caching is not supported for Mamba yet.")
max_model_len = self.vllm_config.model_config.max_model_len

page_size_padded = self._maybe_pad_mamba_page_size(
attn_layers, mamba_layers, kv_cache_spec, max_model_len,
block_size)

# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtype=self.kv_cache_dtype,
block_size=max_model_len,
page_size_padded=page_size_padded)

return kv_cache_spec

def _maybe_pad_mamba_page_size(
self,
attn_layers: dict[str, Attention],
mamba_layers: dict[str, MambaMixer2],
kv_cache_spec: dict[str, KVCacheSpec],
max_model_len: int,
block_size: int,
) -> Optional[int]:
"""
Ensure that page size of attention KV cache groups is greater than or
equal to the mamba KV cache groups. If not, we suggest to the user
how to set the attention block size to ensure that it is.

If the attention page size is strictly greater than the mamba page size,
we pad the mamba page size to make them equal.

Args:
attn_layers: Attention layers
mamba_layers: Mamba layers
kv_cache_spec: KV cache spec (populated with attention layers)

Returns:
Optional[int]: Mamba page size with padding (None if no padding).
"""
page_size_padded=mamba_page_size_padded)

if len(attn_layers) == 0:
return None
page_size = kv_cache_spec[layer_name].page_size_bytes
if mamba_page_size is None:
mamba_page_size = page_size
else:
assert page_size == mamba_page_size

attn_layer_name = next(iter(attn_layers))
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
mamba_layer_name = next(iter(mamba_layers))
mamba_page_size = MambaSpec(
shapes=mamba_layers[mamba_layer_name].get_state_shape(),
dtype=self.kv_cache_dtype,
block_size=max_model_len).page_size_bytes
if attn_page_size < mamba_page_size:
# attention page size (for 16 tokens)
attn_page_size_16 = 16 * attn_page_size // block_size
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
suggest_attn_block_size = 16 * cdiv(mamba_page_size,
attn_page_size_16)
raise ValueError(
"Attention block size should be increased to at least "
f"{suggest_attn_block_size} in order to match "
"the mamba page size")

return attn_page_size
return kv_cache_spec, attn_page_size, mamba_page_size