-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[V1] [Hybrid] Enable piecewise CUDA Graph for mamba layers #21194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
get_tensor_model_parallel_world_size, | ||
tensor_model_parallel_all_gather, | ||
tensor_model_parallel_all_reduce) | ||
from vllm.forward_context import get_forward_context | ||
from vllm.forward_context import ForwardContext, get_forward_context | ||
from vllm.model_executor.custom_op import CustomOp | ||
from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||
RowParallelLinear) | ||
|
@@ -33,6 +33,8 @@ | |
LoaderFunction, composed_weight_loader, sharded_weight_loader) | ||
from vllm.model_executor.models.mamba_cache import MambaCacheParams | ||
from vllm.model_executor.utils import set_weight_attrs | ||
from vllm.platforms import current_platform | ||
from vllm.utils import direct_register_custom_op | ||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata | ||
|
||
# Added by the IBM Team, 2024 | ||
|
@@ -424,14 +426,36 @@ def __init__( | |
def forward_native( | ||
self, | ||
hidden_states: torch.Tensor, | ||
conv_state: torch.Tensor, | ||
ssm_state: torch.Tensor, | ||
output: torch.Tensor, | ||
mamba_cache_params: MambaCacheParams, | ||
mamba2_metadata: Mamba2Metadata, | ||
mup_vector: Optional[torch.Tensor] = None, | ||
): | ||
pass | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
output: torch.Tensor, | ||
mamba_cache_params: MambaCacheParams, | ||
mamba2_metadata: Mamba2Metadata, | ||
mup_vector: Optional[torch.Tensor] = None, | ||
): | ||
if not envs.VLLM_USE_V1: | ||
CustomOp.forward(self, hidden_states, output, mamba_cache_params, | ||
mamba2_metadata, mup_vector) | ||
else: | ||
torch.ops.vllm.mamba_mixer2( | ||
hidden_states, | ||
output, | ||
self.prefix, | ||
mup_vector, | ||
) | ||
|
||
def forward_cuda( | ||
self, | ||
hidden_states: torch.Tensor, | ||
output: torch.Tensor, | ||
mamba_cache_params: MambaCacheParams, | ||
mamba2_metadata: Mamba2Metadata, | ||
mup_vector: Optional[torch.Tensor] = None, | ||
|
@@ -517,25 +541,26 @@ def forward_cuda( | |
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count | ||
has_prefill = num_prefills > 0 | ||
has_decode = num_decodes > 0 | ||
num_actual_tokens = num_prefill_tokens + num_decodes | ||
|
||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill | ||
# Separate prefill and decode by splitting varlen input | ||
# Split along token dimension | ||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill | ||
if envs.VLLM_USE_V1: | ||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split( | ||
hidden_states_B_C, | ||
hidden_states_B_C[:num_actual_tokens], | ||
[num_decodes, num_prefill_tokens], | ||
dim=0, | ||
) | ||
dt_d, dt_p = torch.split( | ||
dt, | ||
dt[:num_actual_tokens], | ||
[num_decodes, num_prefill_tokens], | ||
dim=0, | ||
) | ||
# Split along batch dimension | ||
state_indices_tensor_d, state_indices_tensor_p = torch.split( | ||
state_indices_tensor, | ||
state_indices_tensor[:num_actual_tokens], | ||
[num_decodes, num_prefills], | ||
dim=0, | ||
) | ||
|
@@ -696,11 +721,10 @@ def forward_cuda( | |
# GatedRMSNorm internally applying SiLU to the gate | ||
# SiLU is applied internally before normalization, unlike standard | ||
# norm usage | ||
hidden_states = self.norm(hidden_states, gate) | ||
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens]) | ||
|
||
# 5. Final linear projection | ||
out, _ = self.out_proj(hidden_states) | ||
return out | ||
output[:num_actual_tokens], _ = self.out_proj(hidden_states) | ||
|
||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: | ||
return get_mamba_state_shape( | ||
|
@@ -712,3 +736,36 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: | |
state_size=self.ssm_state_size, | ||
conv_kernel=self.conv_kernel_size, | ||
) | ||
|
||
|
||
def mamba_mixer2( | ||
hidden_states: torch.Tensor, | ||
output: torch.Tensor, | ||
layer_name: str, | ||
mup_vector: Optional[torch.Tensor] = None, | ||
) -> None: | ||
forward_context: ForwardContext = get_forward_context() | ||
self = forward_context.no_compile_layers[layer_name] | ||
self.forward_cuda(hidden_states=hidden_states, | ||
output=output, | ||
mamba_cache_params=None, | ||
mamba2_metadata=None, | ||
mup_vector=mup_vector) | ||
|
||
|
||
def mamba_mixer2_fake( | ||
hidden_states: torch.Tensor, | ||
output: torch.Tensor, | ||
layer_name: str, | ||
mup_vector: Optional[torch.Tensor] = None, | ||
) -> None: | ||
return | ||
|
||
|
||
direct_register_custom_op( | ||
op_name="mamba_mixer2", | ||
op_func=mamba_mixer2, | ||
Comment on lines
+766
to
+767
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should use a common op name here ( |
||
mutates_args=["output"], | ||
fake_impl=mamba_mixer2_fake, | ||
dispatch_key=current_platform.dispatch_key, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
|
||
from vllm import envs | ||
from vllm.attention.layer import Attention | ||
from vllm.compilation.decorators import support_torch_compile | ||
from vllm.config import CacheConfig, VllmConfig | ||
from vllm.distributed import get_tensor_model_parallel_world_size | ||
from vllm.distributed.parallel_state import get_pp_group | ||
|
@@ -122,11 +123,10 @@ def forward( | |
hidden_states, residual = self.input_layernorm( | ||
hidden_states, residual) | ||
|
||
hidden_states = self.mamba(hidden_states, mamba_cache_params, | ||
mamba2_metadata) | ||
output = torch.empty_like(hidden_states) | ||
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) | ||
Comment on lines
+126
to
+127
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to change the code back to |
||
# Fully Connected | ||
hidden_states, residual = self.pre_ff_layernorm( | ||
hidden_states, residual) | ||
hidden_states, residual = self.pre_ff_layernorm(output, residual) | ||
hidden_states = self.feed_forward(hidden_states) | ||
return hidden_states, residual | ||
|
||
|
@@ -169,7 +169,7 @@ def __init__( | |
self.max_position_embeddings = max_position_embeddings | ||
|
||
if hasattr(config, "partial_rotary_factor"): | ||
rotary_dim = self.head_dim * config.partial_rotary_factor | ||
rotary_dim = int(self.head_dim * config.partial_rotary_factor) | ||
elif hasattr(config, "attn_rotary_emb"): | ||
rotary_dim = config.attn_rotary_emb # for backward compatibility | ||
else: | ||
|
@@ -258,6 +258,7 @@ def forward( | |
} | ||
|
||
|
||
@support_torch_compile | ||
class BambaModel(nn.Module): | ||
|
||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
|
Uh oh!
There was an error while loading. Please reload this page.