Skip to content

Enable V1 for Hybrid SSM/Attention Models #20016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
de4e3a2
working change
tdoublep Jun 23, 2025
617cd26
working changes
tdoublep Jun 24, 2025
9378c54
Merge branch 'main' into tpa-bamba-v1
tdoublep Jun 24, 2025
300d25f
Working version
tdoublep Jun 24, 2025
0822308
Add support + test for Zamba2
tdoublep Jun 24, 2025
a9fc73f
Fix memory layout for KV cache tensors in mamba case
tdoublep Jun 26, 2025
0e5b6de
kv_cache_interface.py: use utils.round_up
tdoublep Jun 26, 2025
89f504a
Fix unrelated CI test failing
tdoublep Jun 26, 2025
0b7783b
Enable bamba-9b in CI
tdoublep Jun 26, 2025
ded4833
Fix unrelated CI issue
tdoublep Jun 26, 2025
31db869
Add support for Nemotron-H
tdoublep Jun 27, 2025
c45e7e5
Enable Granite 4.0 (HybridMoE)
tdoublep Jun 27, 2025
0f20e11
Merge branch 'main' into tpa-bamba-v1
tdoublep Jun 27, 2025
e2c14ba
add support for Falcon H1
tdoublep Jun 27, 2025
c5a25eb
Fix overflow issue in mamba_ssm kernel
tdoublep Jun 30, 2025
cfc38c0
Resolve merge conflicts
tdoublep Jun 30, 2025
aaa6f0e
Add check for transformers min version
tdoublep Jun 30, 2025
d187bfd
Don't fail test if model is not in HF_EXAMPLE_MODELS
tdoublep Jun 30, 2025
fde28dc
Fix test_batching in same way
tdoublep Jun 30, 2025
1777fd1
Update tests/models/language/generation/test_hybrid.py
tdoublep Jul 1, 2025
58e66c9
Update vllm/model_executor/models/granitemoehybrid.py
tdoublep Jul 1, 2025
c2da03e
page_size -> num_element_per_page
tdoublep Jul 1, 2025
e0404c9
Clean up page size padding logic
tdoublep Jul 1, 2025
c74698d
gpu_model_runner.py: add TODO about batch reordering
tdoublep Jul 1, 2025
b72b729
Fix linting issue
tdoublep Jul 1, 2025
105737c
Validate memory layout for hybrid models against attention backends
tdoublep Jul 1, 2025
d8ff3b9
Adjust comment
tdoublep Jul 1, 2025
c857ec3
Merge branch 'main' into tpa-bamba-v1
tdoublep Jul 4, 2025
ea8cf32
Move memory layout check into separate function
tdoublep Jul 4, 2025
b38d3fb
Move logic to pad mamba page size into separate function
tdoublep Jul 4, 2025
e6b0015
Add extra todo
tdoublep Jul 4, 2025
14fd006
test_oracle.py: hybrid models now supported
tdoublep Jul 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from tests.models.registry import HF_EXAMPLE_MODELS
from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
Expand All @@ -19,31 +20,55 @@
SSM_MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
"mistralai/Mamba-Codestral-7B-v0.1",
]

HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev",
# NOTE: Currently the test failes due to HF transformers issue fixed in:
# https://github.com/huggingface/transformers/pull/39033
# We will enable vLLM test for Granite after next HF transformers release.
# "ibm-granite/granite-4.0-tiny-preview",
# NOTE: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
# not compatible with pip-compile.
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-ai-platform/Bamba-9B-v1",
"nvidia/Nemotron-H-8B-Base-8K",
Comment on lines +34 to +35
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these will be too large for the CI, which runs on L4 GPUs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The models are close to the limit, but the tests do pass in CI.

"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
]

HF_UNSUPPORTED_MODELS = [
# The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
"mistralai/Mamba-Codestral-7B-v0.1",
# Note: I'm not seeing the same output from vLLM V0 vs. HF transformers
# for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1
"nvidia/Nemotron-H-8B-Base-8K",
Comment on lines +41 to +48
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the Nemotron issue is also caused by this n_groups > 1 issue, which was fixed in huggingface/transformers#35943 but only for zamba2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth @tdoublep is this a blocker to this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.

# NOTE: Currently the test fails due to HF transformers issue fixed in:
# https://github.com/huggingface/transformers/pull/39033
# We will enable vLLM test for Granite after next HF transformers release.
"ibm-granite/granite-4.0-tiny-preview",
]

V1_SUPPORTED_MODELS = [
"mistralai/Mamba-Codestral-7B-v0.1",
"ibm-ai-platform/Bamba-9B-v1",
"Zyphra/Zamba2-1.2B-instruct",
"nvidia/Nemotron-H-8B-Base-8K",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
]

ATTN_BLOCK_SIZES = {
"ibm-ai-platform/Bamba-9B-v1": 528,
"Zyphra/Zamba2-1.2B-instruct": 80,
"nvidia/Nemotron-H-8B-Base-8K": 528,
"ibm-granite/granite-4.0-tiny-preview": 400,
"tiiuae/Falcon-H1-0.5B-Base": 800,
Comment on lines +65 to +69
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdoublep do you know what attention backends are used for these block sizes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using FlashInfer in all of the tests. This is required because it reorders the batch in the same way as the Mamba backend.

I saw that in some of the cases where the block size is really big (e.g., Falcon-H1) that it triggers some (jit?) compilation from FlashInfer when running serving benchmark. The results still look good though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To my understanding, FlashAttention is compatible with the decode-first order needed by mamba2 as it accepts arbitrary order. Current _may_reorder_batch only allows reordering the first batch because it is easy to implement and is enough for models without mamba. It can be refactor to support FlashAttention + mamba.

But due to the kv_cache_shape problem I mentioned below, we cannot use FlashAttention in this PR, so I think it's fine to use FlashInfer in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the extremely big block size, we only need to make $\sum_{l\in attention}(l.pagesizebytes) >= mamba.pagesizebytes$. But as it also needs some memory layout change, we can implement this optimization in a follow-up PR.

And does the JIT compilation happen during engine initialization or during execution?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But due to the kv_cache_shape problem I mentioned below, we cannot use FlashAttention in this PR, so I think it's fine to use FlashInfer in this PR.

Yes, you are right. The (2, num_blocks, ...) issue is the bigger reason why FlashAttention can't be supported right now.

Copy link
Member Author

@tdoublep tdoublep Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Don't entirely follow here - what is variable l? (Nevermind, I think I get it now, it's layers)

Copy link
Member Author

@tdoublep tdoublep Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And does the JIT compilation happen during engine initialization or during execution?

In the serving benchmark, I think it happens for the 1st test prompt that is sent to warm things up. Will double check (confirmed).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to call that kernel during the engine's warm up stage instead of the warmup prompt? (I'm OK with left a warning during engine warmup in this PR when strange block_size is used and fix it later)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK so this is false alarm, sorry. The jit recompilation that I'm seeing is not related to the large attention block size. The only hybrid model I see it happening for is Falcon-H1-0.5B-Base and I just checked that I see exactly the same jit compilation happening in V0 if I enable Flashinfer attention backend (e.g., also with block size 16). It must be something related to that specific model architecture (e.g., head size or number of heads).

}

# Avoid OOM
MAX_NUM_SEQS = 4

Expand All @@ -60,8 +85,16 @@ def test_models(
max_tokens: int,
num_logprobs: int,
) -> None:

try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass

with hf_runner(model) as hf_model:
if model != "mistralai/Mamba-Codestral-7B-v0.1":
if model not in HF_UNSUPPORTED_MODELS:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
Expand All @@ -72,12 +105,21 @@ def test_models(
example_prompts, max_tokens, num_logprobs)

if model in V1_SUPPORTED_MODELS:
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
block_size = ATTN_BLOCK_SIZES[model]
else:
block_size = 16

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False) as vllm_model:
enable_prefix_caching=False,
block_size=block_size) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
Expand Down Expand Up @@ -111,6 +153,14 @@ def test_batching(
max_tokens: int,
num_logprobs: int,
) -> None:

try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass

for_loop_outputs = []
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for prompt in example_prompts:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def check_available_online(
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base",
min_transformers_version="4.53"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
Expand Down
1 change: 0 additions & 1 deletion tests/v1/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder
"state-spaces/mamba-130m-hf", # mamba1
"hmellor/tiny-random-BambaForCausalLM", # hybrid
"BAAI/bge-m3", # embedding
]

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _selective_scan_update_kernel(
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr)
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += (state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
else:
Expand Down
45 changes: 30 additions & 15 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import nn
from transformers import BambaConfig

from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -36,7 +37,7 @@
from vllm.utils import LayerBlockType

from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
SupportsQuant)
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -97,7 +98,9 @@ def __init__(self,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size)

self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -313,10 +316,14 @@ def forward(

attn_metadata = get_forward_context().attn_metadata

mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None

if get_pp_group().is_first_rank:
if inputs_embeds is not None:
Expand All @@ -337,7 +344,8 @@ def forward(
num_attn += 1

layer_mamba_cache_params = None
if isinstance(layer, BambaMixerDecoderLayer):
if isinstance(layer,
BambaMixerDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)

Expand Down Expand Up @@ -411,7 +419,7 @@ def load_weights(self, weights: Iterable[tuple[str,


class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only, SupportsQuant):
IsHybrid, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -475,15 +483,22 @@ def forward(self,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)

self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())

mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)

Expand Down
57 changes: 40 additions & 17 deletions vllm/model_executor/models/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import nn
from transformers import FalconH1Config

from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand All @@ -33,8 +34,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsV0Only)
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
config: FalconH1Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
Expand All @@ -107,6 +108,8 @@ def __init__(
activation=config.hidden_act,
quant_config=quant_config,
use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size,
)
# n_groups is overridden later by `MambaMixer2`
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
Expand Down Expand Up @@ -316,18 +319,26 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()

# Instantiate the attention branch
self.self_attn = FalconH1AttentionDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)

# In V1 all attention/ssm layers must have
# different index in prefix
ssm_layer_idx = config.num_hidden_layers + layer_idx
ssm_prefix = prefix.split(".")[0] + f".{ssm_layer_idx}"

# Instantiate the SSM branch
self.mamba = FalconH1SSMDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=ssm_prefix,
)
self.ssm_out_multiplier = config.ssm_out_multiplier
self.ssm_in_multiplier = config.ssm_in_multiplier
Expand Down Expand Up @@ -452,10 +463,16 @@ def forward(
# proper continuous batching computation including
# chunked prefill
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)

if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None

if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds * self.embedding_multiplier
Expand All @@ -468,7 +485,9 @@ def forward(

for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
layer_mamba_cache_params = None
if mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
Expand All @@ -484,7 +503,7 @@ def forward(


class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only):
IsHybrid):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
Expand Down Expand Up @@ -558,15 +577,19 @@ def forward(
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
if self.mamba_cache is None:
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.lm_head.weight.dtype
if hasattr(self.lm_head, 'weight') else torch.bfloat16,
self.config.num_hidden_layers,
*self._get_mamba_cache_shape(),
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.lm_head.weight.dtype if hasattr(
self.lm_head, 'weight') else torch.bfloat16,
self.config.num_hidden_layers,
*self._get_mamba_cache_shape(),
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

hidden_states = self.model(
input_ids,
positions,
Expand Down
Loading