Skip to content

Enable the super kernel feature under the Multistream Moe feature #1641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: v0.9.1-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_ascend/ascend_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def __init__(self, torchair_graph_config):
self.enable_view_optimize = torchair_graph_config.get(
"enable_view_optimize", True)
self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False)

self.enable_super_kernel = torchair_graph_config.get(
"enable_super_kernel", False)

if not isinstance(self.graph_batch_sizes, list):
raise TypeError("graph_batch_sizes must be list[int]")
if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0:
Expand Down
19 changes: 16 additions & 3 deletions vllm_ascend/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import torch
import torch_npu
import torchair as tng
import vllm.envs as envs
from torch import nn
from transformers import PretrainedConfig
Expand Down Expand Up @@ -520,6 +521,9 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
) -> None:
nn.Module.__init__(self)
ascend_config = get_ascend_config()
self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel

self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
Expand All @@ -529,6 +533,7 @@ def __init__(
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.prefix = prefix
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
attn_cls = CustomDeepseekV2MLAAttention
Expand Down Expand Up @@ -560,6 +565,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.is_moe = True
else:
self.mlp = CustomDeepseekV2MLP(
hidden_size=config.hidden_size,
Expand All @@ -568,6 +574,8 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.is_moe = False

self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -611,10 +619,15 @@ def forward(
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor

is_prefill = get_forward_context().with_prefill
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.is_moe and not is_prefill and self.enable_super_kernel:
with tng.scope.super_kernel(self.prefix, 'stream-fusion=1'):
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)

if isinstance(self.mlp, CustomDeepseekV2MoE):
hidden_states = self.mlp(hidden_states, attn_metadata)
Expand Down
31 changes: 23 additions & 8 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import torch.distributed as dist
import torch_npu
import torchair as tng
from torch import nn
from vllm.config import get_current_vllm_config
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -1021,6 +1022,7 @@ def __init__(

AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter
self.prefix = prefix

if params_dtype is None:
params_dtype = torch.get_default_dtype()
Expand Down Expand Up @@ -1077,7 +1079,8 @@ def __init__(
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe

self.enable_super_kernel = \
ascend_config.torchair_graph_config.enable_super_kernel
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
Expand Down Expand Up @@ -1137,19 +1140,30 @@ def forward(self,
num_tokens, hidden_size = hidden_states.shape

fused_moe_state = get_forward_context().fused_moe_state
is_prefill = get_forward_context().with_prefill
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
quantized_x_for_share, dynamic_scale_for_share = None, None
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
if self.enable_multistream_moe:
assert gate is not None
router_logits, _ = gate(hidden_states)
if isinstance(self.quant_method.quant_method,
AscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
if not is_prefill and self.enable_super_kernel:
with tng.scope.super_kernel(self.prefix, 'stream-funsion=1'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create a new context corresponding to super kernel with parameter to control whether enable it, in order to avoid code duplication. Refering to npu_stream_switch.

router_logits, _ = gate(hidden_states)
if isinstance(self.quant_method.quant_method,
AscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)
else:
router_logits, _ = gate(hidden_states)
if isinstance(self.quant_method.quant_method,
AscendW8A8DynamicFusedMoEMethod
) and fused_moe_state == FusedMoEState.MC2:
with npu_stream_switch("moe_secondary", 0):
quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant(
hidden_states)

if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
Expand Down Expand Up @@ -1209,6 +1223,7 @@ def forward(self,
and self.enable_multistream_moe and not is_prefill else None,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
prefix=self.prefix,
)

if shared_experts:
Expand Down
Loading
Loading