From 61cee8dedc5d5107619597a5ba95613446b4bfd3 Mon Sep 17 00:00:00 2001 From: NNUCJ <616151263@qq.com> Date: Wed, 2 Jul 2025 16:05:38 +0800 Subject: [PATCH] add super kernel Signed-off-by: NNUCJ <616151263@qq.com> --- vllm_ascend/ascend_config.py | 4 +- vllm_ascend/models/deepseek_v2.py | 19 +- vllm_ascend/ops/fused_moe.py | 31 ++- vllm_ascend/quantization/w8a8_dynamic.py | 291 +++++++++++++++-------- 4 files changed, 239 insertions(+), 106 deletions(-) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index c3043e7a73..7925d0e713 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -63,7 +63,9 @@ def __init__(self, torchair_graph_config): self.enable_view_optimize = torchair_graph_config.get( "enable_view_optimize", True) self.enable_kv_nz = torchair_graph_config.get("enable_kv_nz", False) - + self.enable_super_kernel = torchair_graph_config.get( + "enable_super_kernel", False) + if not isinstance(self.graph_batch_sizes, list): raise TypeError("graph_batch_sizes must be list[int]") if self.graph_batch_sizes_init and len(self.graph_batch_sizes) > 0: diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index cc7b914492..359495a871 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -29,6 +29,7 @@ import torch import torch_npu +import torchair as tng import vllm.envs as envs from torch import nn from transformers import PretrainedConfig @@ -520,6 +521,9 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ) -> None: nn.Module.__init__(self) + ascend_config = get_ascend_config() + self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -529,6 +533,7 @@ def __init__( # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) self.layer_idx = layer_idx + self.prefix = prefix # TODO: enable mla in vllm-ascend if model_config.use_mla: attn_cls = CustomDeepseekV2MLAAttention @@ -560,6 +565,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) + self.is_moe = True else: self.mlp = CustomDeepseekV2MLP( hidden_size=config.hidden_size, @@ -568,6 +574,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) + self.is_moe = False + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -611,10 +619,15 @@ def forward( # The residual is shared by all layers, we only scale it on # first layer. residual *= 1. / self.routed_scaling_factor - + is_prefill = get_forward_context().with_prefill # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + if self.is_moe and not is_prefill and self.enable_super_kernel: + with tng.scope.super_kernel(self.prefix, 'stream-fusion=1'): + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) if isinstance(self.mlp, CustomDeepseekV2MoE): hidden_states = self.mlp(hidden_states, attn_metadata) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index fe1164fd4d..9f16a1357b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -22,6 +22,7 @@ import torch import torch.distributed as dist import torch_npu +import torchair as tng from torch import nn from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, @@ -1021,6 +1022,7 @@ def __init__( AscendFusedMoE.moe_counter += 1 self.moe_instance_id = AscendFusedMoE.moe_counter + self.prefix = prefix if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -1077,7 +1079,8 @@ def __init__( self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = \ ascend_config.torchair_graph_config.enable_multistream_moe - + self.enable_super_kernel = \ + ascend_config.torchair_graph_config.enable_super_kernel if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") @@ -1137,19 +1140,30 @@ def forward(self, num_tokens, hidden_size = hidden_states.shape fused_moe_state = get_forward_context().fused_moe_state + is_prefill = get_forward_context().with_prefill # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None from vllm_ascend.quantization.w8a8_dynamic import \ AscendW8A8DynamicFusedMoEMethod if self.enable_multistream_moe: assert gate is not None - router_logits, _ = gate(hidden_states) - if isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod - ) and fused_moe_state == FusedMoEState.MC2: - with npu_stream_switch("moe_secondary", 0): - quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( - hidden_states) + if not is_prefill and self.enable_super_kernel: + with tng.scope.super_kernel(self.prefix, 'stream-funsion=1'): + router_logits, _ = gate(hidden_states) + if isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod + ) and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( + hidden_states) + else: + router_logits, _ = gate(hidden_states) + if isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod + ) and fused_moe_state == FusedMoEState.MC2: + with npu_stream_switch("moe_secondary", 0): + quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( + hidden_states) if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: @@ -1209,6 +1223,7 @@ def forward(self, and self.enable_multistream_moe and not is_prefill else None, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, + prefix=self.prefix, ) if shared_experts: diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index a9938c14f2..78417f4578 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -21,6 +21,7 @@ import torch import torch.distributed as dist import torch_npu +import torchair as tng from vllm.distributed import GroupCoordinator, get_ep_group, get_tp_group from vllm.forward_context import get_forward_context @@ -215,6 +216,8 @@ def fused_experts_with_mc2( w2_scale_bias: torch.Tensor = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + prefix:str = "", + use_super_kernel: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if log2phy: topk_ids = log2phy[topk_ids] @@ -226,103 +229,199 @@ def fused_experts_with_mc2( # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. - global_bs = math.ceil(get_forward_context().max_tokens_across_dp / + if shared_experts is not None and use_super_kernel: + with tng.scope.super_kernel(prefix, 'stream-fusion=1'): + global_bs = math.ceil(get_forward_context().max_tokens_across_dp / + tp_world_size) * ep_world_size + + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) + + if (expert_map is not None): + moe_expert_num = len(expert_map) + global_redundant_expert_num + else: + moe_expert_num = global_redundant_expert_num + # hidden_states = hidden_states.bfloat16() + kwargs_mc2 = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": global_bs, + } + + stage1_kwargs = { + "scales": None, + "quant_mode": quant_mode, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + kwargs_mc2.update(stage1_kwargs) + + output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ + 0:5] + + + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, expand_x) + shared_act_out = shared_experts.act_fn( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] + + # `expand_x` will be disposed in the `apply_mlp` function + down_out_list = apply_mlp_decode([expand_x], + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) + + # moeCombine + kwargs_mc2 = { + "expand_x": down_out_list, + "expert_ids": topk_ids, + "expand_idx": expand_idx, + "expert_scales": topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": 0, + } + tp_recv_counts = torch.empty(1, + dtype=torch.int32, + device=hidden_states.device) + stage3_kwargs = { + "ep_send_counts": ep_recv_counts, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + } + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + kwargs_mc2.update(stage3_kwargs) + + hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_act, down_out_list) + shared_output, _ = shared_experts.down_proj( + (shared_act, swiglu_out_scale)) + return hidden_states, shared_output + else: + global_bs = math.ceil(get_forward_context().max_tokens_across_dp / tp_world_size) * ep_world_size - # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 - or is_torchair) + # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine + need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 + or is_torchair) - if (expert_map is not None): - moe_expert_num = len(expert_map) + global_redundant_expert_num - else: - moe_expert_num = global_redundant_expert_num - # hidden_states = hidden_states.bfloat16() - kwargs_mc2 = { - "x": hidden_states, - "expert_ids": topk_ids, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": global_bs, - } - - stage1_kwargs = { - "scales": None, - "quant_mode": quant_mode, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if need_extra_args: - stage1_kwargs.update({ - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - kwargs_mc2.update(stage1_kwargs) - - output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) - # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ - 0:5] - - if shared_experts is not None: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(quantized_x_for_share, expand_x) - shared_act_out = shared_experts.act_fn( - (quantized_x_for_share, dynamic_scale_for_share)) - shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] - - # `expand_x` will be disposed in the `apply_mlp` function - down_out_list = apply_mlp_decode([expand_x], - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale) - - # moeCombine - kwargs_mc2 = { - "expand_x": down_out_list, - "expert_ids": topk_ids, - "expand_idx": expand_idx, - "expert_scales": topk_weights.to(torch.float32), - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": global_bs, - } - tp_recv_counts = torch.empty(1, - dtype=torch.int32, - device=hidden_states.device) - stage3_kwargs = { - "ep_send_counts": ep_recv_counts, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if need_extra_args: - stage3_kwargs.update({ - "tp_send_counts": tp_recv_counts, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - kwargs_mc2.update(stage3_kwargs) - - hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - - if shared_experts is None: - return hidden_states - else: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(shared_act, down_out_list) - shared_output, _ = shared_experts.down_proj( - (shared_act, swiglu_out_scale)) - return hidden_states, shared_output + if (expert_map is not None): + moe_expert_num = len(expert_map) + global_redundant_expert_num + else: + moe_expert_num = global_redundant_expert_num + # hidden_states = hidden_states.bfloat16() + kwargs_mc2 = { + "x": hidden_states, + "expert_ids": topk_ids, + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": global_bs, + } + stage1_kwargs = { + "scales": None, + "quant_mode": quant_mode, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + } + if need_extra_args: + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + kwargs_mc2.update(stage1_kwargs) + + output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) + # comm_stream.wait_stream(torch.npu.current_stream()) + expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ + 0:5] + + if shared_experts is not None: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(quantized_x_for_share, expand_x) + shared_act_out = shared_experts.act_fn( + (quantized_x_for_share, dynamic_scale_for_share)) + shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] + + # `expand_x` will be disposed in the `apply_mlp` function + down_out_list = apply_mlp_decode([expand_x], + w1, + w1_scale, + w2, + w2_scale, + expert_token_nums, + dynamic_scale=dynamic_scale) + + # moeCombine + kwargs_mc2 = { + "expand_x": down_out_list, + "expert_ids": topk_ids, + "expand_idx": expand_idx, + "expert_scales": topk_weights.to(torch.float32), + "expert_shard_type": 0, + "shared_expert_rank_num": 0, + "moe_expert_num": moe_expert_num, + "global_bs": global_bs, + } + tp_recv_counts = torch.empty(1, + dtype=torch.int32, + device=hidden_states.device) + stage3_kwargs = { + "ep_send_counts": ep_recv_counts, + "group_ep": moe_all_to_all_group_name, + "ep_world_size": ep_world_size, + "ep_rank_id": ep_rank_id, + } + if need_extra_args: + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) + kwargs_mc2.update(stage3_kwargs) + + hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) + + if shared_experts is None: + return hidden_states + else: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_act, down_out_list) + shared_output, _ = shared_experts.down_proj( + (shared_act, swiglu_out_scale)) + return hidden_states, shared_output + # currently expert parallelism implemented with all2all # is under-optimized. @@ -663,6 +762,7 @@ def __init__(self): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout + self.enable_super_kernel = ascend_config.torchair_graph_config.enable_super_kernel try: device_group = self.ep_group.device_group @@ -738,6 +838,7 @@ def apply( shared_experts: Optional[Any] = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + prefix:str ="", **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -807,7 +908,9 @@ def apply( shared_experts=shared_experts, is_torchair=self.torchair_graph_enabled, quantized_x_for_share=shared_gate_up, - dynamic_scale_for_share=shared_dequant_scale) + dynamic_scale_for_share=shared_dequant_scale, + prefix=prefix, + use_super_kernel=self.enable_super_kernel) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight,