Skip to content

Commit 881e3cb

Browse files
authored
[V1] [Hybrid] Enable piecewise CUDA Graph for mamba layers (#21194)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent 9f414a1 commit 881e3cb

File tree

10 files changed

+100
-31
lines changed

10 files changed

+100
-31
lines changed

tests/models/language/generation/test_hybrid.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def test_models(
104104
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
105105
with vllm_runner(model,
106106
max_num_seqs=MAX_NUM_SEQS,
107-
enforce_eager=True,
108107
enable_prefix_caching=False) as vllm_model:
109108
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
110109
example_prompts, max_tokens, num_logprobs)

vllm/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4312,6 +4312,7 @@ def set_splitting_ops_for_v1(self):
43124312
self.splitting_ops = [] if self.full_cuda_graph else [
43134313
"vllm.unified_attention",
43144314
"vllm.unified_attention_with_output",
4315+
"vllm.mamba_mixer2",
43154316
]
43164317

43174318

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
get_tensor_model_parallel_world_size,
1414
tensor_model_parallel_all_gather,
1515
tensor_model_parallel_all_reduce)
16-
from vllm.forward_context import get_forward_context
16+
from vllm.forward_context import ForwardContext, get_forward_context
1717
from vllm.model_executor.custom_op import CustomOp
1818
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1919
RowParallelLinear)
@@ -33,6 +33,8 @@
3333
LoaderFunction, composed_weight_loader, sharded_weight_loader)
3434
from vllm.model_executor.models.mamba_cache import MambaCacheParams
3535
from vllm.model_executor.utils import set_weight_attrs
36+
from vllm.platforms import current_platform
37+
from vllm.utils import direct_register_custom_op
3638
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
3739

3840
# Added by the IBM Team, 2024
@@ -424,14 +426,36 @@ def __init__(
424426
def forward_native(
425427
self,
426428
hidden_states: torch.Tensor,
427-
conv_state: torch.Tensor,
428-
ssm_state: torch.Tensor,
429+
output: torch.Tensor,
430+
mamba_cache_params: MambaCacheParams,
431+
mamba2_metadata: Mamba2Metadata,
432+
mup_vector: Optional[torch.Tensor] = None,
429433
):
430434
pass
431435

436+
def forward(
437+
self,
438+
hidden_states: torch.Tensor,
439+
output: torch.Tensor,
440+
mamba_cache_params: MambaCacheParams,
441+
mamba2_metadata: Mamba2Metadata,
442+
mup_vector: Optional[torch.Tensor] = None,
443+
):
444+
if not envs.VLLM_USE_V1:
445+
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
446+
mamba2_metadata, mup_vector)
447+
else:
448+
torch.ops.vllm.mamba_mixer2(
449+
hidden_states,
450+
output,
451+
self.prefix,
452+
mup_vector,
453+
)
454+
432455
def forward_cuda(
433456
self,
434457
hidden_states: torch.Tensor,
458+
output: torch.Tensor,
435459
mamba_cache_params: MambaCacheParams,
436460
mamba2_metadata: Mamba2Metadata,
437461
mup_vector: Optional[torch.Tensor] = None,
@@ -517,25 +541,26 @@ def forward_cuda(
517541
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
518542
has_prefill = num_prefills > 0
519543
has_decode = num_decodes > 0
544+
num_actual_tokens = num_prefill_tokens + num_decodes
520545

521546
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
522547
# Separate prefill and decode by splitting varlen input
523548
# Split along token dimension
524549
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
525550
if envs.VLLM_USE_V1:
526551
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
527-
hidden_states_B_C,
552+
hidden_states_B_C[:num_actual_tokens],
528553
[num_decodes, num_prefill_tokens],
529554
dim=0,
530555
)
531556
dt_d, dt_p = torch.split(
532-
dt,
557+
dt[:num_actual_tokens],
533558
[num_decodes, num_prefill_tokens],
534559
dim=0,
535560
)
536561
# Split along batch dimension
537562
state_indices_tensor_d, state_indices_tensor_p = torch.split(
538-
state_indices_tensor,
563+
state_indices_tensor[:num_actual_tokens],
539564
[num_decodes, num_prefills],
540565
dim=0,
541566
)
@@ -696,11 +721,10 @@ def forward_cuda(
696721
# GatedRMSNorm internally applying SiLU to the gate
697722
# SiLU is applied internally before normalization, unlike standard
698723
# norm usage
699-
hidden_states = self.norm(hidden_states, gate)
724+
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
700725

701726
# 5. Final linear projection
702-
out, _ = self.out_proj(hidden_states)
703-
return out
727+
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
704728

705729
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
706730
return get_mamba_state_shape(
@@ -712,3 +736,36 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
712736
state_size=self.ssm_state_size,
713737
conv_kernel=self.conv_kernel_size,
714738
)
739+
740+
741+
def mamba_mixer2(
742+
hidden_states: torch.Tensor,
743+
output: torch.Tensor,
744+
layer_name: str,
745+
mup_vector: Optional[torch.Tensor] = None,
746+
) -> None:
747+
forward_context: ForwardContext = get_forward_context()
748+
self = forward_context.no_compile_layers[layer_name]
749+
self.forward_cuda(hidden_states=hidden_states,
750+
output=output,
751+
mamba_cache_params=None,
752+
mamba2_metadata=None,
753+
mup_vector=mup_vector)
754+
755+
756+
def mamba_mixer2_fake(
757+
hidden_states: torch.Tensor,
758+
output: torch.Tensor,
759+
layer_name: str,
760+
mup_vector: Optional[torch.Tensor] = None,
761+
) -> None:
762+
return
763+
764+
765+
direct_register_custom_op(
766+
op_name="mamba_mixer2",
767+
op_func=mamba_mixer2,
768+
mutates_args=["output"],
769+
fake_impl=mamba_mixer2_fake,
770+
dispatch_key=current_platform.dispatch_key,
771+
)

vllm/model_executor/models/bamba.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from vllm import envs
1313
from vllm.attention.layer import Attention
14+
from vllm.compilation.decorators import support_torch_compile
1415
from vllm.config import CacheConfig, VllmConfig
1516
from vllm.distributed import get_tensor_model_parallel_world_size
1617
from vllm.distributed.parallel_state import get_pp_group
@@ -122,11 +123,10 @@ def forward(
122123
hidden_states, residual = self.input_layernorm(
123124
hidden_states, residual)
124125

125-
hidden_states = self.mamba(hidden_states, mamba_cache_params,
126-
mamba2_metadata)
126+
output = torch.empty_like(hidden_states)
127+
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
127128
# Fully Connected
128-
hidden_states, residual = self.pre_ff_layernorm(
129-
hidden_states, residual)
129+
hidden_states, residual = self.pre_ff_layernorm(output, residual)
130130
hidden_states = self.feed_forward(hidden_states)
131131
return hidden_states, residual
132132

@@ -169,7 +169,7 @@ def __init__(
169169
self.max_position_embeddings = max_position_embeddings
170170

171171
if hasattr(config, "partial_rotary_factor"):
172-
rotary_dim = self.head_dim * config.partial_rotary_factor
172+
rotary_dim = int(self.head_dim * config.partial_rotary_factor)
173173
elif hasattr(config, "attn_rotary_emb"):
174174
rotary_dim = config.attn_rotary_emb # for backward compatibility
175175
else:
@@ -258,6 +258,7 @@ def forward(
258258
}
259259

260260

261+
@support_torch_compile
261262
class BambaModel(nn.Module):
262263

263264
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/falcon_h1.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vllm import envs
1212
from vllm.attention.layer import Attention
13+
from vllm.compilation.decorators import support_torch_compile
1314
from vllm.config import CacheConfig, VllmConfig
1415
from vllm.distributed import get_tensor_model_parallel_world_size
1516
from vllm.distributed.parallel_state import get_pp_group
@@ -179,13 +180,15 @@ def forward(
179180
mamba2_metadata: Mamba2Metadata,
180181
**kwargs,
181182
):
182-
hidden_states = self.mamba(
183+
output = torch.empty_like(hidden_states)
184+
self.mamba(
183185
hidden_states,
186+
output,
184187
mamba_cache_params,
185188
mamba2_metadata=mamba2_metadata,
186189
mup_vector=self.mup_vector,
187190
)
188-
return hidden_states, residual
191+
return output, residual
189192

190193

191194
class FalconH1AttentionDecoderLayer(nn.Module):
@@ -398,6 +401,7 @@ def forward(
398401
return hidden_states
399402

400403

404+
@support_torch_compile
401405
class FalconH1Model(nn.Module):
402406

403407
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from vllm import envs
1313
from vllm.attention.layer import Attention
14+
from vllm.compilation.decorators import support_torch_compile
1415
from vllm.config import CacheConfig, VllmConfig
1516
from vllm.distributed import get_tensor_model_parallel_world_size
1617
from vllm.distributed.parallel_state import get_pp_group
@@ -104,9 +105,9 @@ def forward(
104105
):
105106
residual = hidden_states
106107
hidden_states = self.input_layernorm(hidden_states)
107-
hidden_states = self.mamba(hidden_states, mamba_cache_params,
108-
mamba2_metadata)
109-
hidden_states = residual + hidden_states * self.residual_multiplier
108+
output = torch.empty_like(hidden_states)
109+
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
110+
hidden_states = residual + output * self.residual_multiplier
110111

111112
residual = hidden_states
112113
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -307,6 +308,7 @@ def forward(
307308
}
308309

309310

311+
@support_torch_compile
310312
class GraniteMoeHybridModel(nn.Module):
311313

312314
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/mamba2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vllm import envs
1212
from vllm.attention.backends.abstract import AttentionMetadata
13+
from vllm.compilation.decorators import support_torch_compile
1314
from vllm.config import VllmConfig
1415
from vllm.distributed.parallel_state import get_pp_group
1516
from vllm.forward_context import get_forward_context
@@ -79,11 +80,12 @@ def forward(
7980
else:
8081
hidden_states, residual = self.norm(hidden_states, residual)
8182

82-
hidden_states = self.mixer(hidden_states, mamba_cache_params,
83-
mamba2_metadata)
84-
return hidden_states, residual
83+
output = torch.empty_like(hidden_states)
84+
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
85+
return output, residual
8586

8687

88+
@support_torch_compile
8789
class Mamba2Model(nn.Module):
8890

8991
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/nemotron_h.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from vllm import envs
2727
from vllm.attention.layer import Attention
28+
from vllm.compilation.decorators import support_torch_compile
2829
from vllm.config import CacheConfig, VllmConfig
2930
from vllm.distributed import get_tensor_model_parallel_world_size
3031
from vllm.distributed.parallel_state import get_pp_group
@@ -172,9 +173,9 @@ def forward(
172173
else:
173174
hidden_states, residual = self.norm(hidden_states, residual)
174175

175-
hidden_states = self.mixer(hidden_states, mamba_cache_params,
176-
mamba2_metadata)
177-
return hidden_states, residual
176+
output = torch.empty_like(hidden_states)
177+
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
178+
return output, residual
178179

179180

180181
class NemotronHAttention(nn.Module):
@@ -292,6 +293,7 @@ def forward(
292293
}
293294

294295

296+
@support_torch_compile
295297
class NemotronHModel(nn.Module):
296298

297299
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm/model_executor/models/zamba2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from vllm import envs
1919
from vllm.attention.layer import Attention
20+
from vllm.compilation.decorators import support_torch_compile
2021
from vllm.config import CacheConfig, VllmConfig
2122
from vllm.distributed import get_tensor_model_parallel_world_size
2223
from vllm.forward_context import get_forward_context
@@ -548,14 +549,16 @@ def forward(
548549
hidden_states = self.input_layernorm(hidden_states)
549550

550551
# Process through Mamba mixer
551-
hidden_states = self.mamba(
552+
output = torch.empty_like(hidden_states)
553+
self.mamba(
552554
hidden_states,
555+
output,
553556
mamba_cache_params=mamba_cache_params,
554557
mamba2_metadata=mamba2_metadata,
555558
)
556559

557560
# residual connection after mamba
558-
hidden_states = residual + hidden_states
561+
hidden_states = residual + output
559562

560563
return hidden_states
561564

@@ -646,6 +649,7 @@ def forward(
646649
return layer_outputs
647650

648651

652+
@support_torch_compile
649653
class Zamba2Model(nn.Module):
650654
"""Core Zamba2 model combining transformer and Mamba architectures.
651655

vllm/v1/worker/gpu_model_runner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2753,9 +2753,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
27532753
if self.vllm_config.speculative_config is not None:
27542754
raise NotImplementedError(
27552755
"Mamba with speculative decoding is not supported yet.")
2756-
if not self.vllm_config.model_config.enforce_eager:
2757-
raise NotImplementedError(
2758-
"Mamba with cuda graph is not supported yet.")
27592756
if self.vllm_config.cache_config.enable_prefix_caching:
27602757
raise NotImplementedError(
27612758
"Prefix caching is not supported for Mamba yet.")

0 commit comments

Comments
 (0)