Skip to content

Enable V1 for Hybrid SSM/Attention Models #20016

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jul 4, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
de4e3a2
working change
tdoublep Jun 23, 2025
617cd26
working changes
tdoublep Jun 24, 2025
9378c54
Merge branch 'main' into tpa-bamba-v1
tdoublep Jun 24, 2025
300d25f
Working version
tdoublep Jun 24, 2025
0822308
Add support + test for Zamba2
tdoublep Jun 24, 2025
a9fc73f
Fix memory layout for KV cache tensors in mamba case
tdoublep Jun 26, 2025
0e5b6de
kv_cache_interface.py: use utils.round_up
tdoublep Jun 26, 2025
89f504a
Fix unrelated CI test failing
tdoublep Jun 26, 2025
0b7783b
Enable bamba-9b in CI
tdoublep Jun 26, 2025
ded4833
Fix unrelated CI issue
tdoublep Jun 26, 2025
31db869
Add support for Nemotron-H
tdoublep Jun 27, 2025
c45e7e5
Enable Granite 4.0 (HybridMoE)
tdoublep Jun 27, 2025
0f20e11
Merge branch 'main' into tpa-bamba-v1
tdoublep Jun 27, 2025
e2c14ba
add support for Falcon H1
tdoublep Jun 27, 2025
c5a25eb
Fix overflow issue in mamba_ssm kernel
tdoublep Jun 30, 2025
cfc38c0
Resolve merge conflicts
tdoublep Jun 30, 2025
aaa6f0e
Add check for transformers min version
tdoublep Jun 30, 2025
d187bfd
Don't fail test if model is not in HF_EXAMPLE_MODELS
tdoublep Jun 30, 2025
fde28dc
Fix test_batching in same way
tdoublep Jun 30, 2025
1777fd1
Update tests/models/language/generation/test_hybrid.py
tdoublep Jul 1, 2025
58e66c9
Update vllm/model_executor/models/granitemoehybrid.py
tdoublep Jul 1, 2025
c2da03e
page_size -> num_element_per_page
tdoublep Jul 1, 2025
e0404c9
Clean up page size padding logic
tdoublep Jul 1, 2025
c74698d
gpu_model_runner.py: add TODO about batch reordering
tdoublep Jul 1, 2025
b72b729
Fix linting issue
tdoublep Jul 1, 2025
105737c
Validate memory layout for hybrid models against attention backends
tdoublep Jul 1, 2025
d8ff3b9
Adjust comment
tdoublep Jul 1, 2025
c857ec3
Merge branch 'main' into tpa-bamba-v1
tdoublep Jul 4, 2025
ea8cf32
Move memory layout check into separate function
tdoublep Jul 4, 2025
b38d3fb
Move logic to pad mamba page size into separate function
tdoublep Jul 4, 2025
e6b0015
Add extra todo
tdoublep Jul 4, 2025
14fd006
test_oracle.py: hybrid models now supported
tdoublep Jul 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,25 @@
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-ai-platform/Bamba-9B-v1",
]

V1_SUPPORTED_MODELS = [
"mistralai/Mamba-Codestral-7B-v0.1",
"ibm-ai-platform/Bamba-9B-v1",
]

ATTN_BLOCK_SIZES = {
"ibm-ai-platform/Bamba-9B-v1": 528,
}

# Avoid OOM
MAX_NUM_SEQS = 4


@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -70,10 +76,16 @@ def test_models(
if model in V1_SUPPORTED_MODELS:
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
# set attn block size to match mamba state
block_size = ATTN_BLOCK_SIZES.get(model, 16)
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False) as vllm_model:
enable_prefix_caching=False,
block_size=block_size) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
Expand Down
45 changes: 30 additions & 15 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import nn
from transformers import BambaConfig

from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -36,7 +37,7 @@
from vllm.utils import LayerBlockType

from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
SupportsQuant)
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
Expand Down Expand Up @@ -97,7 +98,9 @@ def __init__(self,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size)

self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -313,10 +316,14 @@ def forward(

attn_metadata = get_forward_context().attn_metadata

mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None

if get_pp_group().is_first_rank:
if inputs_embeds is not None:
Expand All @@ -337,7 +344,8 @@ def forward(
num_attn += 1

layer_mamba_cache_params = None
if isinstance(layer, BambaMixerDecoderLayer):
if isinstance(layer,
BambaMixerDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)

Expand Down Expand Up @@ -411,7 +419,7 @@ def load_weights(self, weights: Iterable[tuple[str,


class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only, SupportsQuant):
IsHybrid, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down Expand Up @@ -475,15 +483,22 @@ def forward(self,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:

num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)

self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())

mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)

self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)

Expand Down
11 changes: 8 additions & 3 deletions vllm/v1/core/kv_cache_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.enable_caching = enable_caching

self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
Expand Down Expand Up @@ -267,9 +268,13 @@ def verify_and_split_kv_cache_groups(self) -> None:

self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
assert self.other_block_size % self.full_attention_block_size == 0, (
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.")

if self.enable_caching:
# this requirement is only needed for the prefix caching logic
divisible = self.other_block_size % self.full_attention_block_size
assert divisible == 0, (
"KVCacheCoordinator assumes the block_size of full "
"attention layers is divisible by other layers now.")

if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
Expand Down
16 changes: 10 additions & 6 deletions vllm/v1/core/kv_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,15 @@ def __init__(
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size

self.block_size: Optional[int] = None
if self.enable_caching:
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size

self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
Expand Down Expand Up @@ -154,6 +157,7 @@ def get_computed_blocks(self,
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
assert self.block_size is not None
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,9 +864,11 @@ def _get_kv_cache_config_uniform_page_size(
kv_cache_groups=kv_cache_groups,
)

min_block_size = min(
[group.kv_cache_spec.block_size for group in kv_cache_groups])

# Print the KV cache size and maximum concurrency.
num_tokens = num_blocks // len(
grouped_layers) * vllm_config.cache_config.block_size
num_tokens = num_blocks // len(grouped_layers) * min_block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/kv_cache_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
multiple_of: Optional[int]

def __post_init__(self):
self.num_elements = sum(prod(shape) for shape in self.shapes)
Expand All @@ -169,7 +170,10 @@ def type_id(self) -> str:

@property
def page_size_bytes(self) -> int:
return self.num_elements * get_dtype_size(self.dtype)
page_size = self.num_elements * get_dtype_size(self.dtype)
if self.multiple_of is not None:
page_size = cdiv(page_size, self.multiple_of) * self.multiple_of
return page_size

def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# We allocate 1 block for each request now, so max_memory_usage_bytes is
Expand Down
28 changes: 26 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2399,7 +2399,8 @@ def _reshape_kv_cache_tensors(
tensor = tensor.view(dtype).view(target_shape)
state_tensors.append(tensor)
start_pos += size_in_bytes
assert start_pos == raw_tensor.numel()
if kv_cache_spec.multiple_of is None:
assert start_pos == raw_tensor.numel()
kv_caches[layer_name] = tuple(state_tensors)
else:
raise NotImplementedError
Expand Down Expand Up @@ -2513,6 +2514,15 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
mamba_layers = get_layers_from_vllm_config(self.vllm_config,
MambaMixer2)
if len(mamba_layers) > 0:
if len(attn_layers) > 0:
# Mamba state must be padded to an integer number of
# 16th tokens worth of attention pages
attn_layer_name = next(iter(attn_layers))
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clarify here, kv_cache_spec[attn_layer_name].page_size_bytes is for a single token stored in the mamba cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly. kv_cache_spec[attn_layer_name].page_size_bytes gives us the size in bytes for an attention page that stores block_size tokens. The goal here is to figure out the size in bytes of an attention page that stores exactly 16 tokens. Hence why on the next line we divide by block_size to normalize it and multiply 16.

Why do we want to know the size in bytes of an attention page that stores 16 tokens? It's because we want to ensure that the mamba page size if padded up to a value that makes it possible for the user to align the attention page size with. Since the user can only set attention block size in multiples of 16, that is why the factor of 16 is needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment to explain the magic number "16"?
And to confirm, arbitrary block_size that is supported by the attention backend is OK, but as we don't know which block_size each attention backend support, we have to hardcode "16" which is supported by most attention backends.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth If I remember correctly, you told me that FlashMLA does not support block_size 16. Can you confirm? If it is true, we may need some other assertion here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And to confirm, arbitrary block_size that is supported by the attention backend is OK, but as we don't know which block_size each attention backend support, we have to hardcode "16" which is supported by most attention backends.

So actually, I think this magic "16" is not necessarily needed. The constraint that the block size must be a multiple of 16 is only coming from FlashAttention backend (which is not compatible with Mamba right now for reasons discussed). I just checked and with FlashInfer it is possible to set the block size to any number.

Still, probably makes sense to keep "16" since we want to support FlashAttention in near future. Do you agree @heheda12345 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment to explain the magic number "16"?

Done

multiple_of = 16 * attn_page_size // block_size
else:
multiple_of = None

if self.vllm_config.speculative_config is not None:
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
Expand All @@ -2529,5 +2539,19 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtype=self.kv_cache_dtype,
block_size=max_model_len)
block_size=max_model_len,
multiple_of=multiple_of)

if len(attn_layers) > 0:
mamba_layer_name = next(iter(mamba_layers))
mamba_page_size = kv_cache_spec[
mamba_layer_name].page_size_bytes
if attn_page_size < mamba_page_size:
required_attn_block_size = cdiv(mamba_page_size,
multiple_of) * 16
raise ValueError(
"Attention block size must be increased to "
f"{required_attn_block_size} in order to match "
"the mamba page size")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a fairly reasonable approach, especially for a first pass

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main question is whether we are OK with vLLM V1 failing under default parameters for hybrid models? If not, we could automatically scale up the attention block size and log what is happening to inform the user, rather than explicitly ask the user to do it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a better option, paired with logging a warning. But that could also wait for a follow-up

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the printed block_size is not "must be", but a value that we are suggesting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to be at least this value in order to work though right? I can't really think of practical scenarios when we would want the attention page size to the bigger than the mamba page size. Mamba page size is typically orders of magnitude bigger than attention page size (per token). If the attention page size is bigger, we will need to pad mamba page to align it and waste more space.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed the language in the exception, please take a look.


return kv_cache_spec