Skip to content

[V1] [Hybrid] Enable piecewise CUDA Graph for mamba layers #21194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ def test_models(
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4312,6 +4312,7 @@ def set_splitting_ops_for_v1(self):
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]


Expand Down
75 changes: 66 additions & 9 deletions vllm/model_executor/layers/mamba/mamba_mixer2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
Expand All @@ -33,6 +33,8 @@
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata

# Added by the IBM Team, 2024
Expand Down Expand Up @@ -424,14 +426,36 @@ def __init__(
def forward_native(
self,
hidden_states: torch.Tensor,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
pass

def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata, mup_vector)
else:
torch.ops.vllm.mamba_mixer2(
hidden_states,
output,
self.prefix,
mup_vector,
)

def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -517,25 +541,26 @@ def forward_cuda(
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes

# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C,
hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt,
dt[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
state_indices_tensor[:num_actual_tokens],
[num_decodes, num_prefills],
dim=0,
)
Expand Down Expand Up @@ -696,11 +721,10 @@ def forward_cuda(
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(hidden_states, gate)
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])

# 5. Final linear projection
out, _ = self.out_proj(hidden_states)
return out
output[:num_actual_tokens], _ = self.out_proj(hidden_states)

def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return get_mamba_state_shape(
Expand All @@ -712,3 +736,36 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)


def mamba_mixer2(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None,
mamba2_metadata=None,
mup_vector=mup_vector)


def mamba_mixer2_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
) -> None:
return


direct_register_custom_op(
op_name="mamba_mixer2",
op_func=mamba_mixer2,
Comment on lines +766 to +767
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use a common op name here (unified_ssm_mixer?) so we can avoid adding a bunch of cases to splitting_ops (fine for this PR though)

mutates_args=["output"],
fake_impl=mamba_mixer2_fake,
dispatch_key=current_platform.dispatch_key,
)
11 changes: 6 additions & 5 deletions vllm/model_executor/models/bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
Expand Down Expand Up @@ -122,11 +123,10 @@ def forward(
hidden_states, residual = self.input_layernorm(
hidden_states, residual)

hidden_states = self.mamba(hidden_states, mamba_cache_params,
mamba2_metadata)
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
Comment on lines +126 to +127
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to change the code back to output = self.mamba(...)?

# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual

Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(
self.max_position_embeddings = max_position_embeddings

if hasattr(config, "partial_rotary_factor"):
rotary_dim = self.head_dim * config.partial_rotary_factor
rotary_dim = int(self.head_dim * config.partial_rotary_factor)
elif hasattr(config, "attn_rotary_emb"):
rotary_dim = config.attn_rotary_emb # for backward compatibility
else:
Expand Down Expand Up @@ -258,6 +258,7 @@ def forward(
}


@support_torch_compile
class BambaModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/falcon_h1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
Expand Down Expand Up @@ -179,13 +180,15 @@ def forward(
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
hidden_states = self.mamba(
output = torch.empty_like(hidden_states)
self.mamba(
hidden_states,
output,
mamba_cache_params,
mamba2_metadata=mamba2_metadata,
mup_vector=self.mup_vector,
)
return hidden_states, residual
return output, residual


class FalconH1AttentionDecoderLayer(nn.Module):
Expand Down Expand Up @@ -398,6 +401,7 @@ def forward(
return hidden_states


@support_torch_compile
class FalconH1Model(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/models/granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
Expand Down Expand Up @@ -104,9 +105,9 @@ def forward(
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.mamba(hidden_states, mamba_cache_params,
mamba2_metadata)
hidden_states = residual + hidden_states * self.residual_multiplier
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
hidden_states = residual + output * self.residual_multiplier

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
Expand Down Expand Up @@ -307,6 +308,7 @@ def forward(
}


@support_torch_compile
class GraniteMoeHybridModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/models/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
Expand Down Expand Up @@ -79,11 +80,12 @@ def forward(
else:
hidden_states, residual = self.norm(hidden_states, residual)

hidden_states = self.mixer(hidden_states, mamba_cache_params,
mamba2_metadata)
return hidden_states, residual
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
return output, residual


@support_torch_compile
class Mamba2Model(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/models/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
Expand Down Expand Up @@ -172,9 +173,9 @@ def forward(
else:
hidden_states, residual = self.norm(hidden_states, residual)

hidden_states = self.mixer(hidden_states, mamba_cache_params,
mamba2_metadata)
return hidden_states, residual
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
return output, residual


class NemotronHAttention(nn.Module):
Expand Down Expand Up @@ -292,6 +293,7 @@ def forward(
}


@support_torch_compile
class NemotronHModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
8 changes: 6 additions & 2 deletions vllm/model_executor/models/zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
Expand Down Expand Up @@ -548,14 +549,16 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)

# Process through Mamba mixer
hidden_states = self.mamba(
output = torch.empty_like(hidden_states)
self.mamba(
hidden_states,
output,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)

# residual connection after mamba
hidden_states = residual + hidden_states
hidden_states = residual + output

return hidden_states

Expand Down Expand Up @@ -646,6 +649,7 @@ def forward(
return layer_outputs


@support_torch_compile
class Zamba2Model(nn.Module):
"""Core Zamba2 model combining transformer and Mamba architectures.

Expand Down
3 changes: 0 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2716,9 +2716,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
if self.vllm_config.speculative_config is not None:
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if not self.vllm_config.model_config.enforce_eager:
raise NotImplementedError(
"Mamba with cuda graph is not supported yet.")
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
Expand Down