From 7ff288ec66c481d648620342004dcbe9df945c01 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Mon, 30 Jun 2025 20:02:20 +0800 Subject: [PATCH 01/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/attention_v1.py | 13 + vllm_ascend/distributed/tensor_parallel.py | 246 +++++++ vllm_ascend/envs.py | 13 +- vllm_ascend/models/__init__.py | 4 + vllm_ascend/models/deepseek_dbo.py | 284 ++++++- vllm_ascend/models/qwen3_dbo.py | 511 +++++++++++++ vllm_ascend/multistream/ms_split.py | 107 ++- vllm_ascend/ops/comm_utils.py | 127 ++++ vllm_ascend/ops/fused_moe.py | 347 +++++---- vllm_ascend/ops/moe_dispatcher/__init__.py | 0 vllm_ascend/ops/moe_dispatcher/moe_utils.py | 379 ++++++++++ .../ops/moe_dispatcher/token_dispatcher.py | 696 ++++++++++++++++++ 12 files changed, 2572 insertions(+), 155 deletions(-) create mode 100644 vllm_ascend/distributed/tensor_parallel.py create mode 100644 vllm_ascend/models/qwen3_dbo.py create mode 100644 vllm_ascend/ops/comm_utils.py create mode 100644 vllm_ascend/ops/moe_dispatcher/__init__.py create mode 100644 vllm_ascend/ops/moe_dispatcher/moe_utils.py create mode 100644 vllm_ascend/ops/moe_dispatcher/token_dispatcher.py diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index adb6de2af4..795020c69e 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -29,6 +29,7 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.ops.attention import vanilla_chunked_prefill @@ -135,6 +136,18 @@ class AscendMetadata: enable_dbo_across_dp: bool = False + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendMetadata"]: + """Split metadata for multi-stream with AscendMetadata""" + from vllm_ascend.multistream.ms_split import model_input_split_v1_attn + return model_input_split_v1_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMetadata, + ) + class AscendAttentionMetadataBuilder: diff --git a/vllm_ascend/distributed/tensor_parallel.py b/vllm_ascend/distributed/tensor_parallel.py new file mode 100644 index 0000000000..70aa820094 --- /dev/null +++ b/vllm_ascend/distributed/tensor_parallel.py @@ -0,0 +1,246 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +import torch + + +def _gather_along_first_dim(input_, group, output_split_sizes=None): + """Gather tensors and concatenate along the first dimension. + + Args: + input_tensor (torch.Tensor): + A tensor to be gathered. + output_split_sizes (List[int], optional): + A list specifying the sizes of the output splits along the first dimension. + If None, equal splitting is assumed. Default: None. + + Returns: + torch.Tensor: Gathered tensor. + """ + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + if output_split_sizes is None: + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, + input_.contiguous(), + group=group) + else: + dim_size[0] = sum(output_split_sizes) + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + output_tensor_list = list( + torch.split(output, output_split_sizes, dim=0)) + torch.distributed.all_gather(output_tensor_list, input_, group=group) + + return output + + +def _gather_along_last_dim(input_, group): + """Gather tensors and concatenate along the last dimension.""" + + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, + input_.contiguous(), + group=group) + tensor_list = output.chunk(world_size, dim=0) + output = torch.cat(tensor_list, dim=-1).contiguous() + + return output + + +def _reduce_scatter_along_first_dim(input_, + group, + input_split_sizes=None, + use_global_buffer=False): + """Reduce-scatter the input tensor across model parallel group. + + Args: + input_ (torch.Tensor): The input tensor to be reduce-scattered. + input_split_sizes (List[int], optional): A list specifying the sizes of + the input splits along the first dimension for each rank. If None, + equal splitting is assumed. Default: None. + """ + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + if input_split_sizes is None: + dim_size = list(input_.size()) + assert ( + dim_size[0] % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.reduce_scatter_tensor(output, + input_.contiguous(), + group=group) + else: + rank = torch.distributed.get_rank(group) + input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0)) + + output = torch.empty_like(input_tensor_list[rank]) + torch.distributed.reduce_scatter(output, + input_tensor_list, + group=group) + return output + + +def _reduce_scatter_along_last_dim(input_, group): + """Reduce-scatter tensors on the last dimension.""" + world_size = torch.distributed.get_world_size(group) + target_shape = list(input_.size()) + target_shape[-1] = target_shape[-1] // world_size + input_ = input_.reshape(-1, input_.shape[-1]) + split_tensors = torch.split(input_, + split_size_or_sections=input_.shape[-1] // + world_size, + dim=1) + concat_tensor = torch.cat(split_tensors, dim=0) + output = _reduce_scatter_along_first_dim(concat_tensor, + group).reshape(target_shape) + return output + + +def all_gather_last_dim_from_tensor_parallel_region(input_, group): + """Wrapper for autograd function: forward: AG, backward RS """ + return _gather_along_last_dim(input_, group) + + +def reduce_scatter_to_sequence_parallel_region(input_, + group, + input_split_sizes=None): + """Wrapper for autograd function: forward: RS, backward AG """ + return _reduce_scatter_along_first_dim(input_, group, input_split_sizes) + + +def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group): + """Wrapper for autograd function: forward: RS, backward AG: AG """ + return _reduce_scatter_along_last_dim(input_, group) + + +def gather_from_sequence_parallel_region( + input_, + group, + output_split_sizes=None, +): + """Wrapper for autograd function: forward: AG, backward: RS """ + return _gather_along_first_dim(input_, group, output_split_sizes) + + +def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None): + world_size = torch.distributed.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input + + input = input.contiguous() + if output_split_sizes is None: + # Equal split (all2all) + output = torch.empty_like(input) + else: + # Unequal split (all2all-v) + output = input.new_empty( + size=[sum(output_split_sizes)] + list(input.size()[1:]), + dtype=input.dtype, + device=torch.npu.current_device(), + ) + torch.distributed.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + +def all_to_all_sp2hp(input_, group): + """ + Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape + [num_tokens/TP, H] to [num_tokens, H/TP]. + + Args: + input_ (torch.Tensor): + The input tensor which has been distributed along the sequence + dimension. + + Returns: + torch.Tensor: The output tensor with shape [num_tokens, H/TP]. + + """ + if group is None: + return input_ + world_size = torch.distributed.get_world_size(group=group) + tp_group = group + input_ = input_.reshape(-1, input_.shape[-1]) + split_tensors = torch.split(input_, + split_size_or_sections=input_.shape[-1] // + world_size, + dim=1) + concat_tensor = torch.cat(split_tensors, dim=0) + output = all_to_all(tp_group, concat_tensor) + return output + + +def all_to_all_hp2sp(input_, group): + """ + Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape + [num_tokens, H/TP] to [num_tokens/TP, H]. + + Args: + input_ (torch.Tensor): + The input tensor which has been distributed along the hidden + dimension. + + Returns: + torch.Tensor: The output tensor with shape [num_tokens/TP, H]. + """ + if group is None: + return input_ + world_size = torch.distributed.get_world_size(group=group) + input_ = input_.reshape(-1, input_.shape[-1]) + tp_group = group + input_exchanged = all_to_all(tp_group, input_) + input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1]) + split_tensors = torch.split( + input_reshaped, + split_size_or_sections=input_reshaped.shape[0] // world_size, + dim=0) + output = torch.cat(split_tensors, dim=-1) + return output diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 27d0131720..eefc78d647 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -106,11 +106,11 @@ "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) ), - # MOE_ALL2ALL_BUFFER: + # VLLM_ASCEND_MOE_ALL2ALL_BUFFER: # 0: default, normal init. # 1: enable moe_all2all_buffer. - "MOE_ALL2ALL_BUFFER": - lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))), + "VLLM_ASCEND_MOE_ALL2ALL_BUFFER": + lambda: bool(int(os.getenv("VLLM_ASCEND_MOE_ALL2ALL_BUFFER", '0'))), # Some models are optimized by vllm ascend. While in some case, e.g. rlhf # training, the optimized model may not be suitable. In this case, set this # value to False to disable the optimized model. @@ -136,7 +136,12 @@ # Whether to enable mla_pa for deepseek mla decode, this flag will be removed after its available torch_npu is public accessible # and the mla_pa will be the default path of deepseek decode path. "VLLM_ASCEND_MLA_PA": - lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)) + lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)), + # VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: + # 0: default, normal init. + # 1: enable moe all2allv. + "VLLM_ASCEND_ENABLE_MOE_ALL2ALLV": + lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALLV', '0'))), } # end-env-vars-definition diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index d85572b32c..c0e8c5be54 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -40,6 +40,10 @@ def register_model(): "DeepseekV3ForCausalLM", "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + ModelRegistry.register_model( + "Qwen3MoeForCausalLM", + "vllm_ascend.models.qwen3_dbo:CustomQwen3MoeForCausalLMDBO") + else: ModelRegistry.register_model( "DeepseekV2ForCausalLM", diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index c0ef61c125..45ccd84c50 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -39,7 +39,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -66,6 +66,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( @@ -76,11 +77,12 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.ops.fused_moe import AscendFusedMoE, apply_mlp, select_experts from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO +VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: bool = envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALLV class CustomDeepseekDBOMLP(CustomDeepseekV2MLP): @@ -170,7 +172,7 @@ def __init__( top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=False, + reduce_results=True if not VLLM_ASCEND_ENABLE_MOE_ALL2ALLV else False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -205,6 +207,7 @@ def __init__( ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.config = config def forward( self, @@ -256,6 +259,131 @@ def _forward_ms_op_gate( router_logits, _ = self.gate(hidden_states) return router_logits + def _forward_op_gating( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None + ) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata is None: + # for profile run + self.is_prefill = True + self.enable_force_load_balance = True + else: + is_prefill = attn_metadata.num_prefills > 0 + self.enable_force_load_balance = False + if hasattr(attn_metadata, 'with_prefill_across_dp'): + self.is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + + num_tokens, hidden_dim = hidden_states.shape + + if self.tp_size > 1: + # pass + num_tokens, hidden_size = hidden_states.shape + if num_tokens < self.tp_size: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, self.tp_size - num_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + chunked_hidden_states_sizes = [x.shape[0] for x in chunk_hidden_states] + local_hidden_states = chunk_hidden_states[self.tp_rank] + else: + local_hidden_states = hidden_states + chunked_hidden_states_sizes = None + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(local_hidden_states) + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if self.config.n_routed_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=self.config.num_experts_per_tok, # topk当前写8 + bias=self.gate.e_score_correction_bias, + k_group=self.config.topk_group, # fix: 4 + group_count=self.config.n_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=local_hidden_states, + router_logits=router_logits, + top_k=self.config.num_experts_per_tok, + use_grouped_topk=True, + renormalize=self.config.norm_topk_prob, + topk_group=self.config.topk_group, + num_expert_group=self.config.n_group, + custom_routing_function=None, + scoring_func=self.config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + + topk_weights = topk_weights.to(hidden_states.dtype) + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if self.enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, self.config.n_routed_experts) + + return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes + + def _forward_dispatch_comm( + self, hidden_states, topk_weights, topk_ids, microbatch_id + ): + token_dispatcher = self.experts.token_dispatchers[microbatch_id] + _, hidden_states, tokens_per_expert = token_dispatcher.token_permutation(hidden_states, topk_weights, topk_ids) + return hidden_states, tokens_per_expert + + def _forward_op_shared_experts( + self, hidden_states + ): + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + return shared_output + + def _forward_op_grouped_mlp( + self, dispatched_input, tokens_per_expert + ): + return apply_mlp( + [dispatched_input], + self.experts.w13_weight, + self.experts.w2_weight, + tokens_per_expert + ) + + def _forward_combine_comm( + self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes + ): + token_dispatcher = self.experts.token_dispatchers[microbatch_id] + token_dispatcher.combine_alltoall() + final_hidden_states = token_dispatcher.unpermute2() * self.routed_scaling_factor + + if self.tp_size > 1: + final_hidden_states = gather_from_sequence_parallel_region(final_hidden_states, self.tp_group, + chunked_hidden_states_sizes) + if num_tokens < self.tp_size: + final_hidden_states = final_hidden_states[:num_tokens] + + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + token_dispatcher.cached_shared_expert_output + token_dispatcher.cached_shared_expert_output.untyped_storage().resize_(0) + token_dispatcher.cached_shared_expert_output = None + + final_hidden_states = final_hidden_states.view(num_tokens, -1) + + return final_hidden_states + class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention): @@ -707,6 +835,139 @@ def _forward_ms_layer( context.after_comm_event.record() return hidden_states, residual + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_layer_alltoallv_finegrained( + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: List[torch.Tensor], + attn_metadata: List[AttentionMetadata], + kv_cache: Optional[torch.Tensor] = None, + is_prefill: bool = False, + ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + assert layer_index >= 0 and ms_metadata is not None + num_micro_batchs = ms_metadata.ms_config.num_micro_batches + assert isinstance(self.mlp, CustomDeepseekDBOMoE) + assert len(positions) == num_micro_batchs + assert len(hidden_states) == num_micro_batchs + assert residual is not None + assert attn_metadata is not None + num_tokens = [None] * num_micro_batchs + hidden_dims = [None] * num_micro_batchs + topk_weights, topk_ids = [None] * num_micro_batchs, [None] * num_micro_batchs + tokens_per_expert = [None] * num_micro_batchs + dispatched_input = [None] * num_micro_batchs + shared_expert_output = [None] * num_micro_batchs + router_expert_output = [None] * num_micro_batchs + chunked_hidden_states_sizes = [None] * num_micro_batchs + token_dispatchers = self.mlp.experts.token_dispatchers + + def print_with_sync(*args, **kwargs): + torch.npu.synchronize() + print(*args, **kwargs) + + def discard_tensor(tensor): + if isinstance(tensor, torch.Tensor): + tensor = [tensor] + for t in tensor: + t.untyped_storage().resize_(0) + + # print_with_sync('begin layer...', torch.distributed.get_rank()) + + # block 1 : attention + # block 2 : Router Gating + # block 3 : Token DisPatch + # the attn computation of microbatch 1 can be overlapped with the moe + # communication in the previous layer, and the attn computation of microbatch 2 + # can be overlapped with the attn communication of microbatch 1 + for i in range(num_micro_batchs): + # wait last layer moe finishing communication + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.MOE_AFTER_COMM) + + forward_context = get_forward_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + forward_context.attn_metadata = attn_metadata[i] + + # input layernorm + hidden_states[i], residual[ + i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i]) + # attention and tp allreduce + hidden_states[i], residual[i] = self._forward_ms_op_attn( + positions[i], hidden_states[i], residual[i], kv_cache, + attn_metadata[i]) + # post attention layer norm + hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i] + ) + num_tokens[i], hidden_dims[i] = hidden_states[i].shape + # If TP is enabled, hidden_states will be chunked. + topk_weights[i], topk_ids[i], dispatched_input[i], chunked_hidden_states_sizes[ + i] = self.mlp._forward_op_gating(hidden_states[i], attn_metadata[i]) + token_dispatchers[i].preprocess_and_permtute1( + dispatched_input[i], topk_weights[i], topk_ids[i], + self.mlp.shared_experts, + shared_experts_input=hidden_states[i] if self.mlp.n_shared_experts else None + ) + # Launch DisPatch Comm in a New Stream. + dispatch_context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_BEFORE_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + ) + dispatch_context.before_comm_event.record() + # print_with_sync(f'begin token dispatch{i}...', torch.distributed.get_rank()) + with torch.npu.stream(dispatch_context.comm_stream): + dispatch_context.comm_stream.wait_event(dispatch_context.before_comm_event) + token_dispatchers[i].dispatch_alltoall() + dispatch_context.after_comm_event.record() + + if self.mlp.n_shared_experts: + token_dispatchers[i].cached_shared_expert_output = tensor_model_parallel_all_reduce( + token_dispatchers[i].cached_shared_expert_output + ) + ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH].record() + + # print_with_sync('begin experts...', torch.distributed.get_rank()) + # block 4 : Router Experts Computation + # block 5 : Token Combine Communication + for i in range(num_micro_batchs): + + ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM) + discard_tensor(hidden_states[i]) + + dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2() + router_expert_output[i] = self.mlp._forward_op_grouped_mlp(dispatched_input[i], tokens_per_expert[i]) + discard_tensor(dispatched_input[i]) + token_dispatchers[i].unpermute1(router_expert_output[i]) + if router_expert_output[i].shape[0] > 0 and token_dispatchers[i].num_local_experts > 1: + discard_tensor(router_expert_output[i]) + + # Launch Combine Comm in a New Stream. + combine_context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_BEFORE_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + ) + combine_context.before_comm_event.record() + ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) + with torch.npu.stream(combine_context.comm_stream): + combine_context.comm_stream.wait_event(combine_context.before_comm_event) + hidden_states[i] = self.mlp._forward_combine_comm( + router_expert_output[i], i, num_tokens[i], chunked_hidden_states_sizes[i] + ) + combine_context.after_comm_event.record() + + return hidden_states, residual + # should split ops in Decoder Layer def _forward_ms_op_input_layernorm( self, @@ -879,6 +1140,16 @@ def can_run_ms(self): return False return True + def all_can_run_ms(self): + can_run_ms_local = self.can_run_ms() + ep_group = get_ep_group().cpu_group + flag = torch.ones(1, dtype=torch.int) if can_run_ms_local else torch.zeros(1, dtype=torch.int) + torch.distributed.all_reduce(flag, group=ep_group) + if flag.item() == torch.distributed.get_world_size(ep_group): + return True + else: + return False + def _forward_ms_layers(self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -896,7 +1167,12 @@ def _forward_ms_layers(self, # the rest layers for i in range(moe_start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer._forward_ms_layer( + ms_layer_forward_func = layer._forward_ms_layer + if VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: + # ms_layer_forward_func = layer._forward_ms_layer_alltoallv + ms_layer_forward_func = layer._forward_ms_layer_alltoallv_finegrained + # print("get_called......") + hidden_states, residual = ms_layer_forward_func( positions=positions, hidden_states=hidden_states, residual=residual, diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py new file mode 100644 index 0000000000..042f4dc400 --- /dev/null +++ b/vllm_ascend/models/qwen3_dbo.py @@ -0,0 +1,511 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" + +from collections.abc import Iterable +from typing import Any, Optional, Union, List +from types import SimpleNamespace + +import torch +import torch_npu +from torch import nn +from transformers import PretrainedConfig + +from vllm.model_executor.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeModel +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.attention import AttentionMetadata +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_world_size, get_tp_group, \ + get_pp_group +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.models.utils import (make_empty_intermediate_tensors_factory, make_layers) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.sequence import IntermediateTensors + +from vllm_ascend.multistream.context import ( + advance_step_multistream_layer_context, get_multistream_comm_context, + get_multistream_layer_context, set_multistream_context) +from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, + MultiStreamPreTransformerLayer) +from vllm_ascend.multistream.metadata import (MultiStreamConfig, + MultiStreamStepMetadata, + make_multistream_metadata_ds) +from vllm_ascend.ops.fused_moe import AscendFusedMoE, select_experts, apply_mlp +from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region +import vllm_ascend.envs as envs_ascend + +VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO +VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: bool = envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALLV + + +class Qwen3MoeDecoderLayerDBO(Qwen3MoeDecoderLayer): + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super(Qwen3MoeDecoderLayerDBO, self).__init__(config, cache_config, quant_config, prefix) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + self.tp_group = get_tp_group().device_group + self.dummy_vllm_config = SimpleNamespace( + parallel_config=SimpleNamespace( + data_parallel_size=1, + ), + compilation_config=SimpleNamespace( + static_forward_context=None, + ), + other_setting="value" + ) + self.config = config + + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + + # should split ops in Decoder Layer + def _forward_ms_op_input_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + return hidden_states, residual + + def _forward_ms_op_attn( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + self.dummy_vllm_config.compilation_config.static_forward_context = get_forward_context().no_compile_layers + with set_forward_context(attn_metadata, self.dummy_vllm_config): + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + return hidden_states, residual + + def _forward_ms_op_post_attn_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ): + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + return hidden_states, residual + + def _forward_op_gating( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None + ) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata is None: + # for profile run + self.is_prefill = True + self.enable_force_load_balance = True + else: + # is_prefill = attn_metadata.num_prefills > 0 + is_prefill = False + self.enable_force_load_balance = False + if hasattr(attn_metadata, 'with_prefill_across_dp'): + self.is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + + num_tokens, hidden_dim = hidden_states.shape + + if self.tp_size > 1: + # pass + num_tokens, hidden_size = hidden_states.shape + if num_tokens < self.tp_size: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, self.tp_size - num_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + chunked_hidden_states_sizes = [x.shape[0] for x in chunk_hidden_states] + local_hidden_states = chunk_hidden_states[self.tp_rank] + else: + local_hidden_states = hidden_states + chunked_hidden_states_sizes = None + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.mlp.gate(local_hidden_states) + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + mlp_config = self.config + if mlp_config.num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=mlp_config.num_experts_per_tok, # topk当前写8 + bias=self.mlp.gate.e_score_correction_bias, + k_group=mlp_config.topk_group, # fix: 4 + group_count=mlp_config.n_group, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # todo new api; 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=local_hidden_states, + router_logits=router_logits, + top_k=mlp_config.num_experts_per_tok, + use_grouped_topk=False, + renormalize=mlp_config.norm_topk_prob, + topk_group=getattr(mlp_config, "topk_group", None), + num_expert_group=getattr(mlp_config, "n_group", None), + custom_routing_function=None, + scoring_func=getattr(mlp_config, "scoring_func", 'softmax'), + e_score_correction_bias=getattr(self.mlp.gate, "e_score_correction_bias", None) + ) + + topk_weights = topk_weights.to(hidden_states.dtype) + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if self.enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, self.config.num_experts) + + return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes + + def _forward_op_grouped_mlp( + self, dispatched_input, tokens_per_expert + ): + return apply_mlp( + [dispatched_input], + self.mlp.experts.w13_weight, + self.mlp.experts.w2_weight, + tokens_per_expert + ) + + def _forward_combine_comm( + self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes + ): + token_dispatcher = self.mlp.experts.token_dispatchers[microbatch_id] + token_dispatcher.combine_alltoall() + final_hidden_states = token_dispatcher.unpermute2() + if hasattr(self.mlp, 'routed_scaling_factor'): + final_hidden_states = final_hidden_states * self.mlp.routed_scaling_factor + + if self.tp_size > 1: + final_hidden_states = gather_from_sequence_parallel_region(final_hidden_states, self.tp_group, + chunked_hidden_states_sizes) + if num_tokens < self.tp_size: + final_hidden_states = final_hidden_states[:num_tokens] + + if hasattr(self.mlp, "shared_experts"): + final_hidden_states = final_hidden_states + token_dispatcher.cached_shared_expert_output + token_dispatcher.cached_shared_expert_output.untyped_storage().resize_(0) + token_dispatcher.cached_shared_expert_output = None + + final_hidden_states = final_hidden_states.view(num_tokens, -1) + + return final_hidden_states + + def _forward_ms_layer_alltoallv_finegrained( + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: List[torch.Tensor], + attn_metadata: List[AttentionMetadata], + kv_cache: Optional[torch.Tensor] = None, + ): + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + assert layer_index >= 0 and ms_metadata is not None + num_micro_batchs = ms_metadata.ms_config.num_micro_batches + assert len(positions) == num_micro_batchs + assert len(hidden_states) == num_micro_batchs + assert residual is not None + assert attn_metadata is not None + num_tokens = [None] * num_micro_batchs + hidden_dims = [None] * num_micro_batchs + topk_weights, topk_ids = [None] * num_micro_batchs, [None] * num_micro_batchs + tokens_per_expert = [None] * num_micro_batchs + dispatched_input = [None] * num_micro_batchs + shared_expert_output = [None] * num_micro_batchs + router_expert_output = [None] * num_micro_batchs + chunked_hidden_states_sizes = [None] * num_micro_batchs + token_dispatchers = self.mlp.experts.token_dispatchers + has_shared_expert = hasattr(self.mlp, 'shared_experts') + + def discard_tensor(tensor): + if isinstance(tensor, torch.Tensor): + tensor = [tensor] + for t in tensor: + t.untyped_storage().resize_(0) + + # block 1 : attention + # block 2 : Router Gating + # block 3 : Token DisPatch + # the attn computation of microbatch 1 can be overlapped with the moe + # communication in the previous layer, and the attn computation of microbatch 2 + # can be overlapped with the attn communication of microbatch 1 + for i in range(num_micro_batchs): + # wait last layer moe finishing communication + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.MOE_AFTER_COMM) + + forward_context = get_forward_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + forward_context.attn_metadata = attn_metadata[i] + + # input layernorm + hidden_states[i], residual[ + i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i]) + # attention and tp allreduce + hidden_states[i], residual[i] = self._forward_ms_op_attn( + positions[i], hidden_states[i], residual[i], kv_cache, + attn_metadata[i]) + # post attention layer norm + hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i] + ) + num_tokens[i], hidden_dims[i] = hidden_states[i].shape + # If TP is enabled, hidden_states will be chunked. + topk_weights[i], topk_ids[i], dispatched_input[i], chunked_hidden_states_sizes[i] = self._forward_op_gating( + hidden_states[i], attn_metadata[i]) + token_dispatchers[i].preprocess_and_permtute1( + dispatched_input[i], topk_weights[i], topk_ids[i], + shared_experts=None, shared_experts_input=None + ) + # Launch DisPatch Comm in a New Stream. + dispatch_context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_BEFORE_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + ) + dispatch_context.before_comm_event.record() + # print_with_sync(f'begin token dispatch{i}...', torch.distributed.get_rank()) + with torch.npu.stream(dispatch_context.comm_stream): + dispatch_context.comm_stream.wait_event(dispatch_context.before_comm_event) + token_dispatchers[i].dispatch_alltoall() + dispatch_context.after_comm_event.record() + + if has_shared_expert: + token_dispatchers[i].cached_shared_expert_output = tensor_model_parallel_all_reduce( + token_dispatchers[i].cached_shared_expert_output + ) + ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH].record() + + # print_with_sync('begin experts...', torch.distributed.get_rank()) + # block 4 : Router Experts Computation + # block 5 : Token Combine Communication + for i in range(num_micro_batchs): + + ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM) + discard_tensor(hidden_states[i]) + + dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2() + router_expert_output[i] = self._forward_op_grouped_mlp(dispatched_input[i], tokens_per_expert[i]) + discard_tensor(dispatched_input[i]) + token_dispatchers[i].unpermute1(router_expert_output[i]) + if router_expert_output[i].shape[0] > 0 and token_dispatchers[i].num_local_experts > 1: + discard_tensor(router_expert_output[i]) + + # Launch Combine Comm in a New Stream. + combine_context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_BEFORE_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + ) + combine_context.before_comm_event.record() + ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) + with torch.npu.stream(combine_context.comm_stream): + combine_context.comm_stream.wait_event(combine_context.before_comm_event) + hidden_states[i] = self._forward_combine_comm( + router_expert_output[i], i, num_tokens[i], chunked_hidden_states_sizes[i] + ) + combine_context.after_comm_event.record() + + return hidden_states, residual + + +class CustomQwen3DBOMoEModel(Qwen3MoeModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen3MoeDecoderLayerDBO(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + # dbo related members + if VLLM_ASCEND_ENABLE_DBO: + self.use_mla = False + self.multistream_config = MultiStreamConfig() + multistream_metadata = make_multistream_metadata_ds( + start_layer=self.start_layer, + end_layer=self.end_layer, + causal_lm=getattr(config, "causal_lm", True), + multistream_config=self.multistream_config, + ) + self.ms_pre_layer = MultiStreamPreTransformerLayer( + multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer( + multistream_metadata) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + num_normal_layers = (0 if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() + else self.end_layer - self.start_layer) + + moe_start_layer = self.start_layer + num_normal_layers + for i in range(self.start_layer, min(moe_start_layer, self.end_layer)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if moe_start_layer < self.end_layer: + # if we enable multistream/dbo, process sparse layers here + hidden_states, residual = self._forward_ms_layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + moe_start_layer=moe_start_layer + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def can_run_ms(self): + attn_metadata = get_forward_context().attn_metadata + # enable prefill overlap + with_prefill = getattr(attn_metadata, "with_prefill_across_dp", False) + if attn_metadata is None or not with_prefill or not attn_metadata.enable_dbo_across_dp: + return False + # if torch.distributed.get_rank() == 0: + # print(attn_metadata) + return True + + def _forward_ms_layers( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + kv_caches: Optional[List[torch.Tensor]] = None, + ): + + if moe_start_layer == self.end_layer: + return hidden_states, residual + + attn_metadata, [positions, hidden_states, + residual] = self.ms_pre_layer( + [positions, hidden_states, residual], ) + # if torch.distributed.get_rank() == 0: + # print(attn_metadata[0], attn_metadata[1]) + # exit() + # the rest layers + for i in range(moe_start_layer, self.end_layer): + layer = self.layers[i] + ms_layer_forward_func = layer._forward_ms_layer_alltoallv_finegrained + # print("get_called......") + hidden_states, residual = ms_layer_forward_func( + positions=positions, + hidden_states=hidden_states, + residual=residual, + attn_metadata=attn_metadata, + ) + advance_step_multistream_layer_context() + + [hidden_states, + residual] = self.ms_post_layer([hidden_states, residual], ) + return hidden_states, residual + + diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index fd32a18abb..a41bb7bccb 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -4,7 +4,7 @@ import numpy as np import torch -from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.attention_v1 import AscendAttentionState, AscendMetadata from .base import MSAttentionMetadataSplitConfig @@ -245,3 +245,108 @@ def model_input_split_v1_mla_attn( enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) return [attention_metadata_pre, attention_metadata_post] + + + def model_input_split_v1_attn( + attn_metadata: AscendMetadata, + _metadata_cls, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> List[Any]: + assert 0 < ms_split_config.num_micro_batches < 3 + if attn_metadata is None: + return [attn_metadata] + [token_index, + seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_actual_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return [attn_metadata] + + + # split attn metadata + + + [block_table_pre, block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, seq_index) + + query_start_loc_pre = query_start_loc_post = None + if attn_metadata.query_start_loc is not None: + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] + + [query_lens_pre, query_lens_post] = split_attn_tensor_type(attn_metadata.query_lens, seq_index) + [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(attn_metadata.seq_lens, seq_index) + + max_query_len_pre = max_query_len_post = None + if attn_metadata.max_query_len is not None: + max_query_len_pre, max_query_len_post = max(query_lens_pre), max(query_lens_post) + + [slot_mapping_pre, slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping, token_index) + + is_only_prefill_pre = is_only_prefill_post = attn_metadata.is_only_prefill + has_prefill_pre, has_prefill_post = torch.any(query_lens_pre > 1).item(), torch.any(query_lens_post > 1).item() + + if not attn_metadata.is_only_prefill: + is_only_prefill_post = torch.all(query_lens_post > 1).item() + + + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + # the attn_mla kernel in torch npu only accept 128*128 attn mask + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = attn_metadata.attn_state + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + # should be none in decode only state + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly + else: + # chunked prefill + if has_prefill_pre: + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( + seq_lens_pre)].contiguous() + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + else: + attn_state_pre = AscendAttentionState.DecodeOnly + attn_mask_pre = None + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + + # construct metadata + attention_metadata_pre = _metadata_cls( + num_actual_tokens=token_index, + block_tables=block_table_pre, + query_start_loc=query_start_loc_pre, + query_lens=query_lens_pre, + seq_lens=seq_lens_pre, + max_query_len=max_query_len_pre, + slot_mapping=slot_mapping_pre, + is_only_prefill=is_only_prefill_pre, + attn_state=attn_state_pre, + attn_mask=attn_mask_pre, + num_input_tokens=token_index, + with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, + ) + + attention_metadata_post = _metadata_cls( + num_actual_tokens=attn_metadata.num_actual_tokens - token_index, + block_tables=block_table_post, + query_start_loc=query_start_loc_post, + query_lens=query_lens_post, + seq_lens=seq_lens_post, + max_query_len=max_query_len_post, + slot_mapping=slot_mapping_post, + is_only_prefill=is_only_prefill_post, + attn_state=attn_state_post, + attn_mask=attn_mask_post, + num_input_tokens=attn_metadata.num_input_tokens - token_index, + with_prefill_across_dp=attn_metadata.with_prefill_across_dp, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, + ) + + return [attention_metadata_pre, attention_metadata_post] \ No newline at end of file diff --git a/vllm_ascend/ops/comm_utils.py b/vllm_ascend/ops/comm_utils.py new file mode 100644 index 0000000000..6c43773308 --- /dev/null +++ b/vllm_ascend/ops/comm_utils.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +import torch +import torch.distributed +import torch.distributed as dist +import torch_npu + +COMM_STREAM = None + + +def async_all_gather(input_, + group, + event=None, + is_use_get_global_memory_buffer=False): + world_size = torch.distributed.get_world_size(group) + dim_size = list(input_.size()) + new_dim_size = dim_size[0] * world_size + dim_size[0] = new_dim_size + + ag_out = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + if event: + # multi stream wait event + global COMM_STREAM + if COMM_STREAM is None: + COMM_STREAM = torch_npu.npu.Stream( + device=torch.npu.current_device()) + with torch_npu.npu.stream(COMM_STREAM): + event.wait() + handle = torch.distributed._all_gather_base(ag_out, + input_.contiguous(), + group=group, + async_op=True) + else: + handle = torch.distributed._all_gather_base(ag_out, + input_.contiguous(), + group=group, + async_op=True) + return input_, ag_out, handle + + +def async_reduce_scatter(input_, + group, + event=None, + stream=None, + is_use_get_global_memory_buffer=False): + world_size = dist.get_world_size(group) + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] // world_size + + rs_out = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + if event or stream: + # multi stream wait event + global COMM_STREAM + if COMM_STREAM is None: + COMM_STREAM = torch_npu.npu.Stream( + device=torch.npu.current_device()) + with torch_npu.npu.stream(COMM_STREAM): + if event: + event.wait() + if stream: + torch.npu.current_stream().wait_stream(stream) + handle = torch.distributed.reduce_scatter_tensor( + rs_out, input_.contiguous(), group=group, async_op=True) + else: + handle = torch.distributed.reduce_scatter_tensor(rs_out, + input_.contiguous(), + group=group, + async_op=True) + return input_, rs_out, handle + + +def async_all_to_all(input_, + output_split_sizes, + input_split_sizes, + group, + event=None): + if output_split_sizes is None: + # Equal split (all2all) + a2a_out = torch.empty_like(input_) + else: + # Unequal split (all2all-v) + a2a_out = input_.new_empty( + size=[sum(output_split_sizes)] + list(input_.size()[1:]), + dtype=input_.dtype, + device=torch.npu.current_device(), + ) + + if event: + # multi stream wait event + global COMM_STREAM + if COMM_STREAM is None: + COMM_STREAM = torch_npu.npu.Stream( + device=torch.npu.current_device()) + with torch_npu.npu.stream(COMM_STREAM): + event.wait() + handle = dist.all_to_all_single( + a2a_out, + input_.contiguous(), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True) + else: + handle = dist.all_to_all_single(a2a_out, + input_.contiguous(), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True) + return input_, a2a_out, handle diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 680e474701..4ef1b1030a 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -43,8 +43,12 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_ascend_soc_version, npu_stream_switch, npu_wait_tensor) +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) -MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER +VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER +VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: bool = envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALLV +VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, @@ -56,11 +60,11 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, if original_total_elements == 0: output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len, ), + topk_ids_pad = torch.full((output_len,), expert_num, dtype=original_dtype, device=device) - unpad_indices = torch.full((original_total_elements, ), + unpad_indices = torch.full((original_total_elements,), -1, dtype=torch.long, device=device) @@ -91,13 +95,13 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, is_kept_mask, indices_in_rec_cond_list_for_all, torch.tensor(-1, device=device, dtype=torch.long)) output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len, ), + topk_ids_pad = torch.full((output_len,), expert_num, dtype=original_dtype, device=device) if topk_ids.shape[0] > 0: all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx - temp_pad_buffer = torch.full((output_len + 1, ), + temp_pad_buffer = torch.full((output_len + 1,), expert_num, dtype=original_dtype, device=device) @@ -112,16 +116,16 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, def fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: Optional[str] = None, - shared_experts: Optional[Any] = None, - is_torchair: bool = False, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: Optional[str] = None, + shared_experts: Optional[Any] = None, + is_torchair: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: quant_mode = 0 ep_group = get_ep_group() @@ -167,7 +171,7 @@ def fused_experts_with_mc2( output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ - 0:5] + 0:5] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): @@ -243,7 +247,7 @@ def fused_experts_with_mc2( return hidden_states, shared_hidden_states -def apply_mlp(hidden_states_wrapper: List[torch.Tensor], +def apply_mlp(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, group_list: torch.Tensor, @@ -269,9 +273,6 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], hidden_states: output hidden states after MLP. """ - assert len(hidden_states_wrapper) == 1 - hidden_states = hidden_states_wrapper.pop() - w1 = w1.transpose(1, 2) hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -299,15 +300,17 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], return hidden_states +# currently expert parallelism implemented with all2all +# is under-optimized. def fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, ): original_shape = hidden_states.shape if len(original_shape) == 3: @@ -325,7 +328,7 @@ def fused_experts_with_all2all( row_idx_len, dtype=torch.int32, device=device).view(top_k, -1).permute( - 1, 0).contiguous()) + 1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -364,7 +367,7 @@ def fused_experts_with_all2all( row_idx_len, dtype=torch.int32, device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() + top_k, -1).permute(1, 0).contiguous() hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -437,16 +440,16 @@ def fused_experts_with_all2all( # currently expert parallelism implemented with all2all # is under-optimized. def fused_experts_with_all2all_buffer( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - max_model_len: int, - global_batch_size: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + max_model_len: int, + global_batch_size: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, ): original_shape = hidden_states.shape if len(original_shape) == 3: @@ -467,9 +470,9 @@ def fused_experts_with_all2all_buffer( expert_idx=topk_ids, active_num=num_tokens) - max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * - max_model_len // ep_group.world_size + - 1) * top_k * 2 + max_row_per_ep_rank = ( + -(-global_batch_size // ep_group.world_size) * max_model_len * + get_dp_group().world_size // ep_group.world_size + 1) * top_k * 2 expert_idx_buffer_scatter, unpad_indices = process_topk_ids( expanded_expert_idx, global_num_experts, ep_group.world_size, max_row_per_ep_rank, num_tokens, top_k) @@ -481,9 +484,9 @@ def fused_experts_with_all2all_buffer( (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) hidden_states_pad_idx[ expert_idx_buffer_scatter != global_num_experts] = torch.arange( - non_pad_len, - dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] expert_idx_buffer_gather = torch.empty_like( @@ -502,7 +505,7 @@ def fused_experts_with_all2all_buffer( group=ep_group.device_group) mask = expert_idx_buffer_gather != global_num_experts local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( - global_num_experts // ep_group.world_size) + global_num_experts // ep_group.world_size) hidden_states = hidden_states_buffer_gather[mask] idx_type = local_expert_idx.dtype sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) @@ -513,10 +516,7 @@ def fused_experts_with_all2all_buffer( hidden_states = hidden_states[sorted_idx] group_list_type = 0 - hidden_states_wrapper = [hidden_states] - del hidden_states - - hidden_states = apply_mlp(hidden_states_wrapper, + hidden_states = apply_mlp(hidden_states, w1, w2, expert_tokens, @@ -561,15 +561,33 @@ def fused_experts_with_all2all_buffer( return final_hidden_states +def fused_experts_with_all2allv(token_dispatcher, probs, routing_map, hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor): + # Enable moe alltoallv, it's a balanced policy for precision and efficiency. + (share_experts_output, dispatched_input, tokens_per_expert) = token_dispatcher.token_permutation( + hidden_states, probs, routing_map + ) + hidden_states_wrapper = [dispatched_input] + del dispatched_input + + expert_output = apply_mlp(hidden_states_wrapper, + w1, + w2, + tokens_per_expert) + output, mlp_bias = token_dispatcher.token_unpermutation(expert_output) + return output + + def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - apply_router_weight_on_input: bool = False, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ Fused experts with top-k routing. @@ -613,7 +631,7 @@ def fused_experts( ), "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert ( - topk == 1 + topk == 1 ), "Only support topk=1 when `apply_router_weight_on_input` is True" hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) @@ -622,7 +640,7 @@ def fused_experts( token_indices = (torch.arange(num_tokens, device=device, dtype=torch.int64).unsqueeze(1).expand( - -1, top_k).reshape(-1)) + -1, top_k).reshape(-1)) # Flatten token-to-expert mappings and map to local experts weights_flat = topk_weights.view(-1) @@ -662,7 +680,7 @@ def fused_experts( row_idx_len, dtype=torch.int32, device=device).view(top_k, -1).permute( - 1, 0).contiguous()) + 1, 0).contiguous()) sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -736,9 +754,9 @@ def fused_experts( def native_grouped_topk( - topk_weights: torch.Tensor, - num_expert_group: Optional[int], - topk_group: Optional[int], + topk_weights: torch.Tensor, + num_expert_group: Optional[int], + topk_group: Optional[int], ): topk_group = 0 if topk_group is None else topk_group num_expert_group = 0 if num_expert_group is None else num_expert_group @@ -761,16 +779,16 @@ def native_grouped_topk( def select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Select top-k experts based on router logits. @@ -880,30 +898,30 @@ def process_weights_after_loading(self, layer): self).process_weights_after_loading(layer) layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w13_weight.data), - requires_grad=False) + requires_grad=False) layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w2_weight.data), - requires_grad=False) + requires_grad=False) def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - is_prefill: bool = False, - enable_force_load_balance: bool = False, - shared_experts: Optional[Any] = None, - **kwargs, + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = False, + enable_force_load_balance: bool = False, + shared_experts: Optional[Any] = None, + **kwargs, ) -> torch.Tensor: # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern @@ -943,6 +961,8 @@ def apply( topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) fused_moe_state = get_forward_context().fused_moe_state + use_alltoallv = 'token_dispatcher' in kwargs and kwargs.get('token_dispatcher') is not None + if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, @@ -963,7 +983,7 @@ def apply( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) - elif MOE_ALL2ALL_BUFFER: + elif VLLM_ASCEND_MOE_ALL2ALL_BUFFER: return fused_experts_with_all2all_buffer( hidden_states=x, w1=layer.w13_weight, @@ -975,6 +995,14 @@ def apply( global_batch_size=self.global_batch_size, expert_map=expert_map, ep_group=get_ep_group()) + elif use_alltoallv and is_prefill: + token_dispatcher = kwargs.get('token_dispatcher') + return fused_experts_with_all2allv(token_dispatcher=token_dispatcher, + probs=topk_weights, + routing_map=topk_ids, + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight) else: return fused_experts_with_all2all(hidden_states=x, w1=layer.w13_weight, @@ -987,33 +1015,32 @@ def apply( class AscendFusedMoE(FusedMoE): - # The moe_counter parameter is required during the initialization of EPLB # to identify the current layer index within the MOE model. moe_counter = -1 def __init__( - self, - num_experts: int, # Global number of experts - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = False, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - ep_size: Optional[int] = None, - dp_size: Optional[int] = None, - prefix: str = "", - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, + self, + num_experts: int, # Global number of experts + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ep_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, ): # TODO: This could not initialize FusedMoE baseclass, # fixme and make __init__() of AscendFusedMoE more clear @@ -1062,17 +1089,19 @@ def __init__( expert_load_balancer = ExpertLoadBalancer(expert_map_path, self.global_num_experts) self.local_num_experts, self.expert_map = \ - expert_load_balancer.get_rank_placement_map( - self.moe_instance_id, - self.ep_rank) + expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, + get_ep_group().rank_in_group) self.log2phy = expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, self.ep_rank) + self.moe_instance_id, + get_ep_group().rank_in_group) self.global_redundant_expert_num = \ - expert_load_balancer.get_global_redundant_expert_num() + expert_load_balancer.get_global_redundant_expert_num() else: # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts) + self.ep_size, + get_ep_group().rank_in_group, self.global_num_experts) self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = \ @@ -1106,7 +1135,7 @@ def __init__( "num_experts": local_num_experts, "hidden_size": hidden_size, "intermediate_size_per_partition": - self.intermediate_size_per_partition, + self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, } @@ -1115,9 +1144,25 @@ def __init__( in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): moe_quant_params["intermediate_size_full"] = intermediate_size + self.ep_group = get_ep_group() # NOTE: self.tp_group is not expert_tp_group self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) + self.token_dispatcher = None + if VLLM_ASCEND_ENABLE_MOE_ALL2ALLV and isinstance( + self.quant_method, AscendUnquantizedFusedMoEMethod): + moe_dispatcher_config = ( + MoeDispatcherConfig().set_num_moe_experts(self.global_num_experts) + .set_num_local_experts(self.local_num_experts) + .set_moe_router_topk(top_k) + .set_group_topk(topk_group) + .set_num_groups(num_expert_group) + .set_expert_bias(e_score_correction_bias) + .set_scaling_factor(1.0).build()) + self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(moe_dispatcher_config) + if VLLM_ASCEND_ENABLE_DBO: + token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(moe_dispatcher_config) + self.token_dispatchers = [self.token_dispatcher, token_dispatcher1] def forward(self, hidden_states: torch.Tensor, @@ -1125,7 +1170,8 @@ def forward(self, is_prefill: bool, enable_force_load_balance: bool = False, top_k: Optional[int] = None, - shared_experts: Optional[Any] = None): + shared_experts: Optional[Any] = None, + replace_allreduce: bool = False): assert self.quant_method is not None if top_k: @@ -1134,14 +1180,18 @@ def forward(self, real_top_k = self.top_k num_tokens, hidden_size = hidden_states.shape + is_deepseek_v3_r1 = self.global_num_experts == 256 - fused_moe_state = get_forward_context().fused_moe_state + fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size, + is_prefill, is_deepseek_v3_r1) if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) tp_size = get_tensor_model_parallel_world_size() - if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: + if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather + and fused_moe_state != FusedMoEState.AllGatherEP + and not replace_allreduce): if num_tokens < tp_size: hidden_states = nn.functional.pad( hidden_states, (0, 0, 0, tp_size - num_tokens)) @@ -1159,15 +1209,16 @@ def forward(self, if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: - max_num_tokens_across_dp = get_forward_context( - ).max_tokens_across_dp - if num_tokens < max_num_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: + max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp + if num_tokens < max_num_tokens_across_dp: + hidden_states = nn.functional.pad( + hidden_states, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + router_logits = nn.functional.pad( + router_logits, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) hidden_states = get_dp_group().all_gather(hidden_states, 0) router_logits = get_dp_group().all_gather(router_logits, 0) @@ -1191,14 +1242,17 @@ def forward(self, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=shared_experts if self.torchair_graph_enabled - and self.enable_multistream_moe and not is_prefill else None, + and self.enable_multistream_moe and not is_prefill else None, + token_dispatcher=self.token_dispatcher ) if shared_experts: if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states - if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: + if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather + and fused_moe_state != FusedMoEState.AllGatherEP + and not replace_allreduce): dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0) @@ -1216,7 +1270,8 @@ def forward(self, else: final_hidden_states = e_hidden_states - if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather + or fused_moe_state == FusedMoEState.AllGatherEP): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -1228,12 +1283,12 @@ def forward(self, # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_fused_moe_comp( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_prefill: bool, - real_top_k, - enable_force_load_balance: bool = False, + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + real_top_k, + enable_force_load_balance: bool = False, ): hidden_states = self.quant_method.apply( layer=self, diff --git a/vllm_ascend/ops/moe_dispatcher/__init__.py b/vllm_ascend/ops/moe_dispatcher/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/vllm_ascend/ops/moe_dispatcher/moe_utils.py b/vllm_ascend/ops/moe_dispatcher/moe_utils.py new file mode 100644 index 0000000000..6cffe4ac5f --- /dev/null +++ b/vllm_ascend/ops/moe_dispatcher/moe_utils.py @@ -0,0 +1,379 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import math +from typing import Optional + +import torch +import torch_npu + + +def group_limited_topk( + scores: torch.Tensor, + topk: int, + num_tokens: int, + num_experts: int, + num_groups: int, + group_topk: int, +): + """Perform top-k routing on a subset of expert groups. + + When using group-limited routing: + 1. Experts are divided into 'moe_router_num_groups' equal-sized groups + 2. For each token, 'moe_router_group_topk' groups are selected based on routing scores + (specifically, the sum of top-2 expert scores within each group) + 3. From these selected groups, 'moe_router_topk' individual experts are chosen + + Two common use cases: + - Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP) + to limit each token to experts on a subset of devices + (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) + + - Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group + to limit each token to experts on a subset of nodes + (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437) + + Args: + scores (torch.Tensor): Softmax scores generated by the router. + topk (int): The number of experts to select for each token. + num_tokens (int): The number of tokens. + num_experts (int): The number of experts. + num_groups (int): Number of groups for routed experts. + group_topk (int): Number of groups selected for each token. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor. + """ + # Organize the experts into groups + # Select groups based on sum of top-(num_groups/group_topk) routing scores within each group + group_scores = (scores.view(num_tokens, + num_groups, -1).topk(num_groups // group_topk, + dim=-1)[0].sum(dim=-1)) + group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + + # Mask the experts based on selection groups + score_mask = (group_mask.unsqueeze(-1).expand( + num_tokens, num_groups, + num_experts // num_groups).reshape(num_tokens, -1)) + + masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf')) + probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1) + + return probs, top_indices + + +def topk_softmax_with_capacity( + logits: torch.Tensor, + topk: int, + capacity_factor: Optional[float] = None, + pad_to_capacity: bool = False, + drop_policy: str = "probs", + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: Optional[float] = None, + deterministic_mode: bool = False, + score_function: str = "sigmoid", + expert_bias: Optional[torch.Tensor] = None, +): + """Apply capacity and padding to the top-k selection. + Args: + logits (torch.Tensor): Logits tensor. + topk (int): The number of experts to select for each token. + capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number + of tokens exceeds the capacity. + pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded + tokens will be 0. + drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". + If "prob", the tokens with the lowest probabilities will be dropped. + If "position", tokens at the end of each batch will be dropped. + use_pre_softmax (bool): Whether to apply softmax before top-k selection. + num_groups (int): Number of groups for routed experts. + group_topk (int): Number of selected groups for each token. + scaling_factor (float): Scaling factor of routing score in top-k selection. + deterministic_mode (bool): Deprecated. + score_function (str): The score function to use. Can be either "softmax" or "sigmoid". + expert_bias (torch.Tensor): The bias added to logits for expert routing. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing + the routing probabilities for each token to each expert. + - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts] + indicating which experts were selected for each token. True values represent + the selected experts. + - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing + the number of local tokens assigned to each expert before dropping and padding. + """ + assert logits.dim( + ) == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." + num_tokens, num_experts = logits.shape + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return torch.topk(scores, k=topk, dim=1) + + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, + dtype=torch.float32).type_as(logits) + probs, top_indices = compute_topk(scores, topk, num_groups, + group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, + group_topk) + probs = torch.softmax(scores, dim=-1, + dtype=torch.float32).type_as(logits) + if scaling_factor: + probs = probs * scaling_factor + elif score_function == "sigmoid": + probs, top_indices, _ = torch_npu.npu_moe_gating_top_k( + logits, + k=topk, # topk当前写8 + bias=expert_bias, + k_group=group_topk, # fix: 4 + group_count=num_groups, # fix 8 + group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + # out_flag=False, # 第三个输出是否输出 + # y2_flag=False, # old api; 第三个输出是否输出 + routed_scaling_factor=scaling_factor, + eps=float(1e-20)) + else: + raise ValueError(f"Invalid score_function: {score_function}") + + # Try using element-wise operations instead of scatter? + topk_masked_gates = torch.zeros_like(logits).scatter( + 1, top_indices.type(torch.int64), probs) + topk_map = torch.zeros_like(logits).int().scatter( + 1, top_indices.type(torch.int64), 1).bool() + tokens_per_expert = topk_map.sum(dim=0) + + if capacity_factor is None: + # TopK without capacity + return topk_masked_gates, topk_map, tokens_per_expert, top_indices + else: + # TopK with capacity + expert_capacity = get_capacity(num_tokens=num_tokens * topk, + num_experts=num_experts, + capacity_factor=capacity_factor) + + # Maskout exceeded tokens + if drop_policy == "probs": + _, capacity_indices = torch.topk(topk_masked_gates, + k=expert_capacity, + dim=0, + sorted=False) + capacity_mask = torch.zeros_like(logits).scatter( + 0, capacity_indices, 1).bool() + elif drop_policy == "position": + _, capacity_indices = torch.topk(topk_map.int(), + k=expert_capacity, + dim=0, + sorted=False) + capacity_mask = torch.zeros_like(logits).scatter( + 0, capacity_indices, 1).bool() + else: + raise ValueError(f"Invalid drop_policy: {drop_policy}") + + if pad_to_capacity: + final_map = capacity_mask + final_probs = topk_masked_gates * final_map + else: + # Get exceed mask and maskout exceeded probs and indices + final_map = torch.logical_and(topk_map, capacity_mask) + final_probs = topk_masked_gates * final_map + return final_probs, final_map, tokens_per_expert, top_indices + + +def get_capacity(num_tokens: int, + num_experts: int, + capacity_factor: float, + min_capacity=None): + """ + Calculate the capacity of each expert. + + Args: + num_tokens (int): num of the input tokens. + num_experts (int): num of the experts. + capacity_factor (float): Capacity factor. + min_capacity (int, optional): Minimum capacity. Defaults to None. + + Returns: + Tensor: Capacity of each expert. + """ + capacity = math.ceil((num_tokens / num_experts) * capacity_factor) + if min_capacity is not None and capacity < min_capacity: + capacity = min_capacity + return capacity + + +def permute( + tokens, + routing_map, + num_out_tokens: Optional[int] = None, + fused: bool = False, + drop_and_pad: bool = False, +): + """Permute the tokens and probs based on the mask. + Tokens with the same designated expert will be grouped together. + The shape of mask is [tokens, num_experts], it indicates which experts were selected + by each token. + + When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to + expert capacity. This function exploits this feature to use ops that support cuda graph. + + Args: + tokens (torch.Tensor): The input token tensor, [num_tokens, hidden]. + routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts]. + num_out_tokens (int, optional): The number of output tokens. If None, it's set to + the number of input tokens. + fused (bool, optional): Whether use the fused permute function. + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + If set to true, routing_map has a fixed number of non-zeros + in each column. + """ + + num_tokens, hidden = tokens.shape + num_experts = routing_map.shape[1] + if drop_and_pad and (num_out_tokens is not None): + capacity = num_out_tokens // num_experts + assert not routing_map.requires_grad + # mask [num_tokens, num_experts] -> [num_experts, num_tokens] + routing_map = routing_map.to(dtype=torch.int8).T.contiguous() + # use argsort to put indices of all non-zeros in the beginning of list + # and keep the first `capacity` number of indices + sorted_indices = routing_map.argsort( + dim=-1, descending=True, stable=True)[:, :capacity].contiguous() + # flatten from [num_experts, capacity] to 1D + sorted_indices = sorted_indices.view(-1) + else: + # mask [num_tokens, num_experts] -> [num_experts, num_tokens] + routing_map = routing_map.bool().T.contiguous() + + # Create a dense expert-to-token mapping from the sparse token-to-expert mapping + token_indices = (torch.arange( + num_tokens, + device=routing_map.device).unsqueeze(0).expand(num_experts, -1)) + sorted_indices = token_indices.masked_select(routing_map) + + # use the mapping to permute the tokens + permuted_input = tokens.index_select(0, sorted_indices) + + return permuted_input, sorted_indices + + +def unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + restore_shape: torch.Size, + probs: torch.Tensor = None, + routing_map: torch.Tensor = None, + fused: bool = False, + drop_and_pad: bool = False, +): + """ + Restore the original order of tokens after permutation. If probs are provided, it + will also apply them to the tokens before restoring the order. + + When drop_and_pad=True, the tensors will have the following properties: + - In routing_map, the number of non-zeros in each column equals to expert capacity + - The size of sorted_indices equals to num_experts * capacity, each split of `capacity` + contains the indices of tokens routed to an expert. + This function exploits these features to use ops that support cuda graph. + + Args: + permuted_tokens (torch.Tensor): The permuted token tensor. + sorted_indices (torch.Tensor): The indices used to sort the tokens. + restore_shape (torch.Size): The shape of the unpermuted tensor. + probs (torch.Tensor, optional): The unpermuted probs tensor, + routing_map (torch.Tensor, optional): Token to expert mapping, shape + [num_tokens, num_experts]. + fused (bool, optional): Whether use the fused unpermute function. + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + + Returns: + torch.Tensor: The tokens restored to their original order. + """ + + _, hidden = restore_shape + input_dtype = permuted_tokens.dtype + + if probs is not None: + assert routing_map is not None, "Mask must be provided to permute the probs." + if drop_and_pad: + num_experts = routing_map.size(1) + num_permuted_tokens = sorted_indices.size(0) + capacity = num_permuted_tokens // num_experts + num_unpermuted_tokens = probs.size(0) + + # [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens + probs_T_1D = probs.T.contiguous().view(-1) + + # get 1D indices of the probs selected by routing_map + indices_dim0 = torch.arange( + num_experts, device=routing_map.device).unsqueeze(-1) + indices_dim1 = sorted_indices.view(num_experts, capacity) + indices_1D = (indices_dim0 * num_unpermuted_tokens + + indices_dim1).view(-1) + + # get probs from indices + permuted_probs = probs_T_1D.index_select(0, indices_1D) + else: + permuted_probs = probs.T.contiguous().masked_select( + routing_map.T.contiguous()) + # Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in + # higher precision due to moe_router_dtype being enabled. This can lead to + # additional GPU memory usage. Use --moe-permute-fusion flag to avoid this extra memory + # allocation. + permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) + + # Create an output tensor filled with zeros + output_tokens = torch.zeros(restore_shape, + dtype=permuted_tokens.dtype, + device=permuted_tokens.device) + # Scatter add the permuted_input back to the original positions + output_tokens.scatter_add_(0, + sorted_indices.unsqueeze(1).expand(-1, hidden), + permuted_tokens) + return output_tokens.to(dtype=input_dtype) + + +def sort_chunks_by_idxs(input: torch.Tensor, + split_sizes: torch.Tensor, + sorted_idxs: torch.Tensor, + fused: bool = False): + """Split and sort the input tensor based on the split_sizes and sorted indices.""" + if input.shape[0] == 0: + return input + + input = torch.split(input, split_sizes.tolist(), dim=0) + output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0) + return output diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py new file mode 100644 index 0000000000..60dd4f1be2 --- /dev/null +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -0,0 +1,696 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch_npu + +from vllm.distributed.parallel_state import get_ep_group +from vllm_ascend.distributed.tensor_parallel import ( + all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, + all_to_all_sp2hp, gather_from_sequence_parallel_region, + reduce_scatter_last_dim_to_tensor_parallel_region) +from vllm_ascend.ops.comm_utils import async_all_to_all +from vllm_ascend.ops.moe_dispatcher.moe_utils import ( + get_capacity, permute, sort_chunks_by_idxs, topk_softmax_with_capacity, + unpermute) + +""" We use the following notation throughout this file: + H: hidden size + B: micro batch size + S: sequence length + TP: tensor model parallel size + EP: expert model parallel size + num_local_tokens: S/TP*B + num_global_tokens: num_local_tokens*TP*EP +""" + + +class MoeDispatcherConfig: + + def __init__(self): + self.num_local_experts: int = 0 + self.num_moe_experts: int = 0 + self.moe_pad_expert_input_to_capacity: bool = False + self.moe_expert_capacity_factor: Optional[float] = None + self.moe_router_topk: int = 2 + self.moe_grouped_gemm: bool = False + self.group_topk: int = 0 + self.num_groups: int = 1 + self.expert_bias: torch.Tensor = None + self.scaling_factor: Optional[float] = None + self.is_fused: bool = True + + def set_num_local_experts(self, num_local_experts): + self.num_local_experts = num_local_experts + return self + + def set_num_moe_experts(self, num_moe_experts): + self.num_moe_experts = num_moe_experts + return self + + def set_moe_pad_expert_input_to_capacity(self, + moe_pad_expert_input_to_capacity): + self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity + return self + + def set_moe_expert_capacity_factor(self, moe_expert_capacity_factor): + self.moe_expert_capacity_factor = moe_expert_capacity_factor + return self + + def set_moe_router_topk(self, moe_router_topk): + self.moe_router_topk = moe_router_topk + return self + + def set_moe_grouped_gemm(self, moe_grouped_gemm): + self.moe_grouped_gemm = moe_grouped_gemm + return self + + def set_group_topk(self, group_topk): + self.group_topk = group_topk + return self + + def set_num_groups(self, num_groups): + self.num_groups = num_groups + return self + + def set_expert_bias(self, expert_bias): + self.expert_bias = expert_bias + return self + + def set_scaling_factor(self, scaling_factor): + self.scaling_factor = scaling_factor + return self + + def set_is_fused(self, is_fused): + self.is_fused = is_fused + return self + + def build(self): + return self + + +class MoEDispatcher: + + def __init__(self, config: MoeDispatcherConfig) -> None: + """ + Initialize the MoE Token Dispatcher. + """ + self.config = config + self.shared_experts = None + + def set_shared_experts(self, shared_experts): + self.shared_experts = shared_experts + + @property + def ep_group(self): + """Get expert model parallel group.""" + return get_ep_group().device_group + + @property + def ep_rank(self): + return get_ep_group().rank_in_group + + @property + def ep_size(self): + return get_ep_group().world_size + + @property + def tp_ep_group(self): + """Get expert tensor and model parallel group.""" + return None + + @property + def tp_ep_size(self): + return 1 + + +class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher): + overlap_stream = None + + """ + The implementation of the AlltoAll-based token dispatcher, which handles token + dispatching on the sequence level instead of token level. The core of this implementation + lies in each device dispatching on the entire sequence, with the hidden state being partitioned. + + """ + + def __init__(self, config: MoeDispatcherConfig): + """ + Initialize the AlltoAllSeq token dispatcher. + + Args: + config (MoeDispatcherConfig): Configuration for the transformer model. + """ + super().__init__(config) + self.num_local_experts = config.num_local_experts + self.config = config + # use MOEAlltoAllSEQTokenDispatcher to init + + self.hidden_shape = None + self.num_input_tokens = None + self.num_experts = config.num_moe_experts + assert self.num_local_experts > 0, "Expected at least one expert" + if self.num_local_experts > 1: + self.expert_ids_per_ep_rank = torch.tensor( + [i % self.num_local_experts for i in range(self.num_experts)], + dtype=torch.int32, + device=torch.npu.current_device(), + ) + + local_expert_indices_offset = ( + self.ep_rank * self.num_local_experts + ) + + self.local_expert_indices = [ + local_expert_indices_offset + i for i in range(self.num_local_experts) + ] + assert ( + len(self.local_expert_indices) == self.num_local_experts + ), "Invalid local expert indices" + for i in range(len(self.local_expert_indices) - 1): + assert ( + self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1 + ), "local_expert_indices must be continous" + self.probs = None + self.input_splits = None + self.output_splits = None + self.routing_map = None + self.hidden_shape_before_permute = None + + # [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert_cpu = None + self.num_global_tokens_per_local_expert = None + input_chunk_idxs = torch.arange(self.num_experts) + # [num_local_experts, ep_size]. Sort the input chunks by local experts. + self.sort_input_by_local_experts = input_chunk_idxs.reshape( + -1, self.num_local_experts + ).T.ravel() + # [ep_size, num_local_experts]. Restore the output chunks by local experts. + self.restore_output_by_local_experts = input_chunk_idxs.reshape( + self.num_local_experts, -1 + ).T.ravel().to(torch.device("cpu"), non_blocking=True) + + # Token drop and padding. + # We need to keep track of the token num if we drop tokens without padding them. + self.num_out_tokens = None + # Drop and pad the input to capacity. + self.drop_and_pad = self.config.moe_pad_expert_input_to_capacity + if self.drop_and_pad: + assert self.config.moe_expert_capacity_factor is not None + self.capacity = None + + # A cuda stream synchronization is needed in self.token_permutation() + # in some cases, because there are several non-blocking DtoH data + # transfers called in self.preprocess(). The synchronization happens + # at different points based on MoE settings as late as possible. + # Valid sync points are "before_permutation_1", "before_ep_alltoall", + # "before_finish", and "no_sync". + self.cuda_sync_point = "no_sync" + + # cached intermediate tensors. + self.cached_permutated_local_input_tokens = None + self.cached_global_input_tokens = None + self.cached_shared_expert_output = None + self.tokens_per_expert = None + + if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None: + MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream() + + self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream + + def preprocess(self, indices: torch.Tensor, with_sync=True) -> torch.Tensor: + """ + Preprocess routing map for AlltoAll communication and token permutation. + This method computes the number of tokens assigned to each expert based on + the routing map. It also initializes the necessary data structures for + AlltoAll communication, such as input and output splits, and the mapping + between global tokens and local experts. + + Args: + routing_map (torch.Tensor): The mapping of tokens to experts, with shape + [num_tokens, num_experts]. + + Returns: + torch.Tensor: Tensor containing the number of tokens assigned to local expert. + """ + num_local_tokens_per_expert = torch.histc( + indices, bins=self.num_experts, min=0, max=self.num_experts + ) + + # num_local_tokens_per_expert: [num_experts] + + ep_size = self.ep_size + if self.drop_and_pad: + # Drop and pad the input to capacity. + num_tokens = indices.numel() + self.capacity = get_capacity( + num_tokens=num_tokens, + num_experts=self.num_experts, + capacity_factor=self.config.moe_expert_capacity_factor, + ) + self.num_out_tokens = self.capacity * self.num_experts + num_tokens_per_local_expert = torch.full( + (self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long + ) + self.num_global_tokens_per_local_expert_cpu = torch.full( + (self.num_experts * self.tp_ep_size,), self.capacity, dtype=torch.long + ) + return num_tokens_per_local_expert + elif self.config.moe_expert_capacity_factor is not None: + # Token drop but no pad. A synchronization is needed before the first + # permutation to get the `num_out_tokens` CPU value. + self.num_out_tokens = num_local_tokens_per_expert.sum().to( + torch.device("cpu"), non_blocking=True + ) + self.cuda_sync_point = "before_permutation_1" + else: + # Dropless + self.num_out_tokens = indices.numel() + if self.ep_size > 1 or self.num_local_experts > 1: + # Token dropless and enable ep. A synchronization is needed before expert parallel + # AlltoAll communication to get the `input_splits` and `output_splits` CPU values. + self.cuda_sync_point = "before_ep_alltoall" + else: + # Token dropless and no ep. A synchronization is needed to get the + # `tokens_per_expert` CPU value. + self.cuda_sync_point = "before_finish" + + if ep_size > 1: + # =================================================== + # Calculate input_splits, output_splits for alltoall-v. + # =================================================== + self.input_splits = ( + num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts) + .sum(axis=1) + .to(torch.device("cpu"), non_blocking=True) + .numpy() + ) + num_global_tokens_per_expert = gather_from_sequence_parallel_region( + num_local_tokens_per_expert, group=self.ep_group + ).reshape(ep_size, self.num_experts) + self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[0]: + self.local_expert_indices[-1] + 1] + self.output_splits = ( + self.num_global_tokens_per_local_expert.sum(axis=-1) + .to(torch.device("cpu"), non_blocking=True) + .numpy() + ) + num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0) + # =================================================== + # num_global_tokens_per_expert: [ep_size, num_experts] + # num_global_tokens_per_local_expert: [ep_size, num_local_experts] + # num_tokens_per_local_expert: [num_local_experts] + # =================================================== + else: + self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape( + -1, self.num_experts + ) + num_tokens_per_local_expert = num_local_tokens_per_expert + + if self.num_local_experts > 1 and with_sync: + self.cuda_sync_point = "no_sync" + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel() + ) + + # self.num_global_tokens_per_local_expert_cpu = ( + # self.num_global_tokens_per_local_expert.view(-1, self.num_local_experts).to( + # torch.device("cpu"), non_blocking=True + # ) + # ) + # if not hasattr(self, "comm_stream"): + # self.comm_stream = torch.npu.Stream() + # self.comm_stream.wait_stream(torch.npu.current_stream()) + + return num_tokens_per_local_expert + + def routing(self, probs): + seq_length, bsz = probs.shape[:2] + probs = probs.view(-1, self.config.num_moe_experts) + + scores, routing_map, _, top_indices = topk_softmax_with_capacity( + probs, + self.config.moe_router_topk, + capacity_factor=self.config.moe_expert_capacity_factor, + pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, + group_topk=self.config.group_topk, + num_groups=self.config.num_groups, + expert_bias=self.config.expert_bias, + scaling_factor=self.config.scaling_factor + ) + self.top_indices = top_indices + return scores, routing_map + + def preprocess_overlap(self, routing_map): + num_tokens_per_local_expert = self.preprocess(routing_map) + self.num_global_tokens_per_local_expert = self.num_global_tokens_per_local_expert + self.input_splits = self.input_splits + self.output_splits = self.output_splits + self.num_out_tokens = self.num_out_tokens + self.num_global_tokens_per_local_expert_cpu = self.num_global_tokens_per_local_expert_cpu + return num_tokens_per_local_expert + + def token_permutation( + self, + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + ): + """ + Dispatch tokens to local experts using AlltoAllSeq communication. + + Args: + hidden_states (torch.Tensor): Input token embeddings. + probs (torch.Tensor): Probs of tokens assigned to experts. + Shape: [num_tokens, num_experts]. + routing_map (torch.Tensor): Mapping of tokens assigned to experts. + Shape: [num_tokens, num_experts]. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - Permuted token embeddings for local experts. + - Number of tokens per expert. + """ + self.hidden_shape = hidden_states.shape + self.probs = probs + self.routing_map = routing_map + self.top_indices = routing_map + assert probs.dim() == 2, "Expected 2D tensor for probs" + assert routing_map.dim() == 2, "Expected 2D tensor for routing map" + + # Permutation 1: input to AlltoAll input + def alltoall_token_permutation1(hidden_states, routing_map): + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + tokens_per_expert = self.preprocess(routing_map) + if self.tp_ep_size > 1: + hidden_states = all_to_all_sp2hp(hidden_states, group=self.tp_ep_group) + self.hidden_shape_before_permute = hidden_states.shape + + if self.cuda_sync_point == "before_permutation_1": + torch.npu.current_stream().synchronize() + if not self.config.is_fused: + permutated_local_input_tokens, reversed_local_input_permutation_mapping = permute( + hidden_states, + routing_map, + num_out_tokens=self.num_out_tokens, + ) + else: + permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( + tokens=hidden_states, + indices=self.top_indices, + num_out_tokens=self.num_out_tokens, + ) + return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert + + permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1( + hidden_states, routing_map) + self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping + # permute 1 + + ep_group = self.ep_group + + # Perform expert parallel AlltoAll communication + if self.cuda_sync_point == "before_ep_alltoall": + torch.npu.current_stream().synchronize() + _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( + permutated_local_input_tokens, + self.output_splits, + self.input_splits, + ep_group, + ) + + # shared experts compute + if self.shared_experts is not None: + (share_experts_output), *_ = self.shared_experts(hidden_states) + else: + share_experts_output = None + + permute1_ep_all_to_all_handle.wait() + permutated_local_input_tokens.untyped_storage().resize_(0) + + def alltoall_token_permutation2(global_input_tokens): + # Permutation 2: Sort tokens by local expert. + if self.num_local_experts > 1: + global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( + global_input_tokens, + self.global_input_tokens_local_experts_indices + ) + + # Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens. + # global_input_tokens: [SEQL, H/TP] -> [SEQL, H] + if self.tp_ep_size > 1 and self.config.moe_grouped_gemm: + global_input_tokens = all_gather_last_dim_from_tensor_parallel_region( + global_input_tokens, self.tp_ep_group + ) + if self.cuda_sync_point == "before_finish": + torch.npu.current_stream().synchronize() + + return global_input_tokens + + # token premute2 input + global_input_tokens = alltoall_token_permutation2(global_input_tokens) + + return share_experts_output, global_input_tokens, tokens_per_expert + + def preprocess_and_permtute1( + self, + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + shared_experts=None, + shared_experts_input: torch.Tensor = None + ): + self.hidden_shape = hidden_states.shape + self.probs = probs + self.top_indices = routing_map + assert probs.dim() == 2, "Expected 2D tensor for probs" + assert routing_map.dim() == 2, "Expected 2D tensor for routing map" + + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + tokens_per_expert = self.preprocess(routing_map, with_sync=False) + self.hidden_shape_before_permute = hidden_states.shape + + if self.cuda_sync_point == "before_permutation_1": + torch.npu.current_stream().synchronize() + + event = torch.npu.current_stream().record_event() + self.perm1_finish_event = torch.npu.Event() + with torch.npu.stream(self.overlap_stream): + self.overlap_stream.wait_event(event) + + if shared_experts is not None: + shared_output = shared_experts(shared_experts_input) + self.cached_shared_expert_output = shared_output + + if not self.config.is_fused: + hidden_states, self.reversed_local_input_permutation_mapping = permute( + hidden_states, + routing_map, + num_out_tokens=self.num_out_tokens, + ) + else: + hidden_states, self.reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( + tokens=hidden_states, + indices=self.top_indices, + num_out_tokens=self.num_out_tokens, + ) + + self.perm1_finish_event.record() + + # repeat interleve will launch a sync on current_stream. + if self.num_local_experts > 1: + self.cuda_sync_point = "no_sync" + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel() + ) + + self.cached_permutated_local_input_tokens = hidden_states + self.tokens_per_expert = tokens_per_expert + + def dispatch_alltoall(self): + ep_group = self.ep_group + + # Perform expert parallel AlltoAll communication + if self.cuda_sync_point == "before_ep_alltoall": + torch.npu.current_stream().synchronize() + + torch.npu.current_stream().wait_event(self.perm1_finish_event) + self.perm1_finish_event = None + _, self.cached_global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( + self.cached_permutated_local_input_tokens, + self.output_splits, + self.input_splits, + ep_group, + ) + permute1_ep_all_to_all_handle.wait() + self.cached_permutated_local_input_tokens.untyped_storage().resize_(0) + self.cached_permutated_local_input_tokens = None + + def permute2(self): + global_input_tokens = self.cached_global_input_tokens + if self.num_local_experts > 1: + global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( + self.cached_global_input_tokens, + self.global_input_tokens_local_experts_indices + ) + self.cached_global_input_tokens.untyped_storage().resize_(0) + self.cached_global_input_tokens = None + + return global_input_tokens, self.tokens_per_expert + + def unpermute1( + self, + hidden_states: torch.Tensor + ): + # Unpermutation 2: expert output to AlltoAll input + if hidden_states.shape[0] > 0 and self.num_local_experts > 1: + hidden_states = torch_npu.npu_moe_token_unpermute( + hidden_states, + self.reversed_global_input_permutation_mapping + ) + self.cached_global_output_tokens = hidden_states + self.reversed_global_input_permutation_mapping = None + + def combine_alltoall(self): + ep_group = self.ep_group + # Perform expert parallel AlltoAll communication + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + _, self.cached_local_output_tokens, handle = async_all_to_all( + self.cached_global_output_tokens, + self.input_splits, + self.output_splits, + ep_group + ) + handle.wait() + self.cached_global_output_tokens.untyped_storage().resize_(0) + self.cached_global_output_tokens = None + self.input_splits = None + self.output_splits = None + + def unpermute2(self): + output = torch_npu.npu_moe_token_unpermute( + permuted_tokens=self.cached_local_output_tokens, + sorted_indices=self.reversed_local_input_permutation_mapping.to(torch.int32), + probs=self.probs, + restore_shape=self.hidden_shape_before_permute + ) + + output = output.view(self.hidden_shape) + + self.probs = None + self.reversed_local_input_permutation_mapping = None + self.cached_local_output_tokens.untyped_storage().resize_(0) + self.cached_local_output_tokens = None + + return output + + def token_unpermutation( + self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None + ): + """ + Reverse the token permutation to restore the original order. + + Args: + hidden_states (torch.Tensor): Output from local experts. + bias (torch.Tensor, optional): Bias tensor (not supported). + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Unpermuted token embeddings in the original order. + - None (bias is not supported). + """ + + def alltoall_token_unpermutation1(hidden_states): + assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher" + # Perform tensor parallel Reduce-Scatter + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + if self.tp_ep_size > 1: + hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region(hidden_states, group=self.tp_ep_group) + + # Unpermutation 2: expert output to AlltoAll input + if hidden_states.shape[0] > 0 and self.num_local_experts > 1: + hidden_states = torch_npu.npu_moe_token_unpermute( + hidden_states, + self.reversed_global_input_permutation_mapping + ) + # hidden_states = sort_chunks_by_idxs( + # hidden_states, + # self.num_global_tokens_per_local_expert_cpu.T.ravel(), + # self.restore_output_by_local_experts, + # ) + return hidden_states + + hidden_states = alltoall_token_unpermutation1(hidden_states) + + ep_group = self.ep_group + # Perform expert parallel AlltoAll communication + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + _, permutated_local_input_tokens, handle = async_all_to_all( + hidden_states, + self.input_splits, + self.output_splits, + ep_group + ) + handle.wait() + hidden_states.untyped_storage().resize_(0) + + def alltoall_token_unpermutation2(permutated_local_input_tokens): + # Unpermutation 1: AlltoAll output to output + if self.config.is_fused: + # permuted_probs = (self.probs.T.contiguous().masked_select(self.routing_map.T.contiguous()) + # .view(-1, self.config.moe_router_topk)) + output = torch_npu.npu_moe_token_unpermute( + permuted_tokens=permutated_local_input_tokens, + sorted_indices=self.reversed_local_input_permutation_mapping.to(torch.int32), + probs=self.probs, + restore_shape=self.hidden_shape_before_permute + ) + else: + output = unpermute( + permutated_local_input_tokens, + self.reversed_local_input_permutation_mapping, + probs=self.probs, + restore_shape=self.hidden_shape_before_permute, + routing_map=self.routing_map, + ) + + # Perform tensor parallel AlltoAll communication + # output: [S*B, H/TP] -> [S*B/TP, H] + if self.tp_ep_size > 1: + output = all_to_all_hp2sp(output, self.tp_ep_group) + + # Reshape the output tensor + output = output.view(self.hidden_shape) + return output + + output = alltoall_token_unpermutation2(permutated_local_input_tokens) + + self.input_splits = None + self.output_splits = None + self.num_global_tokens_per_local_expert = None + self.num_global_tokens_per_local_expert_cpu = None + + return output, None From 6d7b5b4a098c7ac5e0db4818666c0acf4911dbd2 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 09:58:42 +0800 Subject: [PATCH 02/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/multistream/ms_split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index a41bb7bccb..6ca0a4b436 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -247,7 +247,7 @@ def model_input_split_v1_mla_attn( return [attention_metadata_pre, attention_metadata_post] - def model_input_split_v1_attn( +def model_input_split_v1_attn( attn_metadata: AscendMetadata, _metadata_cls, ms_split_config: MSAttentionMetadataSplitConfig, From 6a8e1a962ae685518cf54d3b77fbfa4f76c94a01 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 10:13:39 +0800 Subject: [PATCH 03/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/models/__init__.py | 6 +-- vllm_ascend/models/qwen3_dbo.py | 65 +++++++++++++++++++-------------- 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index c0e8c5be54..abf531d370 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -53,6 +53,6 @@ def register_model(): "DeepseekV3ForCausalLM", "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") - ModelRegistry.register_model( - "Qwen3MoeForCausalLM", - "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") + ModelRegistry.register_model( + "Qwen3MoeForCausalLM", + "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 042f4dc400..52054a3e98 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -1,30 +1,3 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# # Adapted from -# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py -# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py -# """Inference-only DeepseekV2/DeepseekV3 model.""" - from collections.abc import Iterable from typing import Any, Optional, Union, List from types import SimpleNamespace @@ -42,9 +15,12 @@ from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_world_size, get_tp_group, \ get_pp_group from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.models.utils import (make_empty_intermediate_tensors_factory, make_layers) +from vllm.model_executor.models.utils import (make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm_ascend.multistream.context import ( advance_step_multistream_layer_context, get_multistream_comm_context, @@ -509,3 +485,36 @@ def _forward_ms_layers( return hidden_states, residual +class CustomQwen3MoeForCausalLMDBO(Qwen3MoeForCausalLM): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomQwen3DBOMoEModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + From 4805c5ad8e5338406c5994c0076ecfbcccf0f622 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 10:23:53 +0800 Subject: [PATCH 04/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/models/qwen3_dbo.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 52054a3e98..ce6189d17f 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -401,6 +401,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + graph_enable: Optional[bool] = True ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -517,4 +518,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + graph_enable: Optional[bool] = True + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, graph_enable) + return hidden_states + From d68ce0761ff05282e3a7eff9d983144af214dc0b Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 10:29:47 +0800 Subject: [PATCH 05/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/models/qwen3_dbo.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index ce6189d17f..de449444c7 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -401,7 +401,6 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - graph_enable: Optional[bool] = True ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -527,7 +526,7 @@ def forward( graph_enable: Optional[bool] = True ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds, graph_enable) + inputs_embeds) return hidden_states From 0aff6937f739582ca3605fa4ddffb3215f484c1d Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 11:13:56 +0800 Subject: [PATCH 06/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- tests/singlecard/test_offline_inference.py | 21 ++ vllm_ascend/models/qwen3_dbo.py | 2 + vllm_ascend/ops/fused_moe.py | 270 ++++++++++----------- 3 files changed, 152 insertions(+), 141 deletions(-) diff --git a/tests/singlecard/test_offline_inference.py b/tests/singlecard/test_offline_inference.py index cd65a24969..de958ea7e6 100644 --- a/tests/singlecard/test_offline_inference.py +++ b/tests/singlecard/test_offline_inference.py @@ -131,3 +131,24 @@ def test_models_topk() -> None: enforce_eager=True, gpu_memory_utilization=0.7) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + + +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALLV": "1", "VLLM_ASCEND_ENABLE_DBO": "1"}) +def test_models_topk() -> None: + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(max_tokens=5, + temperature=0.0, + top_k=50, + top_p=0.9) + + with VllmRunner("Qwen/Qwen2.5-0.5B-Instruct", + max_model_len=8192, + dtype="float16", + enforce_eager=True, + gpu_memory_utilization=0.7) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index de449444c7..e6eeba78b0 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -6,6 +6,7 @@ import torch_npu from torch import nn from transformers import PretrainedConfig +from vllm.compilation.decorators import support_torch_compile from vllm.model_executor.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeModel from vllm.config import CacheConfig, VllmConfig @@ -352,6 +353,7 @@ def discard_tensor(tensor): return hidden_states, residual +@support_torch_compile class CustomQwen3DBOMoEModel(Qwen3MoeModel): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 4ef1b1030a..115e1c880a 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -60,11 +60,11 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, if original_total_elements == 0: output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len,), + topk_ids_pad = torch.full((output_len, ), expert_num, dtype=original_dtype, device=device) - unpad_indices = torch.full((original_total_elements,), + unpad_indices = torch.full((original_total_elements, ), -1, dtype=torch.long, device=device) @@ -95,13 +95,13 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, is_kept_mask, indices_in_rec_cond_list_for_all, torch.tensor(-1, device=device, dtype=torch.long)) output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len,), + topk_ids_pad = torch.full((output_len, ), expert_num, dtype=original_dtype, device=device) if topk_ids.shape[0] > 0: all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx - temp_pad_buffer = torch.full((output_len + 1,), + temp_pad_buffer = torch.full((output_len + 1, ), expert_num, dtype=original_dtype, device=device) @@ -116,16 +116,16 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, def fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: Optional[str] = None, - shared_experts: Optional[Any] = None, - is_torchair: bool = False, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: Optional[str] = None, + shared_experts: Optional[Any] = None, + is_torchair: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: quant_mode = 0 ep_group = get_ep_group() @@ -171,7 +171,7 @@ def fused_experts_with_mc2( output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ - 0:5] + 0:5] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): @@ -303,14 +303,14 @@ def apply_mlp(hidden_states: torch.Tensor, # currently expert parallelism implemented with all2all # is under-optimized. def fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, ): original_shape = hidden_states.shape if len(original_shape) == 3: @@ -440,16 +440,16 @@ def fused_experts_with_all2all( # currently expert parallelism implemented with all2all # is under-optimized. def fused_experts_with_all2all_buffer( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - max_model_len: int, - global_batch_size: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + max_model_len: int, + global_batch_size: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, ): original_shape = hidden_states.shape if len(original_shape) == 3: @@ -470,8 +470,7 @@ def fused_experts_with_all2all_buffer( expert_idx=topk_ids, active_num=num_tokens) - max_row_per_ep_rank = ( - -(-global_batch_size // ep_group.world_size) * max_model_len * + max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * max_model_len * get_dp_group().world_size // ep_group.world_size + 1) * top_k * 2 expert_idx_buffer_scatter, unpad_indices = process_topk_ids( expanded_expert_idx, global_num_experts, ep_group.world_size, @@ -505,7 +504,7 @@ def fused_experts_with_all2all_buffer( group=ep_group.device_group) mask = expert_idx_buffer_gather != global_num_experts local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( - global_num_experts // ep_group.world_size) + global_num_experts // ep_group.world_size) hidden_states = hidden_states_buffer_gather[mask] idx_type = local_expert_idx.dtype sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) @@ -580,14 +579,14 @@ def fused_experts_with_all2allv(token_dispatcher, probs, routing_map, hidden_sta def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - apply_router_weight_on_input: bool = False, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + apply_router_weight_on_input: bool = False, ) -> torch.Tensor: """ Fused experts with top-k routing. @@ -631,7 +630,7 @@ def fused_experts( ), "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert ( - topk == 1 + topk == 1 ), "Only support topk=1 when `apply_router_weight_on_input` is True" hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) @@ -640,7 +639,7 @@ def fused_experts( token_indices = (torch.arange(num_tokens, device=device, dtype=torch.int64).unsqueeze(1).expand( - -1, top_k).reshape(-1)) + -1, top_k).reshape(-1)) # Flatten token-to-expert mappings and map to local experts weights_flat = topk_weights.view(-1) @@ -680,7 +679,7 @@ def fused_experts( row_idx_len, dtype=torch.int32, device=device).view(top_k, -1).permute( - 1, 0).contiguous()) + 1, 0).contiguous()) sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -754,9 +753,9 @@ def fused_experts( def native_grouped_topk( - topk_weights: torch.Tensor, - num_expert_group: Optional[int], - topk_group: Optional[int], + topk_weights: torch.Tensor, + num_expert_group: Optional[int], + topk_group: Optional[int], ): topk_group = 0 if topk_group is None else topk_group num_expert_group = 0 if num_expert_group is None else num_expert_group @@ -779,16 +778,16 @@ def native_grouped_topk( def select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Select top-k experts based on router logits. @@ -898,30 +897,30 @@ def process_weights_after_loading(self, layer): self).process_weights_after_loading(layer) layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w13_weight.data), - requires_grad=False) + requires_grad=False) layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w2_weight.data), - requires_grad=False) + requires_grad=False) def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - is_prefill: bool = False, - enable_force_load_balance: bool = False, - shared_experts: Optional[Any] = None, - **kwargs, + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + is_prefill: bool = False, + enable_force_load_balance: bool = False, + shared_experts: Optional[Any] = None, + **kwargs, ) -> torch.Tensor: # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern @@ -1015,32 +1014,33 @@ def apply( class AscendFusedMoE(FusedMoE): + # The moe_counter parameter is required during the initialization of EPLB # to identify the current layer index within the MOE model. moe_counter = -1 def __init__( - self, - num_experts: int, # Global number of experts - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = False, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, - ep_size: Optional[int] = None, - dp_size: Optional[int] = None, - prefix: str = "", - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - apply_router_weight_on_input: bool = False, + self, + num_experts: int, # Global number of experts + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ep_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, ): # TODO: This could not initialize FusedMoE baseclass, # fixme and make __init__() of AscendFusedMoE more clear @@ -1089,19 +1089,17 @@ def __init__( expert_load_balancer = ExpertLoadBalancer(expert_map_path, self.global_num_experts) self.local_num_experts, self.expert_map = \ - expert_load_balancer.get_rank_placement_map( - self.moe_instance_id, - get_ep_group().rank_in_group) + expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, + self.ep_rank) self.log2phy = expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, - get_ep_group().rank_in_group) + self.moe_instance_id, self.ep_rank) self.global_redundant_expert_num = \ - expert_load_balancer.get_global_redundant_expert_num() + expert_load_balancer.get_global_redundant_expert_num() else: # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) + self.ep_size, self.ep_rank, self.global_num_experts) self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = \ @@ -1135,7 +1133,7 @@ def __init__( "num_experts": local_num_experts, "hidden_size": hidden_size, "intermediate_size_per_partition": - self.intermediate_size_per_partition, + self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, } @@ -1170,8 +1168,7 @@ def forward(self, is_prefill: bool, enable_force_load_balance: bool = False, top_k: Optional[int] = None, - shared_experts: Optional[Any] = None, - replace_allreduce: bool = False): + shared_experts: Optional[Any] = None): assert self.quant_method is not None if top_k: @@ -1180,18 +1177,14 @@ def forward(self, real_top_k = self.top_k num_tokens, hidden_size = hidden_states.shape - is_deepseek_v3_r1 = self.global_num_experts == 256 - fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size, - is_prefill, is_deepseek_v3_r1) + fused_moe_state = get_forward_context().fused_moe_state if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) tp_size = get_tensor_model_parallel_world_size() - if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather - and fused_moe_state != FusedMoEState.AllGatherEP - and not replace_allreduce): + if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: if num_tokens < tp_size: hidden_states = nn.functional.pad( hidden_states, (0, 0, 0, tp_size - num_tokens)) @@ -1209,16 +1202,15 @@ def forward(self, if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: - attn_metadata = get_forward_context().attn_metadata - if attn_metadata is not None: - max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp - if num_tokens < max_num_tokens_across_dp: - hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) - router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + max_num_tokens_across_dp = get_forward_context( + ).max_tokens_across_dp + if num_tokens < max_num_tokens_across_dp: + hidden_states = nn.functional.pad( + hidden_states, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + router_logits = nn.functional.pad( + router_logits, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) hidden_states = get_dp_group().all_gather(hidden_states, 0) router_logits = get_dp_group().all_gather(router_logits, 0) @@ -1242,17 +1234,14 @@ def forward(self, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=shared_experts if self.torchair_graph_enabled - and self.enable_multistream_moe and not is_prefill else None, - token_dispatcher=self.token_dispatcher + and self.enable_multistream_moe and not is_prefill else None, ) if shared_experts: if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states - if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather - and fused_moe_state != FusedMoEState.AllGatherEP - and not replace_allreduce): + if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0) @@ -1270,8 +1259,7 @@ def forward(self, else: final_hidden_states = e_hidden_states - if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather - or fused_moe_state == FusedMoEState.AllGatherEP): + if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -1283,12 +1271,12 @@ def forward(self, # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_fused_moe_comp( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_prefill: bool, - real_top_k, - enable_force_load_balance: bool = False, + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + real_top_k, + enable_force_load_balance: bool = False, ): hidden_states = self.quant_method.apply( layer=self, From f6ab19ed624b9bdf91bfcfc6ba7691f8906a665b Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 12:57:59 +0800 Subject: [PATCH 07/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- tests/singlecard/test_offline_inference.py | 2 +- vllm_ascend/envs.py | 8 ++++---- vllm_ascend/models/deepseek_dbo.py | 5 ++--- vllm_ascend/models/qwen3_dbo.py | 1 - vllm_ascend/ops/fused_moe.py | 6 ++---- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/singlecard/test_offline_inference.py b/tests/singlecard/test_offline_inference.py index de958ea7e6..8ca2a2fb9e 100644 --- a/tests/singlecard/test_offline_inference.py +++ b/tests/singlecard/test_offline_inference.py @@ -133,7 +133,7 @@ def test_models_topk() -> None: vllm_model.generate(example_prompts, sampling_params) -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALLV": "1", "VLLM_ASCEND_ENABLE_DBO": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1", "VLLM_ASCEND_ENABLE_DBO": "1"}) def test_models_topk() -> None: example_prompts = [ "Hello, my name is", diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index eefc78d647..8af5bdd783 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -137,11 +137,11 @@ # and the mla_pa will be the default path of deepseek decode path. "VLLM_ASCEND_MLA_PA": lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)), - # VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: + # VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: # 0: default, normal init. - # 1: enable moe all2allv. - "VLLM_ASCEND_ENABLE_MOE_ALL2ALLV": - lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALLV', '0'))), + # 1: enable moe all2all seq. + "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": + lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))), } # end-env-vars-definition diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 45ccd84c50..daa512c25d 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -82,7 +82,6 @@ from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO -VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: bool = envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALLV class CustomDeepseekDBOMLP(CustomDeepseekV2MLP): @@ -172,7 +171,7 @@ def __init__( top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=True if not VLLM_ASCEND_ENABLE_MOE_ALL2ALLV else False, + reduce_results=True if not envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ else False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -1168,7 +1167,7 @@ def _forward_ms_layers(self, for i in range(moe_start_layer, self.end_layer): layer = self.layers[i] ms_layer_forward_func = layer._forward_ms_layer - if VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: + if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: # ms_layer_forward_func = layer._forward_ms_layer_alltoallv ms_layer_forward_func = layer._forward_ms_layer_alltoallv_finegrained # print("get_called......") diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index e6eeba78b0..2e770d456e 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -37,7 +37,6 @@ import vllm_ascend.envs as envs_ascend VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO -VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: bool = envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALLV class Qwen3MoeDecoderLayerDBO(Qwen3MoeDecoderLayer): diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 115e1c880a..eceba44cb3 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -47,8 +47,6 @@ MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER -VLLM_ASCEND_ENABLE_MOE_ALL2ALLV: bool = envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALLV -VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, @@ -1147,7 +1145,7 @@ def __init__( self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) self.token_dispatcher = None - if VLLM_ASCEND_ENABLE_MOE_ALL2ALLV and isinstance( + if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance( self.quant_method, AscendUnquantizedFusedMoEMethod): moe_dispatcher_config = ( MoeDispatcherConfig().set_num_moe_experts(self.global_num_experts) @@ -1158,7 +1156,7 @@ def __init__( .set_expert_bias(e_score_correction_bias) .set_scaling_factor(1.0).build()) self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(moe_dispatcher_config) - if VLLM_ASCEND_ENABLE_DBO: + if envs_ascend.VLLM_ASCEND_ENABLE_DBO: token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(moe_dispatcher_config) self.token_dispatchers = [self.token_dispatcher, token_dispatcher1] From a94c094a3f0643faef54f68234f36864e42d1aaf Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 13:18:15 +0800 Subject: [PATCH 08/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/ascend_forward_context.py | 5 +++++ vllm_ascend/models/deepseek_dbo.py | 6 ++++-- vllm_ascend/ops/fused_moe.py | 7 ++++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 80216b9044..50198acfa0 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -7,17 +7,22 @@ from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context, set_forward_context +import vllm_ascend.envs as envs_ascend + class FusedMoEState(Enum): AllGather = 0 All2All = 1 MC2 = 2 + All2AllSeq = 3 # TODO(zzzzwwjj): add soc_version to choose branch def get_fused_moe_state(ep_size: int, with_prefill: bool): if ep_size == 1: return FusedMoEState.AllGather + elif with_prefill and envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: + return FusedMoEState.All2AllSeq # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. elif ep_size < 16 or with_prefill: return FusedMoEState.All2All diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index daa512c25d..e1fac8da8e 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -66,6 +66,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP from vllm_ascend.multistream.base import MSEventKey @@ -171,7 +172,7 @@ def __init__( top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=True if not envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ else False, + reduce_results=True, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -1160,6 +1161,7 @@ def _forward_ms_layers(self, if moe_start_layer == self.end_layer: return hidden_states, residual + fused_moe_state = get_forward_context().fused_moe_state attn_metadata, [positions, hidden_states, residual] = self.ms_pre_layer( [positions, hidden_states, residual], ) @@ -1167,7 +1169,7 @@ def _forward_ms_layers(self, for i in range(moe_start_layer, self.end_layer): layer = self.layers[i] ms_layer_forward_func = layer._forward_ms_layer - if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: + if fused_moe_state == FusedMoEState.All2AllSeq: # ms_layer_forward_func = layer._forward_ms_layer_alltoallv ms_layer_forward_func = layer._forward_ms_layer_alltoallv_finegrained # print("get_called......") diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index eceba44cb3..0755c44f6c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -958,7 +958,6 @@ def apply( topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) fused_moe_state = get_forward_context().fused_moe_state - use_alltoallv = 'token_dispatcher' in kwargs and kwargs.get('token_dispatcher') is not None if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( @@ -992,7 +991,7 @@ def apply( global_batch_size=self.global_batch_size, expert_map=expert_map, ep_group=get_ep_group()) - elif use_alltoallv and is_prefill: + elif fused_moe_state == FusedMoEState.All2AllSeq is not None and is_prefill: token_dispatcher = kwargs.get('token_dispatcher') return fused_experts_with_all2allv(token_dispatcher=token_dispatcher, probs=topk_weights, @@ -1145,8 +1144,10 @@ def __init__( self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) self.token_dispatcher = None - if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance( + fused_moe_state = get_forward_context().fused_moe_state + if fused_moe_state == FusedMoEState.All2AllSeq and isinstance( self.quant_method, AscendUnquantizedFusedMoEMethod): + self.reduce_results = False moe_dispatcher_config = ( MoeDispatcherConfig().set_num_moe_experts(self.global_num_experts) .set_num_local_experts(self.local_num_experts) From 91570d8b814ec9656387c0ee9cff11ec87941519 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 13:22:06 +0800 Subject: [PATCH 09/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 0755c44f6c..b15d38b8a6 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1144,8 +1144,7 @@ def __init__( self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) self.token_dispatcher = None - fused_moe_state = get_forward_context().fused_moe_state - if fused_moe_state == FusedMoEState.All2AllSeq and isinstance( + if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance( self.quant_method, AscendUnquantizedFusedMoEMethod): self.reduce_results = False moe_dispatcher_config = ( From e7c0d2d5bf15ded0ca6fff81fac87b948e335e5b Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 13:26:14 +0800 Subject: [PATCH 10/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index b15d38b8a6..e1d86f7766 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -991,7 +991,7 @@ def apply( global_batch_size=self.global_batch_size, expert_map=expert_map, ep_group=get_ep_group()) - elif fused_moe_state == FusedMoEState.All2AllSeq is not None and is_prefill: + elif fused_moe_state == FusedMoEState.All2AllSeq and is_prefill: token_dispatcher = kwargs.get('token_dispatcher') return fused_experts_with_all2allv(token_dispatcher=token_dispatcher, probs=topk_weights, From 47439e82b7d8220de7a76821a886afd933583f43 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 13:32:51 +0800 Subject: [PATCH 11/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index e1d86f7766..cc0e324d85 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1233,6 +1233,7 @@ def forward(self, global_redundant_expert_num=self.global_redundant_expert_num, shared_experts=shared_experts if self.torchair_graph_enabled and self.enable_multistream_moe and not is_prefill else None, + token_dispatcher=self.token_dispatcher ) if shared_experts: From cf3f1c803b6b35b6f258f70e6337b264183ec4f0 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 13:36:45 +0800 Subject: [PATCH 12/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index cc0e324d85..a277ff5a8b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -565,10 +565,8 @@ def fused_experts_with_all2allv(token_dispatcher, probs, routing_map, hidden_sta (share_experts_output, dispatched_input, tokens_per_expert) = token_dispatcher.token_permutation( hidden_states, probs, routing_map ) - hidden_states_wrapper = [dispatched_input] - del dispatched_input - expert_output = apply_mlp(hidden_states_wrapper, + expert_output = apply_mlp(hidden_states, w1, w2, tokens_per_expert) From a4126f3a57f1fa61d56430db3f3069849b7368b0 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 14:03:19 +0800 Subject: [PATCH 13/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/moe_dispatcher/token_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 60dd4f1be2..f3e2599a24 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -190,7 +190,7 @@ def __init__(self, config: MoeDispatcherConfig): for i in range(len(self.local_expert_indices) - 1): assert ( self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1 - ), "local_expert_indices must be continous" + ), "local_expert_indices must be continuous" self.probs = None self.input_splits = None self.output_splits = None From 807aaf05ff4fd4013b6a7898f0df0b8159e10520 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 1 Jul 2025 19:28:48 +0800 Subject: [PATCH 14/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/models/deepseek_dbo.py | 4 ++-- vllm_ascend/worker/model_runner_v1.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index e1fac8da8e..7d4de51471 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -172,7 +172,7 @@ def __init__( top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - reduce_results=True, + reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, @@ -190,7 +190,7 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=True, + reduce_results=True if not envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ else False, prefix=f"{prefix}.shared_experts", ) CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 32b12508db..c09f1a52c6 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -603,7 +603,7 @@ def _check_dbo_is_valid(self, query_lens: torch.Tensor, ]: return False # considering the case that one dp rank may enable dbo while others may not - if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO: + if not envs_ascend.VLLM_ASCEND_ENABLE_DBO: return False # TODO: remove it if token-level microbatch is enabled [token_index, From 6f6efc1f9e63835c9c8c51012a54897e177ec33f Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 2 Jul 2025 11:25:47 +0800 Subject: [PATCH 15/60] [Feature]Moe alltoallv communication optimization for unquantized RL training sence & alltoallv support dpo Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index a277ff5a8b..7c36317ba4 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -566,7 +566,7 @@ def fused_experts_with_all2allv(token_dispatcher, probs, routing_map, hidden_sta hidden_states, probs, routing_map ) - expert_output = apply_mlp(hidden_states, + expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert) From 5411ed62bff01a80685c21ccd3cbed1eadc1af37 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 8 Jul 2025 11:29:54 +0800 Subject: [PATCH 16/60] add st:qwen3 Signed-off-by: weijinqian_v1 --- tests/multicard/test_qwen3_moe.py | 71 ++++++++++++++++++++++ tests/singlecard/test_offline_inference.py | 20 ------ 2 files changed, 71 insertions(+), 20 deletions(-) create mode 100644 tests/multicard/test_qwen3_moe.py diff --git a/tests/multicard/test_qwen3_moe.py b/tests/multicard/test_qwen3_moe.py new file mode 100644 index 0000000000..391cc48424 --- /dev/null +++ b/tests/multicard/test_qwen3_moe.py @@ -0,0 +1,71 @@ + +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +""" +Compare the outputs of vLLM with and without aclgraph. +Run `pytest tests/multicard/test_data_parallel.py`. +""" + +import os +import subprocess +import sys +from unittest.mock import patch + +import pytest + +MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [32]) +@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3", "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1", "VLLM_ASCEND_ENABLE_DBO": "1"}) +def test_qwen3_moe_inference(model, max_tokens): + script = "examples/offline_data_parallel.py" + + env = os.environ.copy() + + cmd = [ + sys.executable, + script, + "--model", + model, + "--dp-size", + "2", + "--tp-size", + "2", + "--node-size", + "1", + "--node-rank", + "0", + "--trust-remote-code", + "--enforce-eager", + ] + + print(f"Running subprocess: {' '.join(cmd)}") + proc = subprocess.run(cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + timeout=600) + output = proc.stdout.decode() + + print(output) + + assert "DP rank 0 needs to process" in output + assert "DP rank 1 needs to process" in output + assert "Generated text:" in output + assert proc.returncode == 0 \ No newline at end of file diff --git a/tests/singlecard/test_offline_inference.py b/tests/singlecard/test_offline_inference.py index 8ca2a2fb9e..09f29f5c3a 100644 --- a/tests/singlecard/test_offline_inference.py +++ b/tests/singlecard/test_offline_inference.py @@ -132,23 +132,3 @@ def test_models_topk() -> None: gpu_memory_utilization=0.7) as vllm_model: vllm_model.generate(example_prompts, sampling_params) - -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1", "VLLM_ASCEND_ENABLE_DBO": "1"}) -def test_models_topk() -> None: - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = SamplingParams(max_tokens=5, - temperature=0.0, - top_k=50, - top_p=0.9) - - with VllmRunner("Qwen/Qwen2.5-0.5B-Instruct", - max_model_len=8192, - dtype="float16", - enforce_eager=True, - gpu_memory_utilization=0.7) as vllm_model: - vllm_model.generate(example_prompts, sampling_params) From 3f887690b42b648db776fab3b22c4c9f4535fccc Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 8 Jul 2025 11:59:01 +0800 Subject: [PATCH 17/60] add st for moe token dispatcher Signed-off-by: weijinqian_v1 --- tests/ut/moe_util.py | 217 +++++++++++++++++++++++++++++++++++ tests/ut/token_dispatcher.py | 205 +++++++++++++++++++++++++++++++++ 2 files changed, 422 insertions(+) create mode 100644 tests/ut/moe_util.py create mode 100644 tests/ut/token_dispatcher.py diff --git a/tests/ut/moe_util.py b/tests/ut/moe_util.py new file mode 100644 index 0000000000..ec745773ec --- /dev/null +++ b/tests/ut/moe_util.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +import torch +import pytest +import math + +from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, \ + group_limited_topk, unpermute, sort_chunks_by_idxs + + +class TestMoeUtils: + + @pytest.fixture + def setup(self): + self.num_tokens = 16 + self.num_experts = 4 + self.hidden_size = 8 + self.topk = 2 + self.capacity_factor = 1.0 + self.group_topk = 2 + self.num_groups = 2 + self.scaling_factor = 1.0 + + def test_group_limited_topk(self, setup): + # Test group-limited topk routing + scores = torch.randn(self.num_tokens, self.num_experts) + probs, indices = group_limited_topk( + scores, + topk=self.topk, + num_tokens=self.num_tokens, + num_experts=self.num_experts, + num_groups=self.num_groups, + group_topk=self.group_topk + ) + + assert probs.shape == (self.num_tokens, self.topk) + assert indices.shape == (self.num_tokens, self.topk) + assert torch.all(indices < self.num_experts) + + def test_topk_softmax_with_capacity(self, setup): + # Test topk softmax with capacity + logits = torch.randn(self.num_tokens, self.num_experts) + + # Test without capacity + probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( + logits, + topk=self.topk + ) + assert probs.shape == (self.num_tokens, self.num_experts) + assert routing_map.shape == (self.num_tokens, self.num_experts) + assert tokens_per_expert.shape == (self.num_experts,) + + # Test with capacity + probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( + logits, + topk=self.topk, + capacity_factor=self.capacity_factor, + pad_to_capacity=True + ) + expert_capacity = get_capacity( + num_tokens=self.num_tokens * self.topk, + num_experts=self.num_experts, + capacity_factor=self.capacity_factor + ) + assert tokens_per_expert.max() <= expert_capacity + + # Test with group routing + probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( + logits, + topk=self.topk, + num_groups=self.num_groups, + group_topk=self.group_topk + ) + assert probs.shape == (self.num_tokens, self.num_experts) + + def test_get_capacity(self, setup): + # Test capacity calculation + capacity = get_capacity( + num_tokens=self.num_tokens, + num_experts=self.num_experts, + capacity_factor=self.capacity_factor + ) + expected = math.ceil((self.num_tokens / self.num_experts) * self.capacity_factor) + assert capacity == expected + + # Test with min capacity + min_capacity = 5 + capacity = get_capacity( + num_tokens=self.num_tokens, + num_experts=self.num_experts, + capacity_factor=self.capacity_factor, + min_capacity=min_capacity + ) + assert capacity == min_capacity + + def test_permute(self, setup): + # Test token permutation + tokens = torch.randn(self.num_tokens, self.hidden_size) + routing_map = torch.randint(0, 2, (self.num_tokens, self.num_experts)).bool() + + # Basic permutation + permuted_tokens, sorted_indices = permute(tokens, routing_map) + assert permuted_tokens.shape[0] == routing_map.sum() + assert sorted_indices.shape[0] == routing_map.sum() + + # With drop and pad + capacity = get_capacity( + num_tokens=self.num_tokens * self.topk, + num_experts=self.num_experts, + capacity_factor=self.capacity_factor + ) + num_out_tokens = capacity * self.num_experts + permuted_tokens, sorted_indices = permute( + tokens, + routing_map, + num_out_tokens=num_out_tokens, + drop_and_pad=True + ) + assert permuted_tokens.shape[0] == num_out_tokens + assert sorted_indices.shape[0] == num_out_tokens + + def test_unpermute(self, setup): + # Test token unpermutation + tokens = torch.randn(self.num_tokens, self.hidden_size) + routing_map = torch.randint(0, 2, (self.num_tokens, self.num_experts)).bool() + probs = torch.rand(self.num_tokens, self.num_experts) + + # First permute + permuted_tokens, sorted_indices = permute(tokens, routing_map) + + # Then unpermute + restored_tokens = unpermute( + permuted_tokens, + sorted_indices, + tokens.shape, + probs=probs, + routing_map=routing_map + ) + assert restored_tokens.shape == tokens.shape + + # With drop and pad + capacity = get_capacity( + num_tokens=self.num_tokens * self.topk, + num_experts=self.num_experts, + capacity_factor=self.capacity_factor + ) + num_out_tokens = capacity * self.num_experts + permuted_tokens, sorted_indices = permute( + tokens, + routing_map, + num_out_tokens=num_out_tokens, + drop_and_pad=True + ) + restored_tokens = unpermute( + permuted_tokens, + sorted_indices, + tokens.shape, + probs=probs, + routing_map=routing_map, + drop_and_pad=True + ) + assert restored_tokens.shape == tokens.shape + + def test_sort_chunks_by_idxs(self, setup): + # Test chunk sorting + input_tensor = torch.randn(10, self.hidden_size) + split_sizes = torch.tensor([3, 2, 5]) + sorted_idxs = torch.tensor([2, 0, 1]) + + output = sort_chunks_by_idxs(input_tensor, split_sizes, sorted_idxs) + assert output.shape == input_tensor.shape + + # Verify the order is correct + expected = torch.cat([input_tensor[5:], input_tensor[0: 3], input_tensor[3: 5]]) + assert torch.allclose(output, expected) \ + \ + @ pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + + def test_score_functions(self, setup, score_function): + # Test different score functions + logits = torch.randn(self.num_tokens, self.num_experts) + expert_bias = torch.randn(self.num_experts) + + probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( + logits, + topk=self.topk, + score_function=score_function, + expert_bias=expert_bias + ) + assert probs.shape == (self.num_tokens, self.num_experts) + assert routing_map.shape == (self.num_tokens, self.num_experts) + assert tokens_per_expert.shape == (self.num_experts,) + + def test_edge_cases(self, setup): + # Test empty input + empty_logits = torch.randn(0, self.num_experts) + with pytest.raises(AssertionError): + topk_softmax_with_capacity(empty_logits, topk=self.topk) + + # Test invalid score function + logits = torch.randn(self.num_tokens, self.num_experts) + with pytest.raises(ValueError): + topk_softmax_with_capacity( + logits, + topk=self.topk, + score_function="invalid" + ) + + # Test invalid drop policy + with pytest.raises(ValueError): + topk_softmax_with_capacity( + logits, + topk=self.topk, + capacity_factor=1.0, + drop_policy="invalid" + ) \ No newline at end of file diff --git a/tests/ut/token_dispatcher.py b/tests/ut/token_dispatcher.py new file mode 100644 index 0000000000..06f365c5f2 --- /dev/null +++ b/tests/ut/token_dispatcher.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + +import torch +import pytest +from unittest.mock import MagicMock + +from vllm_ascend.ops.moe_dispatcher.moe_utils import get_capacity +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import MoeDispatcherConfig, MoEAlltoAllSeqOverLapDispatcher + + +class TestMoEAlltoAllSeqOverLapDispatcher: + + @pytest.fixture + def config(self): + config = MoeDispatcherConfig() + config.set_num_local_experts(2) + config.set_num_moe_experts(4) + config.set_moe_pad_expert_input_to_capacity(False) + config.set_moe_expert_capacity_factor(None) + config.set_moe_router_topk(2) + config.set_moe_grouped_gemm(False) + config.set_group_topk(0) + config.set_num_groups(1) + config.set_is_fused(True) + return config.build() + + @pytest.fixture + def mock_ep_group(self, mocker): + mock_group = MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 2 + mock_group.device_group = "mock_group" + mocker.patch('vllm_ascend.distributed.tensor_parallel.get_ep_group', return_value=mock_group) + return mock_group + + @pytest.fixture + def dispatcher(self, config, mock_ep_group): + return MoEAlltoAllSeqOverLapDispatcher(config) + + def test_initialization(self, dispatcher, config): + assert dispatcher.num_local_experts == config.num_local_experts + assert dispatcher.num_experts == config.num_moe_experts + assert dispatcher.local_expert_indices == [0, 1] + assert dispatcher.ep_rank == 0 + assert dispatcher.ep_size == 2 + assert dispatcher.overlap_stream is not None + + def test_routing(self, dispatcher): + probs = torch.randn(4, 4) # 4 tokens, 4 experts + scores, routing_map = dispatcher.routing(probs) + assert scores.shape == (4, 2) # topk=2 + assert routing_map.shape == (4, 2) + + def test_preprocess(self, dispatcher): + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + num_tokens_per_local_expert = dispatcher.preprocess(routing_map) + assert num_tokens_per_local_expert.shape == (2,) + + def test_token_permutation(self, dispatcher): + hidden_states = torch.randn(4, 8) # 4 tokens, hidden_size=8 + probs = torch.randn(4, 4) # 4 tokens, 4 experts + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + + shared_output, global_input, tokens_per_expert = dispatcher.token_permutation( + hidden_states, probs, routing_map + ) + + assert shared_output is None + assert global_input.shape[1] == 8 # hidden size preserved + assert tokens_per_expert.shape == (2,) + + def test_token_unpermutation(self, dispatcher): + # First do permutation to setup state + hidden_states = torch.randn(4, 8) + probs = torch.randn(4, 4) + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + _, global_input, _ = dispatcher.token_permutation(hidden_states, probs, routing_map) + + # Now test unpermutation + expert_output = torch.randn_like(global_input) + output, bias = dispatcher.token_unpermutation(expert_output) + + assert output.shape == hidden_states.shape + assert bias is None + + def test_preprocess_and_permute1(self, dispatcher): + hidden_states = torch.randn(4, 8) + probs = torch.randn(4, 4) + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + + dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) + + assert dispatcher.cached_permutated_local_input_tokens is not None + assert dispatcher.tokens_per_expert is not None + + def test_dispatch_alltoall(self, dispatcher): + # Setup with preprocess_and_permute1 + hidden_states = torch.randn(4, 8) + probs = torch.randn(4, 4) + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) + + dispatcher.dispatch_alltoall() + + assert dispatcher.cached_global_input_tokens is not None + assert dispatcher.cached_permutated_local_input_tokens is None + + def test_permute2(self, dispatcher): + # Setup chain + hidden_states = torch.randn(4, 8) + probs = torch.randn(4, 4) + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) + dispatcher.dispatch_alltoall() + + global_input, tokens_per_expert = dispatcher.permute2() + + assert global_input is not None + assert tokens_per_expert.shape == (2,) + + def test_unpermute1(self, dispatcher): + # Setup chain + hidden_states = torch.randn(4, 8) + probs = torch.randn(4, 4) + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) + dispatcher.dispatch_alltoall() + global_input, _ = dispatcher.permute2() + + dispatcher.unpermute1(global_input) + + assert dispatcher.cached_global_output_tokens is not None + + def test_combine_alltoall(self, dispatcher): + # Setup chain + hidden_states = torch.randn(4, 8) + probs = torch.randn(4, 4) + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) + dispatcher.dispatch_alltoall() + global_input, _ = dispatcher.permute2() + dispatcher.unpermute1(global_input) + + dispatcher.combine_alltoall() + + assert dispatcher.cached_local_output_tokens is not None + assert dispatcher.cached_global_output_tokens is None + + def test_unpermute2(self, dispatcher): + # Setup chain + hidden_states = torch.randn(4, 8) + probs = torch.randn(4, 4) + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) + dispatcher.dispatch_alltoall() + global_input, _ = dispatcher.permute2() + dispatcher.unpermute1(global_input) + dispatcher.combine_alltoall() + + output = dispatcher.unpermute2() + + assert output.shape == hidden_states.shape + assert dispatcher.cached_local_output_tokens is None + + @pytest.mark.parametrize("capacity_factor", [1.0, 1.5, 2.0]) + def test_with_capacity_factor(self, config, capacity_factor): + config.set_moe_pad_expert_input_to_capacity(True) + config.set_moe_expert_capacity_factor(capacity_factor) + dispatcher = MoEAlltoAllSeqOverLapDispatcher(config) + + hidden_states = torch.randn(4, 8) + probs = torch.randn(4, 4) + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + + shared_output, global_input, tokens_per_expert = dispatcher.token_permutation( + hidden_states, probs, routing_map + ) + + # Check capacity was calculated correctly + num_tokens = hidden_states.shape[0] + expected_capacity = get_capacity( + num_tokens=num_tokens, + num_experts=dispatcher.num_experts, + capacity_factor=capacity_factor, + ) + assert dispatcher.capacity == expected_capacity + + def test_shared_experts(self, dispatcher): + mock_shared_experts = MagicMock() + mock_shared_experts.return_value = (torch.randn(4, 8),) + dispatcher.set_shared_experts(mock_shared_experts) + + hidden_states = torch.randn(4, 8) + probs = torch.randn(4, 4) + routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) + + shared_output, _, _ = dispatcher.token_permutation( + hidden_states, probs, routing_map + ) + + assert shared_output is not None + assert shared_output.shape == hidden_states.shape + mock_shared_experts.assert_called_once() From 854c149f42d8b24edad0fc869e7b5f07db90fb54 Mon Sep 17 00:00:00 2001 From: yangkai Date: Tue, 8 Jul 2025 21:13:28 +0800 Subject: [PATCH 18/60] fix bug Signed-off-by: weijinqian_v1 --- vllm_ascend/ascend_forward_context.py | 2 +- vllm_ascend/models/qwen3_dbo.py | 62 ++++++----------- vllm_ascend/models/qwen3_moe.py | 99 +++++++++++++++++++++++++++ vllm_ascend/multistream/ms_split.py | 4 +- vllm_ascend/ops/fused_moe.py | 2 +- 5 files changed, 125 insertions(+), 44 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 040d761301..1c47351b81 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -21,7 +21,7 @@ class FusedMoEState(Enum): def get_fused_moe_state(ep_size: int, with_prefill: bool): if ep_size == 1: return FusedMoEState.AllGather - elif with_prefill and envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: + elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: return FusedMoEState.All2AllSeq # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. elif ep_size < 16 or with_prefill: diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 2e770d456e..2f2760f559 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -6,7 +6,6 @@ import torch_npu from torch import nn from transformers import PretrainedConfig -from vllm.compilation.decorators import support_torch_compile from vllm.model_executor.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeModel from vllm.config import CacheConfig, VllmConfig @@ -22,6 +21,7 @@ from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.compilation.decorators import support_torch_compile from vllm_ascend.multistream.context import ( advance_step_multistream_layer_context, get_multistream_comm_context, @@ -35,6 +35,7 @@ from vllm_ascend.ops.fused_moe import AscendFusedMoE, select_experts, apply_mlp from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region import vllm_ascend.envs as envs_ascend +from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -197,7 +198,7 @@ def _forward_op_grouped_mlp( self, dispatched_input, tokens_per_expert ): return apply_mlp( - [dispatched_input], + dispatched_input, self.mlp.experts.w13_weight, self.mlp.experts.w2_weight, tokens_per_expert @@ -207,8 +208,7 @@ def _forward_combine_comm( self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes ): token_dispatcher = self.mlp.experts.token_dispatchers[microbatch_id] - token_dispatcher.combine_alltoall() - final_hidden_states = token_dispatcher.unpermute2() + final_hidden_states, _ = token_dispatcher.token_unpermutation(hidden_states) if hasattr(self.mlp, 'routed_scaling_factor'): final_hidden_states = final_hidden_states * self.mlp.routed_scaling_factor @@ -267,13 +267,10 @@ def discard_tensor(tensor): # communication in the previous layer, and the attn computation of microbatch 2 # can be overlapped with the attn communication of microbatch 1 for i in range(num_micro_batchs): - # wait last layer moe finishing communication - ms_metadata.try_wait_event(layer_index - 1, i, - MSEventKey.MOE_AFTER_COMM) - forward_context = get_forward_context() layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( ) + ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) forward_context.attn_metadata = attn_metadata[i] # input layernorm @@ -309,36 +306,25 @@ def discard_tensor(tensor): with torch.npu.stream(dispatch_context.comm_stream): dispatch_context.comm_stream.wait_event(dispatch_context.before_comm_event) token_dispatchers[i].dispatch_alltoall() + dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2() dispatch_context.after_comm_event.record() - if has_shared_expert: - token_dispatchers[i].cached_shared_expert_output = tensor_model_parallel_all_reduce( - token_dispatchers[i].cached_shared_expert_output - ) - ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH].record() - # print_with_sync('begin experts...', torch.distributed.get_rank()) # block 4 : Router Experts Computation # block 5 : Token Combine Communication for i in range(num_micro_batchs): - ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM) discard_tensor(hidden_states[i]) - - dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2() router_expert_output[i] = self._forward_op_grouped_mlp(dispatched_input[i], tokens_per_expert[i]) discard_tensor(dispatched_input[i]) - token_dispatchers[i].unpermute1(router_expert_output[i]) - if router_expert_output[i].shape[0] > 0 and token_dispatchers[i].num_local_experts > 1: - discard_tensor(router_expert_output[i]) # Launch Combine Comm in a New Stream. combine_context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_BEFORE_COMM], + MSEventKey.FFN_COM_FINISH], after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], + MSEventKey.FFN_AR_FINISH], ) combine_context.before_comm_event.record() ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) @@ -347,7 +333,7 @@ def discard_tensor(tensor): hidden_states[i] = self._forward_combine_comm( router_expert_output[i], i, num_tokens[i], chunked_hidden_states_sizes[i] ) - combine_context.after_comm_event.record() + ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_AR_FINISH] = combine_context.comm_stream.record_event() return hidden_states, residual @@ -443,11 +429,10 @@ def forward( def can_run_ms(self): attn_metadata = get_forward_context().attn_metadata # enable prefill overlap - with_prefill = getattr(attn_metadata, "with_prefill_across_dp", False) + with_prefill = get_forward_context().with_prefill if attn_metadata is None or not with_prefill or not attn_metadata.enable_dbo_across_dp: return False - # if torch.distributed.get_rank() == 0: - # print(attn_metadata) + return True def _forward_ms_layers( @@ -465,9 +450,7 @@ def _forward_ms_layers( attn_metadata, [positions, hidden_states, residual] = self.ms_pre_layer( [positions, hidden_states, residual], ) - # if torch.distributed.get_rank() == 0: - # print(attn_metadata[0], attn_metadata[1]) - # exit() + num_micro_batch = len(attn_metadata) # the rest layers for i in range(moe_start_layer, self.end_layer): layer = self.layers[i] @@ -481,6 +464,11 @@ def _forward_ms_layers( ) advance_step_multistream_layer_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + for i in range(num_micro_batch): + ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) + + [hidden_states, residual] = self.ms_post_layer([hidden_states, residual], ) return hidden_states, residual @@ -517,17 +505,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + + def forward(self, *args, **kwargs): + if "graph_enable" in kwargs: + kwargs.pop('graph_enable') + return super().forward(*args, **kwargs) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - graph_enable: Optional[bool] = True - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 8ff1b52a7a..1dc328342b 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -15,10 +15,26 @@ # limitations under the License. # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. +from typing import Optional +import torch +import vllm +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM +from vllm.distributed.parallel_state import get_ep_group +from vllm.forward_context import get_forward_context +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ops.fused_moe import AscendFusedMoE + class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { "qkv_proj": [ @@ -33,3 +49,86 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } + + +class AscendQwen3MoeSparseMoeBlock(nn.Module): + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = \ + ascend_config.torchair_graph_config.enable_multistream_moe + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + if attn_metadata is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + is_prefill = get_forward_context().with_prefill + enable_force_load_balance = False + # if hasattr(attn_metadata, 'with_prefill_across_dp'): + # is_prefill = attn_metadata.with_prefill_across_dp + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None) + + return hidden_states + + +vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock \ No newline at end of file diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index a99bd90aa9..69d078a47d 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -324,13 +324,13 @@ def model_input_split_v1_attn( query_start_loc=query_start_loc_pre, query_lens=query_lens_pre, seq_lens=seq_lens_pre, + seq_lens_list=seq_lens_pre.tolist(), max_query_len=max_query_len_pre, slot_mapping=slot_mapping_pre, is_only_prefill=is_only_prefill_pre, attn_state=attn_state_pre, attn_mask=attn_mask_pre, num_input_tokens=token_index, - with_prefill_across_dp=attn_metadata.with_prefill_across_dp, enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) @@ -340,13 +340,13 @@ def model_input_split_v1_attn( query_start_loc=query_start_loc_post, query_lens=query_lens_post, seq_lens=seq_lens_post, + seq_lens_list=seq_lens_post.tolist(), max_query_len=max_query_len_post, slot_mapping=slot_mapping_post, is_only_prefill=is_only_prefill_post, attn_state=attn_state_post, attn_mask=attn_mask_post, num_input_tokens=attn_metadata.num_input_tokens - token_index, - with_prefill_across_dp=attn_metadata.with_prefill_across_dp, enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 1b5660594b..687bb432c5 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -989,7 +989,7 @@ def apply( global_batch_size=self.global_batch_size, expert_map=expert_map, ep_group=get_ep_group()) - elif fused_moe_state == FusedMoEState.All2AllSeq and is_prefill: + elif fused_moe_state == FusedMoEState.All2AllSeq: token_dispatcher = kwargs.get('token_dispatcher') return fused_experts_with_all2allv(token_dispatcher=token_dispatcher, probs=topk_weights, From d0bd006f12833809c24726b623ce727cd274e188 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 8 Jul 2025 16:52:08 +0800 Subject: [PATCH 19/60] add st for moe token dispatcher Signed-off-by: weijinqian_v1 --- tests/ut/{moe_util.py => test_moe_util.py} | 65 ++----- tests/ut/test_token_dispatcher.py | 56 ++++++ tests/ut/token_dispatcher.py | 205 --------------------- 3 files changed, 74 insertions(+), 252 deletions(-) rename tests/ut/{moe_util.py => test_moe_util.py} (78%) create mode 100644 tests/ut/test_token_dispatcher.py delete mode 100644 tests/ut/token_dispatcher.py diff --git a/tests/ut/moe_util.py b/tests/ut/test_moe_util.py similarity index 78% rename from tests/ut/moe_util.py rename to tests/ut/test_moe_util.py index ec745773ec..9da4fb16b9 100644 --- a/tests/ut/moe_util.py +++ b/tests/ut/test_moe_util.py @@ -4,9 +4,9 @@ import torch import pytest import math +import vllm_ascend.patch.worker.patch_common.patch_utils -from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, \ - group_limited_topk, unpermute, sort_chunks_by_idxs +from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, group_limited_topk, unpermute, sort_chunks_by_idxs class TestMoeUtils: @@ -22,6 +22,7 @@ def setup(self): self.num_groups = 2 self.scaling_factor = 1.0 + def test_group_limited_topk(self, setup): # Test group-limited topk routing scores = torch.randn(self.num_tokens, self.num_experts) @@ -38,42 +39,33 @@ def test_group_limited_topk(self, setup): assert indices.shape == (self.num_tokens, self.topk) assert torch.all(indices < self.num_experts) - def test_topk_softmax_with_capacity(self, setup): + + @pytest.mark.parametrize("score_function", ["softmax"]) + def test_topk_softmax_with_capacity(self, setup, score_function): # Test topk softmax with capacity logits = torch.randn(self.num_tokens, self.num_experts) # Test without capacity probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( logits, - topk=self.topk + topk=self.topk, + score_function=score_function ) assert probs.shape == (self.num_tokens, self.num_experts) assert routing_map.shape == (self.num_tokens, self.num_experts) assert tokens_per_expert.shape == (self.num_experts,) - # Test with capacity - probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( - logits, - topk=self.topk, - capacity_factor=self.capacity_factor, - pad_to_capacity=True - ) - expert_capacity = get_capacity( - num_tokens=self.num_tokens * self.topk, - num_experts=self.num_experts, - capacity_factor=self.capacity_factor - ) - assert tokens_per_expert.max() <= expert_capacity - # Test with group routing probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( logits, topk=self.topk, num_groups=self.num_groups, - group_topk=self.group_topk + group_topk=self.group_topk, + score_function=score_function ) assert probs.shape == (self.num_tokens, self.num_experts) + def test_get_capacity(self, setup): # Test capacity calculation capacity = get_capacity( @@ -94,6 +86,7 @@ def test_get_capacity(self, setup): ) assert capacity == min_capacity + def test_permute(self, setup): # Test token permutation tokens = torch.randn(self.num_tokens, self.hidden_size) @@ -120,6 +113,7 @@ def test_permute(self, setup): assert permuted_tokens.shape[0] == num_out_tokens assert sorted_indices.shape[0] == num_out_tokens + def test_unpermute(self, setup): # Test token unpermutation tokens = torch.randn(self.num_tokens, self.hidden_size) @@ -162,6 +156,7 @@ def test_unpermute(self, setup): ) assert restored_tokens.shape == tokens.shape + def test_sort_chunks_by_idxs(self, setup): # Test chunk sorting input_tensor = torch.randn(10, self.hidden_size) @@ -173,10 +168,10 @@ def test_sort_chunks_by_idxs(self, setup): # Verify the order is correct expected = torch.cat([input_tensor[5:], input_tensor[0: 3], input_tensor[3: 5]]) - assert torch.allclose(output, expected) \ - \ - @ pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) + assert torch.allclose(output, expected) + + @pytest.mark.parametrize("score_function", ["softmax"]) def test_score_functions(self, setup, score_function): # Test different score functions logits = torch.randn(self.num_tokens, self.num_experts) @@ -190,28 +185,4 @@ def test_score_functions(self, setup, score_function): ) assert probs.shape == (self.num_tokens, self.num_experts) assert routing_map.shape == (self.num_tokens, self.num_experts) - assert tokens_per_expert.shape == (self.num_experts,) - - def test_edge_cases(self, setup): - # Test empty input - empty_logits = torch.randn(0, self.num_experts) - with pytest.raises(AssertionError): - topk_softmax_with_capacity(empty_logits, topk=self.topk) - - # Test invalid score function - logits = torch.randn(self.num_tokens, self.num_experts) - with pytest.raises(ValueError): - topk_softmax_with_capacity( - logits, - topk=self.topk, - score_function="invalid" - ) - - # Test invalid drop policy - with pytest.raises(ValueError): - topk_softmax_with_capacity( - logits, - topk=self.topk, - capacity_factor=1.0, - drop_policy="invalid" - ) \ No newline at end of file + assert tokens_per_expert.shape == (self.num_experts,) \ No newline at end of file diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py new file mode 100644 index 0000000000..b389eb430f --- /dev/null +++ b/tests/ut/test_token_dispatcher.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. + +import torch +import pytest +from pytest_mock import MockerFixture +import vllm_ascend.patch.worker.patch_common.patch_utils +from vllm_ascend.utils import adapt_patch # noqa E402 + +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import MoeDispatcherConfig, MoEAlltoAllSeqOverLapDispatcher + +adapt_patch(True) + +class TestMoEAlltoAllSeqOverLapDispatcher: + + @pytest.fixture + def config(self): + config = MoeDispatcherConfig() + config.set_num_local_experts(2) + config.set_num_moe_experts(4) + config.set_moe_pad_expert_input_to_capacity(False) + config.set_moe_expert_capacity_factor(None) + config.set_moe_router_topk(2) + config.set_moe_grouped_gemm(False) + config.set_group_topk(0) + config.set_num_groups(1) + config.set_is_fused(False) + return config.build() + + def mock_ep_group(self, mocker): + mock_group = mocker.MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 2 + mock_group.device_group = "mock_group" + return mock_group + + @pytest.fixture + def dispatcher(self, config, mocker: MockerFixture): + mocker.patch("vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group", + return_value=self.mock_ep_group(mocker)) + return MoEAlltoAllSeqOverLapDispatcher(config) + + def test_initialization(self, dispatcher, config): + assert dispatcher.num_local_experts == config.num_local_experts + assert dispatcher.num_experts == config.num_moe_experts + assert dispatcher.local_expert_indices == [0, 1] + assert dispatcher.ep_rank == 0 + assert dispatcher.ep_size == 2 + assert dispatcher.overlap_stream is not None + + def test_routing(self, dispatcher): + probs = torch.randn(4, 4) # 4 tokens, 4 experts + scores, routing_map = dispatcher.routing(probs) + assert scores.shape == (4, 4) # topk=2 + assert routing_map.shape == (4, 4) \ No newline at end of file diff --git a/tests/ut/token_dispatcher.py b/tests/ut/token_dispatcher.py deleted file mode 100644 index 06f365c5f2..0000000000 --- a/tests/ut/token_dispatcher.py +++ /dev/null @@ -1,205 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. - -import torch -import pytest -from unittest.mock import MagicMock - -from vllm_ascend.ops.moe_dispatcher.moe_utils import get_capacity -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import MoeDispatcherConfig, MoEAlltoAllSeqOverLapDispatcher - - -class TestMoEAlltoAllSeqOverLapDispatcher: - - @pytest.fixture - def config(self): - config = MoeDispatcherConfig() - config.set_num_local_experts(2) - config.set_num_moe_experts(4) - config.set_moe_pad_expert_input_to_capacity(False) - config.set_moe_expert_capacity_factor(None) - config.set_moe_router_topk(2) - config.set_moe_grouped_gemm(False) - config.set_group_topk(0) - config.set_num_groups(1) - config.set_is_fused(True) - return config.build() - - @pytest.fixture - def mock_ep_group(self, mocker): - mock_group = MagicMock() - mock_group.rank_in_group = 0 - mock_group.world_size = 2 - mock_group.device_group = "mock_group" - mocker.patch('vllm_ascend.distributed.tensor_parallel.get_ep_group', return_value=mock_group) - return mock_group - - @pytest.fixture - def dispatcher(self, config, mock_ep_group): - return MoEAlltoAllSeqOverLapDispatcher(config) - - def test_initialization(self, dispatcher, config): - assert dispatcher.num_local_experts == config.num_local_experts - assert dispatcher.num_experts == config.num_moe_experts - assert dispatcher.local_expert_indices == [0, 1] - assert dispatcher.ep_rank == 0 - assert dispatcher.ep_size == 2 - assert dispatcher.overlap_stream is not None - - def test_routing(self, dispatcher): - probs = torch.randn(4, 4) # 4 tokens, 4 experts - scores, routing_map = dispatcher.routing(probs) - assert scores.shape == (4, 2) # topk=2 - assert routing_map.shape == (4, 2) - - def test_preprocess(self, dispatcher): - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - num_tokens_per_local_expert = dispatcher.preprocess(routing_map) - assert num_tokens_per_local_expert.shape == (2,) - - def test_token_permutation(self, dispatcher): - hidden_states = torch.randn(4, 8) # 4 tokens, hidden_size=8 - probs = torch.randn(4, 4) # 4 tokens, 4 experts - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - - shared_output, global_input, tokens_per_expert = dispatcher.token_permutation( - hidden_states, probs, routing_map - ) - - assert shared_output is None - assert global_input.shape[1] == 8 # hidden size preserved - assert tokens_per_expert.shape == (2,) - - def test_token_unpermutation(self, dispatcher): - # First do permutation to setup state - hidden_states = torch.randn(4, 8) - probs = torch.randn(4, 4) - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - _, global_input, _ = dispatcher.token_permutation(hidden_states, probs, routing_map) - - # Now test unpermutation - expert_output = torch.randn_like(global_input) - output, bias = dispatcher.token_unpermutation(expert_output) - - assert output.shape == hidden_states.shape - assert bias is None - - def test_preprocess_and_permute1(self, dispatcher): - hidden_states = torch.randn(4, 8) - probs = torch.randn(4, 4) - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - - dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) - - assert dispatcher.cached_permutated_local_input_tokens is not None - assert dispatcher.tokens_per_expert is not None - - def test_dispatch_alltoall(self, dispatcher): - # Setup with preprocess_and_permute1 - hidden_states = torch.randn(4, 8) - probs = torch.randn(4, 4) - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) - - dispatcher.dispatch_alltoall() - - assert dispatcher.cached_global_input_tokens is not None - assert dispatcher.cached_permutated_local_input_tokens is None - - def test_permute2(self, dispatcher): - # Setup chain - hidden_states = torch.randn(4, 8) - probs = torch.randn(4, 4) - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) - dispatcher.dispatch_alltoall() - - global_input, tokens_per_expert = dispatcher.permute2() - - assert global_input is not None - assert tokens_per_expert.shape == (2,) - - def test_unpermute1(self, dispatcher): - # Setup chain - hidden_states = torch.randn(4, 8) - probs = torch.randn(4, 4) - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) - dispatcher.dispatch_alltoall() - global_input, _ = dispatcher.permute2() - - dispatcher.unpermute1(global_input) - - assert dispatcher.cached_global_output_tokens is not None - - def test_combine_alltoall(self, dispatcher): - # Setup chain - hidden_states = torch.randn(4, 8) - probs = torch.randn(4, 4) - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) - dispatcher.dispatch_alltoall() - global_input, _ = dispatcher.permute2() - dispatcher.unpermute1(global_input) - - dispatcher.combine_alltoall() - - assert dispatcher.cached_local_output_tokens is not None - assert dispatcher.cached_global_output_tokens is None - - def test_unpermute2(self, dispatcher): - # Setup chain - hidden_states = torch.randn(4, 8) - probs = torch.randn(4, 4) - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - dispatcher.preprocess_and_permtute1(hidden_states, probs, routing_map) - dispatcher.dispatch_alltoall() - global_input, _ = dispatcher.permute2() - dispatcher.unpermute1(global_input) - dispatcher.combine_alltoall() - - output = dispatcher.unpermute2() - - assert output.shape == hidden_states.shape - assert dispatcher.cached_local_output_tokens is None - - @pytest.mark.parametrize("capacity_factor", [1.0, 1.5, 2.0]) - def test_with_capacity_factor(self, config, capacity_factor): - config.set_moe_pad_expert_input_to_capacity(True) - config.set_moe_expert_capacity_factor(capacity_factor) - dispatcher = MoEAlltoAllSeqOverLapDispatcher(config) - - hidden_states = torch.randn(4, 8) - probs = torch.randn(4, 4) - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - - shared_output, global_input, tokens_per_expert = dispatcher.token_permutation( - hidden_states, probs, routing_map - ) - - # Check capacity was calculated correctly - num_tokens = hidden_states.shape[0] - expected_capacity = get_capacity( - num_tokens=num_tokens, - num_experts=dispatcher.num_experts, - capacity_factor=capacity_factor, - ) - assert dispatcher.capacity == expected_capacity - - def test_shared_experts(self, dispatcher): - mock_shared_experts = MagicMock() - mock_shared_experts.return_value = (torch.randn(4, 8),) - dispatcher.set_shared_experts(mock_shared_experts) - - hidden_states = torch.randn(4, 8) - probs = torch.randn(4, 4) - routing_map = torch.tensor([[0, 1], [1, 2], [2, 3], [0, 1]], dtype=torch.long) - - shared_output, _, _ = dispatcher.token_permutation( - hidden_states, probs, routing_map - ) - - assert shared_output is not None - assert shared_output.shape == hidden_states.shape - mock_shared_experts.assert_called_once() From 49e97712a74f5031d20098299dc9577fa1753d74 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 8 Jul 2025 17:30:03 +0800 Subject: [PATCH 20/60] add moe_block: AscendSparseMoeBlock Signed-off-by: weijinqian_v1 --- vllm_ascend/models/__init__.py | 5 + vllm_ascend/models/moe_block.py | 117 ++++++++++++++++++ .../ops/moe_dispatcher/token_dispatcher.py | 7 +- 3 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 vllm_ascend/models/moe_block.py diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index abf531d370..c144a81a01 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -11,6 +11,7 @@ def register_model(): from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 + from .moe_block import AscendSparseMoeBlock ModelRegistry.register_model( "DeepSeekMTPModel", @@ -20,6 +21,10 @@ def register_model(): "Qwen2VLForConditionalGeneration", "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") + ModelRegistry.register_model( + "Qwen3MoeSparseMoeBlock", + "vllm_ascend.models.moe_block:AscendSparseMoeBlock") + if envs.USE_OPTIMIZED_MODEL: ModelRegistry.register_model( "Qwen2_5_VLForConditionalGeneration", diff --git a/vllm_ascend/models/moe_block.py b/vllm_ascend/models/moe_block.py new file mode 100644 index 0000000000..654457650c --- /dev/null +++ b/vllm_ascend/models/moe_block.py @@ -0,0 +1,117 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +from typing import Optional + +import torch +from torch import nn +from vllm.attention import AttentionMetadata +from vllm.distributed import (get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.distributed.parallel_state import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.linear import ReplicatedLinear + +from vllm_ascend.ascend_config import get_ascend_config +from vllm.distributed.parallel_state import get_ep_group +from vllm_ascend.ops.fused_moe import AscendFusedMoE + +from transformers import PretrainedConfig +from vllm.model_executor.layers.quantization import QuantizationConfig + + +class AscendSparseMoeBlock(nn.Module): + + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = \ + ascend_config.torchair_graph_config.enable_multistream_moe + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + is_prefill = True + if attn_metadata is None: + # for profile run + is_prefill = True + enable_force_load_balance = True + else: + # is_prefill = attn_metadata.num_prefills > 0 is_prefill or + enable_force_load_balance = False + if hasattr(attn_metadata, 'with_prefill_across_dp'): + is_prefill = attn_metadata.with_prefill_across_dp + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + ) + + return hidden_states \ No newline at end of file diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index f3e2599a24..ae73a759e7 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -348,6 +348,10 @@ def preprocess(self, indices: torch.Tensor, with_sync=True) -> torch.Tensor: def routing(self, probs): seq_length, bsz = probs.shape[:2] probs = probs.view(-1, self.config.num_moe_experts) + if self.config.is_fused: + score_function = "sigmoid" + else: + score_function = "softmax" scores, routing_map, _, top_indices = topk_softmax_with_capacity( probs, @@ -357,7 +361,8 @@ def routing(self, probs): group_topk=self.config.group_topk, num_groups=self.config.num_groups, expert_bias=self.config.expert_bias, - scaling_factor=self.config.scaling_factor + scaling_factor=self.config.scaling_factor, + score_function=score_function ) self.top_indices = top_indices return scores, routing_map From a9bccf85594c3ad8bc2a4ff59c83658032f623d7 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 8 Jul 2025 19:05:32 +0800 Subject: [PATCH 21/60] add moe_block: AscendSparseMoeBlock Signed-off-by: weijinqian_v1 --- vllm_ascend/models/moe_block.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/models/moe_block.py b/vllm_ascend/models/moe_block.py index 654457650c..338b39c062 100644 --- a/vllm_ascend/models/moe_block.py +++ b/vllm_ascend/models/moe_block.py @@ -18,6 +18,8 @@ from typing import Optional import torch +import vllm.model_executor.models.qwen3_moe as qwen3 + from torch import nn from vllm.attention import AttentionMetadata from vllm.distributed import (get_tensor_model_parallel_world_size, @@ -91,13 +93,12 @@ def forward( attn_metadata = get_forward_context().attn_metadata # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. - is_prefill = True if attn_metadata is None: # for profile run is_prefill = True enable_force_load_balance = True else: - # is_prefill = attn_metadata.num_prefills > 0 is_prefill or + is_prefill = False enable_force_load_balance = False if hasattr(attn_metadata, 'with_prefill_across_dp'): is_prefill = attn_metadata.with_prefill_across_dp @@ -114,4 +115,6 @@ def forward( shared_experts=None, ) - return hidden_states \ No newline at end of file + return hidden_states + +qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock From e31a7dfd62b055a674d89e69ae924a0dcafbfabf Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Tue, 8 Jul 2025 19:09:46 +0800 Subject: [PATCH 22/60] add moe_block: AscendSparseMoeBlock Signed-off-by: weijinqian_v1 --- vllm_ascend/models/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index c144a81a01..276c604f4f 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -21,10 +21,6 @@ def register_model(): "Qwen2VLForConditionalGeneration", "vllm_ascend.models.qwen2_vl:AscendQwen2VLForConditionalGeneration") - ModelRegistry.register_model( - "Qwen3MoeSparseMoeBlock", - "vllm_ascend.models.moe_block:AscendSparseMoeBlock") - if envs.USE_OPTIMIZED_MODEL: ModelRegistry.register_model( "Qwen2_5_VLForConditionalGeneration", From 0a22312d77958472e3c8195d886a312b7a29fea6 Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Tue, 8 Jul 2025 14:16:11 +0800 Subject: [PATCH 23/60] [0.9.1][Perf] Optimize the number of rope-related index selections in deepseek. (#1614) This PR avoids performing index selection of sin/cos cache every layer in deepseek. Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: weijinqian_v1 --- vllm_ascend/attention/mla_v1.py | 76 +++++++++++++++++++-------- vllm_ascend/worker/model_runner_v1.py | 2 + 2 files changed, 56 insertions(+), 22 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 816d93c028..5e1be9fad9 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -81,6 +81,8 @@ class ChunkedContextMetadata: max_query_len: int max_seq_lens: int chunked_context: Optional[ChunkedContextMetadata] = None + sin: torch.Tensor = None + cos: torch.Tensor = None @dataclass @@ -94,6 +96,8 @@ class AscendMLADecodeMetadata: seq_lens_list: list[int] actual_seq_q_lens: Optional[list[int]] = None attn_mask: Optional[torch.Tensor] = None + sin: torch.Tensor = None + cos: torch.Tensor = None @dataclass @@ -205,6 +209,9 @@ def __init__(self, ) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: @@ -336,6 +343,18 @@ def build_torchair_graph_dummy( else: attn_state = AscendAttentionState.DecodeOnly num_decode_tokens = 1 + sin = torch.ones(num_reqs, + 1, + 1, + self.rope_dim, + dtype=self.runner.dtype, + device=device) + cos = torch.ones(num_reqs, + 1, + 1, + self.rope_dim, + dtype=self.runner.dtype, + device=device) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -344,7 +363,8 @@ def build_torchair_graph_dummy( max_seq_lens=1, attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], - ) + sin=sin, + cos=cos) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, @@ -396,6 +416,16 @@ def build( max_query_len = query_lens.max().item() max_seq_lens = seq_lens.max().item() query_start_loc = common_attn_metadata.query_start_loc + if self.cos_cache is None: + self.cos_cache = self.runner.get_model( + ).model.layers[0].self_attn.rotary_emb.cos_cached + self.sin_cache = self.runner.get_model( + ).model.layers[0].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.runner.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.runner.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.runner.dtype) # type: ignore prefill_metadata = None chunked_context_metadata = None @@ -442,18 +472,26 @@ def build( chunk_seq_lens=chunk_seq_lens, workspace=self.chunked_prefill_workspace, ) - + prefill_input_positions = input_positions[tokens_start:] + cos = self.cos_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + prefill_input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) prefill_metadata = AscendMLAPrefillMetadata( attn_mask=self.runner.attn_mask, query_lens=query_lens[tokens_start:], seq_lens=seq_lens, context_lens=seq_lens[tokens_start:], - input_positions=input_positions[tokens_start:], + input_positions=prefill_input_positions, block_table=block_table[reqs_start:, ...], max_query_len=max_query_len, max_seq_lens=max_seq_lens, query_start_loc=prefill_query_start_loc, chunked_context=chunked_context_metadata, + sin=sin, + cos=cos, ) decode_metadata = None @@ -498,8 +536,15 @@ def build( actual_seq_q_lens = query_start_loc[1:].tolist( ) + self.runner.actual_seq_q_lens[num_reqs:num_reqs + num_reqs_pad_size] + cos = self.cos_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[ + input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) else: seq_lens_list = seq_lens.tolist() + cos, sin = None, None decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -509,7 +554,8 @@ def build( max_seq_lens=max_seq_lens, attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=actual_seq_q_lens, - ) + sin=sin, + cos=cos) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, @@ -1101,15 +1147,8 @@ def forward( decode_k_nope = None assert attn_metadata.decode is not None if self.running_in_graph: - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor - cos = self.rotary_emb.cos_cached[:seq_len].to( - dtype=decode_hs_or_q_c.dtype) - sin = self.rotary_emb.sin_cached[:seq_len].to( - dtype=decode_hs_or_q_c.dtype) - cos = cos[attn_metadata.decode.input_positions] - sin = sin[attn_metadata.decode.input_positions] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] + cos = attn_metadata.decode.cos + sin = attn_metadata.decode.sin # Without explicitly controlling the order, IndexByTensor operations # would be placed after `matmul W_KV_T` hindering the overlapping of # KvRmsNormRopeCache and SingleRope. @@ -1144,15 +1183,8 @@ def forward( prefill_q_nope = prefill_q[..., :self.qk_nope_head_dim] if self.torchair_graph_enabled: num_tokens = prefill_hs_or_q_c.shape[0] - seq_len = self.rotary_emb.max_position_embeddings * self.rotary_emb.scaling_factor - cos = self.rotary_emb.cos_cached[:seq_len].to( - dtype=prefill_q_pe.dtype) - sin = self.rotary_emb.sin_cached[:seq_len].to( - dtype=prefill_q_pe.dtype) - cos = cos[attn_metadata.prefill.input_positions] - sin = sin[attn_metadata.prefill.input_positions] - cos = cos[:, None, None, :] - sin = sin[:, None, None, :] + cos = attn_metadata.prefill.cos + sin = attn_metadata.prefill.sin prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) prefill_k_pe, prefill_k_nope = self.exec_kv_prefill( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3bb4f59a0b..7d366d9563 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1666,6 +1666,8 @@ def _dummy_run( attn_metadata.decode.block_table) torch._dynamo.mark_static( attn_metadata.decode.input_positions) + torch._dynamo.mark_static(attn_metadata.decode.sin) + torch._dynamo.mark_static(attn_metadata.decode.cos) torch._dynamo.mark_static(attn_metadata.slot_mapping) for kv in self.kv_caches: assert isinstance( From ee1dd493d61a2e25ebcb89cbadfc0013c2c1ee91 Mon Sep 17 00:00:00 2001 From: xuyexiong Date: Tue, 8 Jul 2025 18:46:11 +0800 Subject: [PATCH 24/60] [BUGFIX] FIX mtp accuraccy when temperture is not 0 (#1632) ### What this PR does / why we need it? 1. [BUGFIX] FIX mtp accuraccy when temperture is not 0 2. [BUGFIX] FIX mtp when multi DP is enabled ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? vllm-ascend/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py Signed-off-by: xuyexiong Signed-off-by: weijinqian_v1 --- .../spec_decode_v1/test_v1_mtp_correctness.py | 6 ++++-- vllm_ascend/attention/mla_v1.py | 10 ++++++---- vllm_ascend/sample/rejection_sampler.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 16 +++++++++++++--- vllm_ascend/worker/mtp_proposer_v1.py | 9 ++------- 5 files changed, 26 insertions(+), 17 deletions(-) diff --git a/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py b/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py index 3b5e1986f2..68736c0928 100644 --- a/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py +++ b/tests/long_term/spec_decode_v1/test_v1_mtp_correctness.py @@ -114,7 +114,8 @@ def test_mtp_torchair_correctness( enforce_eager=False, additional_config={ "torchair_graph_config": { - "enabled": True + "enabled": True, + "graph_batch_size": [256] }, "ascend_scheduler_config": { "enabled": True @@ -132,7 +133,8 @@ def test_mtp_torchair_correctness( }, additional_config={ "torchair_graph_config": { - "enabled": True + "enabled": True, + "graph_batch_size": [256] }, "ascend_scheduler_config": { "enabled": True diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 5e1be9fad9..98644ec13d 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -324,7 +324,7 @@ def build_torchair_graph_dummy( num_reqs, block_table) num_tokens = num_reqs * self.runner.decode_token_per_req seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device) - seq_lens_list = seq_lens.tolist() + seq_lens_list = [0] * num_reqs input_positions = torch.zeros(num_tokens, dtype=torch.int32, device=device).long() @@ -497,7 +497,7 @@ def build( decode_metadata = None use_torchair_graph = num_token_pad_size != -1 if self._num_decodes > 0: - actual_seq_q_lens = None + actual_seq_q_lens = query_start_loc[1:].tolist() max_seq_lens = seq_lens[:self._num_decodes].max().item() seq_lens = seq_lens[:self._num_decode_tokens] input_positions = input_positions[:self._num_decode_tokens] @@ -1014,11 +1014,13 @@ def _forward_decode( self.qk_rope_head_dim) input_layout = "BNSD" - # TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim] if attn_metadata.attn_state == AscendAttentionState.SpecDecoding: assert num_tokens % self.spec_token_num == 0 + if self.enable_kv_nz: + input_layout = "TND_NTD" + else: + input_layout = "TND" # [bs * q_seq_len, num_heads_per_rank, dim] - input_layout = "TND" q_nope = q_nope.view(num_tokens, self.num_heads, -1) q_pe = q_pe.view(num_tokens, self.num_heads, -1) sparse_mode = 3 diff --git a/vllm_ascend/sample/rejection_sampler.py b/vllm_ascend/sample/rejection_sampler.py index 384787be01..c738410d0c 100644 --- a/vllm_ascend/sample/rejection_sampler.py +++ b/vllm_ascend/sample/rejection_sampler.py @@ -432,7 +432,7 @@ def sample_recovered_tokens_pytorch( if IS_NGRAM: draft_token_id = draft_token_ids[token_idx] - orig_prob = target_probs[token_idx, draft_token_id] + orig_prob = target_probs[token_idx, draft_token_id].item() target_probs[token_idx, draft_token_id] = 0 prob = target_probs[token_idx].clone() else: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7d366d9563..1305ac10d7 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1698,7 +1698,7 @@ def _dummy_run( **model_kwargs) if self.speculative_config and self.speculative_config.method == "deepseek_mtp": assert isinstance(self.drafter, MtpProposer) - self.drafter.dummy_run(num_reqs, with_prefill=with_prefill) + self.drafter.dummy_run(num_reqs) return hidden_states @contextmanager @@ -2163,6 +2163,12 @@ def check_torchair_graph_batch_sizes(self): if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs: self.torchair_graph_batch_sizes.append(self.max_num_reqs) + # we need to make sure that we can deal with max_num_req when `self.decode_token_per_req` is not 1 + self.torchair_graph_batch_sizes = [ + graph_batch_size * self.decode_token_per_req + for graph_batch_size in self.torchair_graph_batch_sizes + ] + # NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size` tp_size = self.parallel_config.tensor_parallel_size if self.parallel_config.enable_expert_parallel: @@ -2170,9 +2176,13 @@ def check_torchair_graph_batch_sizes(self): for graph_batch_size in self.torchair_graph_batch_sizes: cur_graph_batch_size = (graph_batch_size + tp_size - 1) // tp_size * tp_size - # `graph_batch_size` need to be divisible by `self.decode_token_per_req` - cur_graph_batch_size = cur_graph_batch_size * self.decode_token_per_req if cur_graph_batch_size not in new_graph_batch_sizes and \ cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens: new_graph_batch_sizes.append(cur_graph_batch_size) + elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \ + and self.decode_token_per_req > 1: + logger.warning( + f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens", + f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size." + ) self.torchair_graph_batch_sizes = new_graph_batch_sizes diff --git a/vllm_ascend/worker/mtp_proposer_v1.py b/vllm_ascend/worker/mtp_proposer_v1.py index 04a7d617b5..f5c1c5b388 100644 --- a/vllm_ascend/worker/mtp_proposer_v1.py +++ b/vllm_ascend/worker/mtp_proposer_v1.py @@ -308,14 +308,9 @@ def load_model(self) -> None: def dummy_run( self, num_tokens: int, - with_prefill: bool = False, ) -> None: - if self.runner.torchair_graph_enabled and not with_prefill: - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) - else: - attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( - num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) + attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy( + num_reqs=num_tokens, num_actual_tokens=1, is_mtp_model=True) with set_ascend_forward_context(None, self.vllm_config, num_tokens=num_tokens): From eef10934f34cb03278d8cd1cf6f6815f866d1ad0 Mon Sep 17 00:00:00 2001 From: weiguihua2 Date: Tue, 8 Jul 2025 18:46:54 +0800 Subject: [PATCH 25/60] add mc2 mask (#1642) ### What this PR does / why we need it? add mc2 mask ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? --------- Signed-off-by: weiguihua2 --- vllm_ascend/attention/mla_v1.py | 18 +++++++++++++-- vllm_ascend/ops/fused_moe.py | 28 +++++++++++++++++++++++- vllm_ascend/quantization/w8a8_dynamic.py | 16 +++++++++++++- vllm_ascend/worker/model_runner_v1.py | 2 ++ 4 files changed, 60 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 98644ec13d..07dea2dd10 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -11,6 +11,7 @@ from vllm.config import get_current_vllm_config from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm_ascend import envs @@ -98,6 +99,7 @@ class AscendMLADecodeMetadata: attn_mask: Optional[torch.Tensor] = None sin: torch.Tensor = None cos: torch.Tensor = None + mc2_mask: Optional[torch.Tensor] = None @dataclass @@ -213,6 +215,13 @@ def __init__(self, self.cos_cache = None self.sin_cache = None + def generate_activate_mask(self, actual_seqs_num, batch_size): + mc2_mask = torch.zeros(batch_size, + dtype=torch.bool, + device=current_platform.device_type) + mc2_mask[:actual_seqs_num].fill_(True) + return mc2_mask + def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are at @@ -355,6 +364,7 @@ def build_torchair_graph_dummy( self.rope_dim, dtype=self.runner.dtype, device=device) + mc2_mask = self.generate_activate_mask(num_actual_tokens, num_reqs) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=block_table, @@ -364,7 +374,8 @@ def build_torchair_graph_dummy( attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=self.runner.actual_seq_q_lens[:num_reqs], sin=sin, - cos=cos) + cos=cos, + mc2_mask=mc2_mask) return self.metadata_cls( # type: ignore num_input_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens, @@ -545,6 +556,8 @@ def build( else: seq_lens_list = seq_lens.tolist() cos, sin = None, None + mc2_mask = self.generate_activate_mask( + num_actual_tokens, num_reqs + num_reqs_pad_size) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, @@ -555,7 +568,8 @@ def build( attn_mask=self.runner.spec_attn_mask, actual_seq_q_lens=actual_seq_q_lens, sin=sin, - cos=cos) + cos=cos, + mc2_mask=mc2_mask) return self.metadata_cls( # type: ignore num_actual_tokens=num_actual_tokens, diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 687bb432c5..00a03dd914 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -124,6 +124,7 @@ def fused_experts_with_mc2( moe_all_to_all_group_name: Optional[str] = None, shared_experts: Optional[Any] = None, is_torchair: bool = False, + mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: quant_mode = 0 ep_group = get_ep_group() @@ -140,6 +141,9 @@ def fused_experts_with_mc2( need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair) + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + moe_expert_num = len(expert_map) kwargs_mc2 = { "x": hidden_states, @@ -163,6 +167,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage1_kwargs) @@ -232,6 +240,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) @@ -958,6 +970,7 @@ def apply( fused_moe_state = get_forward_context().fused_moe_state if fused_moe_state == FusedMoEState.MC2: + mc2_mask = kwargs.get("mc2_mask", None) return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -968,7 +981,8 @@ def apply( expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, shared_experts=shared_experts, - is_torchair=self.torchair_graph_enabled) + is_torchair=self.torchair_graph_enabled, + mc2_mask=mc2_mask) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -1194,6 +1208,9 @@ def forward(self, if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) + attn_metadata = get_forward_context().attn_metadata + mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and attn_metadata.decode is not None else None + tp_size = get_tensor_model_parallel_world_size() if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: if num_tokens < tp_size: @@ -1201,6 +1218,9 @@ def forward(self, hidden_states, (0, 0, 0, tp_size - num_tokens)) router_logits = nn.functional.pad( router_logits, (0, 0, 0, tp_size - num_tokens)) + if mc2_mask is not None: + mc2_mask = nn.functional.pad(mc2_mask, + (0, tp_size - num_tokens)) chunk_hidden_states = torch.tensor_split(hidden_states, tp_size, dim=0) @@ -1210,6 +1230,11 @@ def forward(self, tp_rank = get_tensor_model_parallel_rank() hidden_states = chunk_hidden_states[tp_rank] router_logits = chunk_router_logits[tp_rank] + + if mc2_mask is not None: + chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0) + mc2_mask = chunk_mc2_mask[tp_rank] + if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: @@ -1248,6 +1273,7 @@ def forward(self, and self.enable_multistream_moe and not is_prefill else None, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, + mc2_mask=mc2_mask, token_dispatcher=self.token_dispatcher ) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index a9938c14f2..3561675095 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -215,6 +215,7 @@ def fused_experts_with_mc2( w2_scale_bias: torch.Tensor = None, quantized_x_for_share: Optional[Any] = None, dynamic_scale_for_share: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if log2phy: topk_ids = log2phy[topk_ids] @@ -233,6 +234,9 @@ def fused_experts_with_mc2( need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair) + # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine + a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + if (expert_map is not None): moe_expert_num = len(expert_map) + global_redundant_expert_num else: @@ -260,6 +264,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) @@ -310,6 +318,10 @@ def fused_experts_with_mc2( "tp_world_size": 1, "tp_rank_id": 0, }) + if a3_need_extra_args: + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) @@ -791,6 +803,7 @@ def apply( topk_weights = topk_weights.to(x.dtype) if fused_moe_state == FusedMoEState.MC2: + mc2_mask = kwargs.get("mc2_mask", None) return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -807,7 +820,8 @@ def apply( shared_experts=shared_experts, is_torchair=self.torchair_graph_enabled, quantized_x_for_share=shared_gate_up, - dynamic_scale_for_share=shared_dequant_scale) + dynamic_scale_for_share=shared_dequant_scale, + mc2_mask=mc2_mask) elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 1305ac10d7..6644f47a70 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1668,6 +1668,8 @@ def _dummy_run( attn_metadata.decode.input_positions) torch._dynamo.mark_static(attn_metadata.decode.sin) torch._dynamo.mark_static(attn_metadata.decode.cos) + torch._dynamo.mark_static( + attn_metadata.decode.mc2_mask) torch._dynamo.mark_static(attn_metadata.slot_mapping) for kv in self.kv_caches: assert isinstance( From eb54e225c49b23befce1548c47ed7d2899861372 Mon Sep 17 00:00:00 2001 From: songshanhu07 <1763685535@qq.com> Date: Tue, 8 Jul 2025 19:44:45 +0800 Subject: [PATCH 26/60] [cherry-pick] static EPLB fix bug, add unit test to v0.9.1-dev (#1667) ### What this PR does / why we need it? [cherry-pick master-> 0.9.1-dev](https://github.com/vllm-project/vllm-ascend/pull/1186) 1.add static EPLB unit test 2.fix bug: Tensor cannot be directly judged by if statements ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Run the unit test. Signed-off-by: songshanhu07 <1763685535@qq.com> Signed-off-by: weijinqian_v1 --- tests/ut/ops/test_expert_load_balancer.py | 147 ++++++++++++++++++++++ vllm_ascend/quantization/w8a8_dynamic.py | 4 +- 2 files changed, 149 insertions(+), 2 deletions(-) create mode 100644 tests/ut/ops/test_expert_load_balancer.py diff --git a/tests/ut/ops/test_expert_load_balancer.py b/tests/ut/ops/test_expert_load_balancer.py new file mode 100644 index 0000000000..3b7a69ddd4 --- /dev/null +++ b/tests/ut/ops/test_expert_load_balancer.py @@ -0,0 +1,147 @@ +# fused moe ops test will hit the infer_schema error, we need add the patch +# here to make the test pass. +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + +import json +import unittest +from typing import List, TypedDict +from unittest import mock + +import torch + +from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer + + +class Device(TypedDict): + device_id: int + device_expert: List[int] + + +class Layer(TypedDict): + layer_id: int + device_count: int + device_list: List[Device] + + +class MockData(TypedDict): + moe_layer_count: int + layer_list: List[Layer] + + +MOCK_DATA: MockData = { + "moe_layer_count": + 1, + "layer_list": [{ + "layer_id": + 0, + "device_count": + 2, + "device_list": [{ + "device_id": 0, + "device_expert": [7, 2, 0, 3, 5] + }, { + "device_id": 1, + "device_expert": [6, 1, 4, 7, 2] + }] + }] +} + + +class TestExpertLoadBalancer(unittest.TestCase): + + def setUp(self): + json_file = "expert_map.json" + with open(json_file, 'w') as f: + json.dump(MOCK_DATA, f) + + self.expert_load_balancer = ExpertLoadBalancer(json_file, + global_expert_num=8) + + def test_init(self): + + self.assertIsInstance(self.expert_load_balancer.expert_map_tensor, + torch.Tensor) + self.assertEqual(self.expert_load_balancer.layers_num, + MOCK_DATA["moe_layer_count"]) + self.assertEqual(self.expert_load_balancer.ranks_num, + MOCK_DATA["layer_list"][0]["device_count"]) + + def test_generate_index_dicts(self): + tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]]) + result = self.expert_load_balancer.generate_index_dicts(tensor_2d) + expected_result = [{ + 7: 0, + 2: 1, + 0: 2, + 3: 3, + 5: 4 + }, { + 6: 5, + 1: 6, + 4: 7, + 7: 8, + 2: 9 + }] + self.assertEqual(result, expected_result) + + def test_generate_expert_placement_map(self): + expert_placement_map = self.expert_load_balancer.generate_expert_placement_map( + ) + self.assertEqual(expert_placement_map.shape, + (self.expert_load_balancer.layers_num, + self.expert_load_balancer.ranks_num, 8)) + self.assertTrue(torch.all(expert_placement_map >= -1)) + + def test_generate_log2phy_expert_map(self): + layer_id = 0 + log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map( + layer_id) + self.assertEqual(log2phy_map.shape, + (self.expert_load_balancer.ranks_num, 8)) + self.assertTrue(torch.all(log2phy_map >= -1)) + + @mock.patch("torch_npu.npu._lazy_init") + @mock.patch("torch.npu.current_device", return_value="cpu") + def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init): + layer_id = 0 + rank_id = 0 + rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map( + layer_id, rank_id) + self.assertEqual(rank_local_expert_num, 5) + expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0], + dtype=torch.int32).to( + rank_expert_map.device) + self.assertTrue(rank_expert_map.equal(expected_tensor)) + + rank_id = 1 + rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map( + layer_id, rank_id) + expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3], + dtype=torch.int32).to( + rank_expert_map.device) + self.assertTrue(rank_expert_map.equal(expected_tensor)) + + def test_get_rank_log2phy_map(self): + layer_id = 0 + rank_id = 0 + log2phy_map = self.expert_load_balancer.get_rank_log2phy_map( + layer_id, rank_id) + expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0], + dtype=torch.int32).to( + log2phy_map.device) + self.assertTrue(log2phy_map.equal(expected_tensor)) + + rank_id = 1 + log2phy_map = self.expert_load_balancer.get_rank_log2phy_map( + layer_id, rank_id) + expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8], + dtype=torch.int32).to( + log2phy_map.device) + self.assertTrue(log2phy_map.equal(expected_tensor)) + + def test_get_global_redundant_expert_num(self): + redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num( + ) + expected_redundant_expert_num = len(MOCK_DATA["layer_list"][0]["device_list"][0]["device_expert"]) * \ + MOCK_DATA["layer_list"][0]["device_count"] - 8 + self.assertEqual(redundant_expert_num, expected_redundant_expert_num) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 3561675095..2dd927f4ed 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -217,7 +217,7 @@ def fused_experts_with_mc2( dynamic_scale_for_share: Optional[Any] = None, mc2_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if log2phy: + if log2phy is not None: topk_ids = log2phy[topk_ids] quant_mode = 2 ep_group = get_ep_group() @@ -352,7 +352,7 @@ def fused_experts_with_all2all(hidden_states: torch.Tensor, global_redundant_expert_num: int = 0, w1_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None): - if log2phy: + if log2phy is not None: topk_ids = log2phy[topk_ids] original_shape = hidden_states.shape if len(original_shape) == 3: From b02ad405b95f2af6fdd14281a121265ceada3de2 Mon Sep 17 00:00:00 2001 From: yangkai Date: Tue, 8 Jul 2025 21:17:47 +0800 Subject: [PATCH 27/60] revert Signed-off-by: weijinqian_v1 --- vllm_ascend/models/qwen3_moe.py | 99 --------------------------------- 1 file changed, 99 deletions(-) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 1dc328342b..8ff1b52a7a 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -15,26 +15,10 @@ # limitations under the License. # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. -from typing import Optional -import torch -import vllm -from torch import nn -from transformers import PretrainedConfig -from vllm.attention import AttentionMetadata -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group -from vllm.distributed.parallel_state import get_dp_group -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM -from vllm.distributed.parallel_state import get_ep_group -from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ops.fused_moe import AscendFusedMoE - class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { "qkv_proj": [ @@ -49,86 +33,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } - - -class AscendQwen3MoeSparseMoeBlock(nn.Module): - top_k: int - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe - - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") - - self.experts = AscendFusedMoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts") - - self.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() - - self.params_dtype = torch.get_default_dtype() - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = get_forward_context().with_prefill - enable_force_load_balance = False - # if hasattr(attn_metadata, 'with_prefill_across_dp'): - # is_prefill = attn_metadata.with_prefill_across_dp - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=self.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=None) - - return hidden_states - - -vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock = AscendQwen3MoeSparseMoeBlock \ No newline at end of file From 66807e0a1a2ec30deea17f9e0a019f15a629b5d1 Mon Sep 17 00:00:00 2001 From: yangkai Date: Tue, 8 Jul 2025 21:18:43 +0800 Subject: [PATCH 28/60] fix bug Signed-off-by: weijinqian_v1 --- vllm_ascend/models/moe_block.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_ascend/models/moe_block.py b/vllm_ascend/models/moe_block.py index 338b39c062..097aa4833b 100644 --- a/vllm_ascend/models/moe_block.py +++ b/vllm_ascend/models/moe_block.py @@ -98,10 +98,8 @@ def forward( is_prefill = True enable_force_load_balance = True else: - is_prefill = False + is_prefill = get_forward_context().with_prefill enable_force_load_balance = False - if hasattr(attn_metadata, 'with_prefill_across_dp'): - is_prefill = attn_metadata.with_prefill_across_dp # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) From d24758e1d3535861740b5d2b47721d7f541f12e8 Mon Sep 17 00:00:00 2001 From: duyangkai Date: Tue, 8 Jul 2025 21:35:22 +0800 Subject: [PATCH 29/60] fix a bug Signed-off-by: weijinqian_v1 --- vllm_ascend/ascend_forward_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 1c47351b81..85acf480c6 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -22,7 +22,7 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool): if ep_size == 1: return FusedMoEState.AllGather elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: - return FusedMoEState.All2AllSeq + return FusedMoEState.All2AllSeq if ep_size < 16 else FusedMoEState.MC2 # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. elif ep_size < 16 or with_prefill: return FusedMoEState.All2All From d76c4fba50ad4fc0733f593bbfe5de8b07bb21c2 Mon Sep 17 00:00:00 2001 From: duyangkai Date: Tue, 8 Jul 2025 21:36:28 +0800 Subject: [PATCH 30/60] fix a bug Signed-off-by: weijinqian_v1 --- vllm_ascend/ascend_forward_context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 85acf480c6..0ed0cbc29c 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -22,7 +22,8 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool): if ep_size == 1: return FusedMoEState.AllGather elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: - return FusedMoEState.All2AllSeq if ep_size < 16 else FusedMoEState.MC2 + # MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage. + return FusedMoEState.All2AllSeq if (ep_size < 16 or with_prefill) else FusedMoEState.MC2 # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. elif ep_size < 16 or with_prefill: return FusedMoEState.All2All From f883902c20eb6d94a545b45427f100d4512086fc Mon Sep 17 00:00:00 2001 From: yangkai Date: Wed, 9 Jul 2025 02:50:36 +0000 Subject: [PATCH 31/60] ut test Signed-off-by: weijinqian_v1 --- tests/ut/test_distributed_tensor_parallel.py | 114 +++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 tests/ut/test_distributed_tensor_parallel.py diff --git a/tests/ut/test_distributed_tensor_parallel.py b/tests/ut/test_distributed_tensor_parallel.py new file mode 100644 index 0000000000..6b05ab436f --- /dev/null +++ b/tests/ut/test_distributed_tensor_parallel.py @@ -0,0 +1,114 @@ +import pytest +import torch +import importlib +from unittest.mock import MagicMock, patch +from vllm_ascend.distributed.tensor_parallel import ( + _gather_along_first_dim, _gather_along_last_dim, + _reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, + all_to_all_sp2hp, all_to_all_hp2sp +) + +# 测试用的固定数据 +@pytest.fixture +def test_tensor(): + return torch.randn(8, 16) + + +@pytest.fixture +def test_tensor_last_dim(): + return torch.randn(8, 16, 32) + + +@pytest.fixture +def mock_group(): + return MagicMock() + + +# 模拟分布式环境 +@pytest.fixture(autouse=True) +def mock_dist(): + with patch("torch.distributed") as mock: + mock.get_world_size.return_value = 4 + mock.get_rank.return_value = 0 + yield mock + + +class TestDistributedCommunication: + """测试分布式通信函数""" + + @pytest.mark.parametrize("world_size", [1, 4]) + def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, world_size): + """测试_gather_along_first_dim""" + mock_dist.get_world_size.return_value = world_size + + result = _gather_along_first_dim(test_tensor, mock_group) + + if world_size == 1: + assert torch.equal(result, test_tensor) + else: + assert result.shape == (32, 16) # 8*4=32 + + def test_gather_along_first_dim_unequal_split(self, test_tensor, mock_group): + """测试不等分分割情况""" + output_split_sizes = [5, 10, 15, 2] + result = _gather_along_first_dim(test_tensor, mock_group, output_split_sizes) + assert result.shape == (32, 16) # 5+10+15+2=32 + + @pytest.mark.parametrize("world_size", [1, 4]) + def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, mock_dist, world_size): + """测试_gather_along_last_dim""" + mock_dist.get_world_size.return_value = world_size + + result = _gather_along_last_dim(test_tensor_last_dim, mock_group) + + if world_size == 1: + assert torch.equal(result, test_tensor_last_dim) + else: + assert result.shape == (8, 16, 32*world_size) # 8*4=32 + + @pytest.mark.parametrize("input_shape,expected_shape", [ + ((32, 16), (8, 16)), + ((40, 10), (10, 10)), + ]) + def test_reduce_scatter_along_first_dim(self, mock_group, input_shape, expected_shape): + input_tensor = torch.randn(*input_shape) + result = _reduce_scatter_along_first_dim(input_tensor, mock_group) + assert result.shape == expected_shape + + def test_reduce_scatter_along_last_dim(self, mock_group): + input_tensor = torch.randn(8, 16, 32) + result = _reduce_scatter_along_last_dim(input_tensor, mock_group) + assert result.shape == (8, 16, 8) # 32/4=8 + + @pytest.mark.parametrize("func,input_shape,expected_shape", [ + ("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), (8, 16, 128)), + ("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)), + ("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), (8, 16, 8)), + ("gather_from_sequence_parallel_region", (8, 16), (32, 16)), + ]) + def test_wrapper_functions(self, mock_group, func, input_shape, expected_shape): + """测试包装函数""" + mod = importlib.import_module('vllm_ascend.distributed.tensor_parallel') + globals = mod.__dict__ + test_func = globals[func] + input_tensor = torch.randn(*input_shape) + result = test_func(input_tensor, mock_group) + assert result.shape == expected_shape + + + @pytest.mark.parametrize("input_shape,output_shape", [ + ((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] + ]) + def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape): + input_tensor = torch.randn(*input_shape) + result = all_to_all_sp2hp(input_tensor, mock_group) + assert result.shape == output_shape + + + @pytest.mark.parametrize("input_shape,output_shape", [ + ((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] + ]) + def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape): + input_tensor = torch.randn(*input_shape) + result = all_to_all_hp2sp(input_tensor, mock_group) + assert result.shape == output_shape \ No newline at end of file From d5656f47e1bffe94c62c077caa1bacdc015cfed5 Mon Sep 17 00:00:00 2001 From: harygo22 Date: Wed, 9 Jul 2025 15:34:09 +0800 Subject: [PATCH 32/60] liscens & fix dsk dbo. * ut test * liscense & fix dsk dbo. Signed-off-by: weijinqian_v1 --- vllm_ascend/distributed/tensor_parallel.py | 1 + vllm_ascend/models/deepseek_dbo.py | 89 +++++-------------- vllm_ascend/ops/moe_dispatcher/moe_utils.py | 2 +- .../ops/moe_dispatcher/token_dispatcher.py | 1 + 4 files changed, 25 insertions(+), 68 deletions(-) diff --git a/vllm_ascend/distributed/tensor_parallel.py b/vllm_ascend/distributed/tensor_parallel.py index 70aa820094..a9e4324554 100644 --- a/vllm_ascend/distributed/tensor_parallel.py +++ b/vllm_ascend/distributed/tensor_parallel.py @@ -1,3 +1,4 @@ +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index c2093bb190..5051b23ab3 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -55,7 +55,6 @@ from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer, @@ -72,8 +71,7 @@ make_multistream_metadata_ds) from vllm_ascend.quantization.w8a8_dynamic import ( AscendW8A8DynamicLinearMethod, apply_mlp) -from vllm_ascend.ops.fused_moe import AscendFusedMoE, apply_mlp, select_experts -from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.ops.fused_moe import apply_mlp, select_experts from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -94,7 +92,8 @@ def __init__( intermediate_size=intermediate_size, hidden_act=hidden_act, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + reduce_results=reduce_results) self.is_dynamic_quant = not isinstance( self.gate_up_proj.quant_method, UnquantizedLinearMethod) and isinstance( @@ -152,19 +151,6 @@ def __init__( prefix=f"{prefix}.shared_experts", ) CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.kv_consumer = None - transfer_config = get_current_vllm_config().kv_transfer_config - if transfer_config is not None: - self.kv_consumer = transfer_config.kv_role = "kv_consumer" - self.params_dtype = torch.get_default_dtype() - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.config = config def forward( @@ -196,9 +182,13 @@ def forward( enable_force_load_balance=enable_force_load_balance, shared_experts=self.shared_experts) + shared_experts_hidden = experts_hidden_states[1] + if not (self.shared_experts.down_proj.reduce_results and self.shared_experts.down_proj.tp_size > 1): + shared_experts_hidden = tensor_model_parallel_all_reduce(shared_experts_hidden) + hidden_states = ( experts_hidden_states[0] * self.routed_scaling_factor + - experts_hidden_states[1]) + shared_experts_hidden) return hidden_states @@ -225,18 +215,10 @@ def _forward_op_gating( ) -> torch.Tensor: if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata is None: - # for profile run - self.is_prefill = True - self.enable_force_load_balance = True - else: - is_prefill = attn_metadata.num_prefills > 0 - self.enable_force_load_balance = False - if hasattr(attn_metadata, 'with_prefill_across_dp'): - self.is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + enable_force_load_balance = get_forward_context().in_profile_run num_tokens, hidden_dim = hidden_states.shape @@ -291,17 +273,11 @@ def _forward_op_gating( # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. - if self.enable_force_load_balance: + if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, self.config.n_routed_experts) return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes - def _forward_dispatch_comm( - self, hidden_states, topk_weights, topk_ids, microbatch_id - ): - token_dispatcher = self.experts.token_dispatchers[microbatch_id] - _, hidden_states, tokens_per_expert = token_dispatcher.token_permutation(hidden_states, topk_weights, topk_ids) - return hidden_states, tokens_per_expert def _forward_op_shared_experts( self, hidden_states @@ -315,7 +291,7 @@ def _forward_op_grouped_mlp( self, dispatched_input, tokens_per_expert ): return apply_mlp( - [dispatched_input], + dispatched_input, self.experts.w13_weight, self.experts.w2_weight, tokens_per_expert @@ -325,8 +301,9 @@ def _forward_combine_comm( self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes ): token_dispatcher = self.experts.token_dispatchers[microbatch_id] - token_dispatcher.combine_alltoall() - final_hidden_states = token_dispatcher.unpermute2() * self.routed_scaling_factor + final_hidden_states, _ = token_dispatcher.token_unpermutation(hidden_states) + if hasattr(self, 'routed_scaling_factor'): + final_hidden_states = final_hidden_states * self.routed_scaling_factor if self.tp_size > 1: final_hidden_states = gather_from_sequence_parallel_region(final_hidden_states, self.tp_group, @@ -794,17 +771,12 @@ def _forward_ms_layer_alltoallv_finegrained( chunked_hidden_states_sizes = [None] * num_micro_batchs token_dispatchers = self.mlp.experts.token_dispatchers - def print_with_sync(*args, **kwargs): - torch.npu.synchronize() - print(*args, **kwargs) - def discard_tensor(tensor): if isinstance(tensor, torch.Tensor): tensor = [tensor] for t in tensor: t.untyped_storage().resize_(0) - # print_with_sync('begin layer...', torch.distributed.get_rank()) # block 1 : attention # block 2 : Router Gating @@ -814,12 +786,11 @@ def discard_tensor(tensor): # can be overlapped with the attn communication of microbatch 1 for i in range(num_micro_batchs): # wait last layer moe finishing communication - ms_metadata.try_wait_event(layer_index - 1, i, - MSEventKey.MOE_AFTER_COMM) forward_context = get_forward_context() layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( ) + ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) forward_context.attn_metadata = attn_metadata[i] # input layernorm @@ -856,9 +827,10 @@ def discard_tensor(tensor): with torch.npu.stream(dispatch_context.comm_stream): dispatch_context.comm_stream.wait_event(dispatch_context.before_comm_event) token_dispatchers[i].dispatch_alltoall() + dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2() dispatch_context.after_comm_event.record() - if self.mlp.n_shared_experts: + if self.mlp.n_shared_experts and self.tp_size > 1: token_dispatchers[i].cached_shared_expert_output = tensor_model_parallel_all_reduce( token_dispatchers[i].cached_shared_expert_output ) @@ -872,20 +844,16 @@ def discard_tensor(tensor): ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM) discard_tensor(hidden_states[i]) - dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2() router_expert_output[i] = self.mlp._forward_op_grouped_mlp(dispatched_input[i], tokens_per_expert[i]) discard_tensor(dispatched_input[i]) - token_dispatchers[i].unpermute1(router_expert_output[i]) - if router_expert_output[i].shape[0] > 0 and token_dispatchers[i].num_local_experts > 1: - discard_tensor(router_expert_output[i]) # Launch Combine Comm in a New Stream. combine_context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_BEFORE_COMM], + MSEventKey.FFN_COM_FINISH], after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], + MSEventKey.FFN_AR_FINISH], ) combine_context.before_comm_event.record() ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) @@ -1032,7 +1000,6 @@ def forward( if VLLM_ASCEND_ENABLE_DBO and not graph_enable and self.can_run_ms() else self.end_layer - self.start_layer) - moe_start_layer = self.start_layer + num_normal_layers for i in range(self.start_layer, min(moe_start_layer, self.end_layer)): layer = self.layers[i] @@ -1068,16 +1035,6 @@ def can_run_ms(self): return False return True - def all_can_run_ms(self): - can_run_ms_local = self.can_run_ms() - ep_group = get_ep_group().cpu_group - flag = torch.ones(1, dtype=torch.int) if can_run_ms_local else torch.zeros(1, dtype=torch.int) - torch.distributed.all_reduce(flag, group=ep_group) - if flag.item() == torch.distributed.get_world_size(ep_group): - return True - else: - return False - def _forward_ms_layers(self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1098,9 +1055,7 @@ def _forward_ms_layers(self, layer = self.layers[i] ms_layer_forward_func = layer._forward_ms_layer if fused_moe_state == FusedMoEState.All2AllSeq: - # ms_layer_forward_func = layer._forward_ms_layer_alltoallv ms_layer_forward_func = layer._forward_ms_layer_alltoallv_finegrained - # print("get_called......") hidden_states, residual = ms_layer_forward_func( positions=positions, hidden_states=hidden_states, diff --git a/vllm_ascend/ops/moe_dispatcher/moe_utils.py b/vllm_ascend/ops/moe_dispatcher/moe_utils.py index 6cffe4ac5f..dc19f75b33 100644 --- a/vllm_ascend/ops/moe_dispatcher/moe_utils.py +++ b/vllm_ascend/ops/moe_dispatcher/moe_utils.py @@ -1,4 +1,4 @@ -# +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index ae73a759e7..6906577778 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. From adf3f740f93f92cf2442f36bfc0bad5a7a8a9acc Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 9 Jul 2025 17:57:03 +0800 Subject: [PATCH 33/60] handle code clean Signed-off-by: weijinqian_v1 --- tests/multicard/test_qwen3_moe.py | 10 +- tests/singlecard/test_offline_inference.py | 1 - tests/ut/test_distributed_tensor_parallel.py | 56 ++-- tests/ut/test_moe_util.py | 116 ++++----- tests/ut/test_token_dispatcher.py | 13 +- vllm_ascend/ascend_forward_context.py | 3 +- vllm_ascend/attention/attention_v1.py | 4 +- vllm_ascend/envs.py | 2 +- vllm_ascend/models/__init__.py | 2 +- vllm_ascend/models/deepseek_dbo.py | 123 ++++----- vllm_ascend/models/moe_block.py | 10 +- vllm_ascend/models/qwen3_dbo.py | 241 ++++++++++-------- vllm_ascend/multistream/ms_split.py | 186 +++++++------- vllm_ascend/ops/fused_moe.py | 74 +++--- .../ops/moe_dispatcher/token_dispatcher.py | 176 ++++++------- 15 files changed, 517 insertions(+), 500 deletions(-) diff --git a/tests/multicard/test_qwen3_moe.py b/tests/multicard/test_qwen3_moe.py index 391cc48424..122f8024fb 100644 --- a/tests/multicard/test_qwen3_moe.py +++ b/tests/multicard/test_qwen3_moe.py @@ -1,4 +1,3 @@ - # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # @@ -32,7 +31,12 @@ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("max_tokens", [32]) -@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3", "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1", "VLLM_ASCEND_ENABLE_DBO": "1"}) +@patch.dict( + os.environ, { + "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3", + "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1", + "VLLM_ASCEND_ENABLE_DBO": "1" + }) def test_qwen3_moe_inference(model, max_tokens): script = "examples/offline_data_parallel.py" @@ -68,4 +72,4 @@ def test_qwen3_moe_inference(model, max_tokens): assert "DP rank 0 needs to process" in output assert "DP rank 1 needs to process" in output assert "Generated text:" in output - assert proc.returncode == 0 \ No newline at end of file + assert proc.returncode == 0 diff --git a/tests/singlecard/test_offline_inference.py b/tests/singlecard/test_offline_inference.py index 09f29f5c3a..cd65a24969 100644 --- a/tests/singlecard/test_offline_inference.py +++ b/tests/singlecard/test_offline_inference.py @@ -131,4 +131,3 @@ def test_models_topk() -> None: enforce_eager=True, gpu_memory_utilization=0.7) as vllm_model: vllm_model.generate(example_prompts, sampling_params) - diff --git a/tests/ut/test_distributed_tensor_parallel.py b/tests/ut/test_distributed_tensor_parallel.py index 6b05ab436f..ff4b8cde64 100644 --- a/tests/ut/test_distributed_tensor_parallel.py +++ b/tests/ut/test_distributed_tensor_parallel.py @@ -1,3 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. import pytest import torch import importlib @@ -5,8 +8,8 @@ from vllm_ascend.distributed.tensor_parallel import ( _gather_along_first_dim, _gather_along_last_dim, _reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, - all_to_all_sp2hp, all_to_all_hp2sp -) + all_to_all_sp2hp, all_to_all_hp2sp) + # 测试用的固定数据 @pytest.fixture @@ -37,7 +40,8 @@ class TestDistributedCommunication: """测试分布式通信函数""" @pytest.mark.parametrize("world_size", [1, 4]) - def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, world_size): + def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, + world_size): """测试_gather_along_first_dim""" mock_dist.get_world_size.return_value = world_size @@ -48,14 +52,17 @@ def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, world_ else: assert result.shape == (32, 16) # 8*4=32 - def test_gather_along_first_dim_unequal_split(self, test_tensor, mock_group): + def test_gather_along_first_dim_unequal_split(self, test_tensor, + mock_group): """测试不等分分割情况""" output_split_sizes = [5, 10, 15, 2] - result = _gather_along_first_dim(test_tensor, mock_group, output_split_sizes) + result = _gather_along_first_dim(test_tensor, mock_group, + output_split_sizes) assert result.shape == (32, 16) # 5+10+15+2=32 @pytest.mark.parametrize("world_size", [1, 4]) - def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, mock_dist, world_size): + def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, + mock_dist, world_size): """测试_gather_along_last_dim""" mock_dist.get_world_size.return_value = world_size @@ -64,13 +71,14 @@ def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, mock_dist if world_size == 1: assert torch.equal(result, test_tensor_last_dim) else: - assert result.shape == (8, 16, 32*world_size) # 8*4=32 + assert result.shape == (8, 16, 32 * world_size) # 8*4=32 @pytest.mark.parametrize("input_shape,expected_shape", [ ((32, 16), (8, 16)), ((40, 10), (10, 10)), ]) - def test_reduce_scatter_along_first_dim(self, mock_group, input_shape, expected_shape): + def test_reduce_scatter_along_first_dim(self, mock_group, input_shape, + expected_shape): input_tensor = torch.randn(*input_shape) result = _reduce_scatter_along_first_dim(input_tensor, mock_group) assert result.shape == expected_shape @@ -81,34 +89,40 @@ def test_reduce_scatter_along_last_dim(self, mock_group): assert result.shape == (8, 16, 8) # 32/4=8 @pytest.mark.parametrize("func,input_shape,expected_shape", [ - ("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), (8, 16, 128)), + ("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), + (8, 16, 128)), ("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)), - ("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), (8, 16, 8)), + ("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), + (8, 16, 8)), ("gather_from_sequence_parallel_region", (8, 16), (32, 16)), ]) - def test_wrapper_functions(self, mock_group, func, input_shape, expected_shape): + def test_wrapper_functions(self, mock_group, func, input_shape, + expected_shape): """测试包装函数""" - mod = importlib.import_module('vllm_ascend.distributed.tensor_parallel') + mod = importlib.import_module( + 'vllm_ascend.distributed.tensor_parallel') globals = mod.__dict__ test_func = globals[func] input_tensor = torch.randn(*input_shape) result = test_func(input_tensor, mock_group) assert result.shape == expected_shape - - @pytest.mark.parametrize("input_shape,output_shape", [ - ((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] - ]) + @pytest.mark.parametrize( + "input_shape,output_shape", + [ + ((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] + ]) def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape): input_tensor = torch.randn(*input_shape) result = all_to_all_sp2hp(input_tensor, mock_group) assert result.shape == output_shape - - @pytest.mark.parametrize("input_shape,output_shape", [ - ((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] - ]) + @pytest.mark.parametrize( + "input_shape,output_shape", + [ + ((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] + ]) def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape): input_tensor = torch.randn(*input_shape) result = all_to_all_hp2sp(input_tensor, mock_group) - assert result.shape == output_shape \ No newline at end of file + assert result.shape == output_shape diff --git a/tests/ut/test_moe_util.py b/tests/ut/test_moe_util.py index 9da4fb16b9..c88d2071ec 100644 --- a/tests/ut/test_moe_util.py +++ b/tests/ut/test_moe_util.py @@ -4,7 +4,7 @@ import torch import pytest import math -import vllm_ascend.patch.worker.patch_common.patch_utils +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, group_limited_topk, unpermute, sort_chunks_by_idxs @@ -22,24 +22,20 @@ def setup(self): self.num_groups = 2 self.scaling_factor = 1.0 - def test_group_limited_topk(self, setup): # Test group-limited topk routing scores = torch.randn(self.num_tokens, self.num_experts) - probs, indices = group_limited_topk( - scores, - topk=self.topk, - num_tokens=self.num_tokens, - num_experts=self.num_experts, - num_groups=self.num_groups, - group_topk=self.group_topk - ) + probs, indices = group_limited_topk(scores, + topk=self.topk, + num_tokens=self.num_tokens, + num_experts=self.num_experts, + num_groups=self.num_groups, + group_topk=self.group_topk) assert probs.shape == (self.num_tokens, self.topk) assert indices.shape == (self.num_tokens, self.topk) assert torch.all(indices < self.num_experts) - @pytest.mark.parametrize("score_function", ["softmax"]) def test_topk_softmax_with_capacity(self, setup, score_function): # Test topk softmax with capacity @@ -47,13 +43,10 @@ def test_topk_softmax_with_capacity(self, setup, score_function): # Test without capacity probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( - logits, - topk=self.topk, - score_function=score_function - ) + logits, topk=self.topk, score_function=score_function) assert probs.shape == (self.num_tokens, self.num_experts) assert routing_map.shape == (self.num_tokens, self.num_experts) - assert tokens_per_expert.shape == (self.num_experts,) + assert tokens_per_expert.shape == (self.num_experts, ) # Test with group routing probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( @@ -61,36 +54,31 @@ def test_topk_softmax_with_capacity(self, setup, score_function): topk=self.topk, num_groups=self.num_groups, group_topk=self.group_topk, - score_function=score_function - ) + score_function=score_function) assert probs.shape == (self.num_tokens, self.num_experts) - def test_get_capacity(self, setup): # Test capacity calculation - capacity = get_capacity( - num_tokens=self.num_tokens, - num_experts=self.num_experts, - capacity_factor=self.capacity_factor - ) - expected = math.ceil((self.num_tokens / self.num_experts) * self.capacity_factor) + capacity = get_capacity(num_tokens=self.num_tokens, + num_experts=self.num_experts, + capacity_factor=self.capacity_factor) + expected = math.ceil( + (self.num_tokens / self.num_experts) * self.capacity_factor) assert capacity == expected # Test with min capacity min_capacity = 5 - capacity = get_capacity( - num_tokens=self.num_tokens, - num_experts=self.num_experts, - capacity_factor=self.capacity_factor, - min_capacity=min_capacity - ) + capacity = get_capacity(num_tokens=self.num_tokens, + num_experts=self.num_experts, + capacity_factor=self.capacity_factor, + min_capacity=min_capacity) assert capacity == min_capacity - def test_permute(self, setup): # Test token permutation tokens = torch.randn(self.num_tokens, self.hidden_size) - routing_map = torch.randint(0, 2, (self.num_tokens, self.num_experts)).bool() + routing_map = torch.randint( + 0, 2, (self.num_tokens, self.num_experts)).bool() # Basic permutation permuted_tokens, sorted_indices = permute(tokens, routing_map) @@ -98,65 +86,54 @@ def test_permute(self, setup): assert sorted_indices.shape[0] == routing_map.sum() # With drop and pad - capacity = get_capacity( - num_tokens=self.num_tokens * self.topk, - num_experts=self.num_experts, - capacity_factor=self.capacity_factor - ) + capacity = get_capacity(num_tokens=self.num_tokens * self.topk, + num_experts=self.num_experts, + capacity_factor=self.capacity_factor) num_out_tokens = capacity * self.num_experts permuted_tokens, sorted_indices = permute( tokens, routing_map, num_out_tokens=num_out_tokens, - drop_and_pad=True - ) + drop_and_pad=True) assert permuted_tokens.shape[0] == num_out_tokens assert sorted_indices.shape[0] == num_out_tokens - def test_unpermute(self, setup): # Test token unpermutation tokens = torch.randn(self.num_tokens, self.hidden_size) - routing_map = torch.randint(0, 2, (self.num_tokens, self.num_experts)).bool() + routing_map = torch.randint( + 0, 2, (self.num_tokens, self.num_experts)).bool() probs = torch.rand(self.num_tokens, self.num_experts) # First permute permuted_tokens, sorted_indices = permute(tokens, routing_map) # Then unpermute - restored_tokens = unpermute( - permuted_tokens, - sorted_indices, - tokens.shape, - probs=probs, - routing_map=routing_map - ) + restored_tokens = unpermute(permuted_tokens, + sorted_indices, + tokens.shape, + probs=probs, + routing_map=routing_map) assert restored_tokens.shape == tokens.shape # With drop and pad - capacity = get_capacity( - num_tokens=self.num_tokens * self.topk, - num_experts=self.num_experts, - capacity_factor=self.capacity_factor - ) + capacity = get_capacity(num_tokens=self.num_tokens * self.topk, + num_experts=self.num_experts, + capacity_factor=self.capacity_factor) num_out_tokens = capacity * self.num_experts permuted_tokens, sorted_indices = permute( tokens, routing_map, num_out_tokens=num_out_tokens, - drop_and_pad=True - ) - restored_tokens = unpermute( - permuted_tokens, - sorted_indices, - tokens.shape, - probs=probs, - routing_map=routing_map, - drop_and_pad=True - ) + drop_and_pad=True) + restored_tokens = unpermute(permuted_tokens, + sorted_indices, + tokens.shape, + probs=probs, + routing_map=routing_map, + drop_and_pad=True) assert restored_tokens.shape == tokens.shape - def test_sort_chunks_by_idxs(self, setup): # Test chunk sorting input_tensor = torch.randn(10, self.hidden_size) @@ -167,10 +144,10 @@ def test_sort_chunks_by_idxs(self, setup): assert output.shape == input_tensor.shape # Verify the order is correct - expected = torch.cat([input_tensor[5:], input_tensor[0: 3], input_tensor[3: 5]]) + expected = torch.cat( + [input_tensor[5:], input_tensor[0:3], input_tensor[3:5]]) assert torch.allclose(output, expected) - @pytest.mark.parametrize("score_function", ["softmax"]) def test_score_functions(self, setup, score_function): # Test different score functions @@ -181,8 +158,7 @@ def test_score_functions(self, setup, score_function): logits, topk=self.topk, score_function=score_function, - expert_bias=expert_bias - ) + expert_bias=expert_bias) assert probs.shape == (self.num_tokens, self.num_experts) assert routing_map.shape == (self.num_tokens, self.num_experts) - assert tokens_per_expert.shape == (self.num_experts,) \ No newline at end of file + assert tokens_per_expert.shape == (self.num_experts, ) diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py index b389eb430f..a5d313cf12 100644 --- a/tests/ut/test_token_dispatcher.py +++ b/tests/ut/test_token_dispatcher.py @@ -4,14 +4,16 @@ import torch import pytest +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + from pytest_mock import MockerFixture -import vllm_ascend.patch.worker.patch_common.patch_utils -from vllm_ascend.utils import adapt_patch # noqa E402 +from vllm_ascend.utils import adapt_patch # noqa E402 from vllm_ascend.ops.moe_dispatcher.token_dispatcher import MoeDispatcherConfig, MoEAlltoAllSeqOverLapDispatcher adapt_patch(True) + class TestMoEAlltoAllSeqOverLapDispatcher: @pytest.fixture @@ -37,8 +39,9 @@ def mock_ep_group(self, mocker): @pytest.fixture def dispatcher(self, config, mocker: MockerFixture): - mocker.patch("vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group", - return_value=self.mock_ep_group(mocker)) + mocker.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group", + return_value=self.mock_ep_group(mocker)) return MoEAlltoAllSeqOverLapDispatcher(config) def test_initialization(self, dispatcher, config): @@ -53,4 +56,4 @@ def test_routing(self, dispatcher): probs = torch.randn(4, 4) # 4 tokens, 4 experts scores, routing_map = dispatcher.routing(probs) assert scores.shape == (4, 4) # topk=2 - assert routing_map.shape == (4, 4) \ No newline at end of file + assert routing_map.shape == (4, 4) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 0ed0cbc29c..965e7b7d2c 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -23,7 +23,8 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool): return FusedMoEState.AllGather elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: # MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage. - return FusedMoEState.All2AllSeq if (ep_size < 16 or with_prefill) else FusedMoEState.MC2 + return FusedMoEState.All2AllSeq if ( + ep_size < 16 or with_prefill) else FusedMoEState.MC2 # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. elif ep_size < 16 or with_prefill: return FusedMoEState.All2All diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index ef3848987a..f8cc4a5a54 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -142,8 +142,8 @@ class AscendMetadata: enable_dbo_across_dp: bool = False def split_metadata_for_multistream( - self, - ms_split_config: MSAttentionMetadataSplitConfig, + self, + ms_split_config: MSAttentionMetadataSplitConfig, ) -> list["AscendMetadata"]: """Split metadata for multi-stream with AscendMetadata""" from vllm_ascend.multistream.ms_split import model_input_split_v1_attn diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 8af5bdd783..06a45c9598 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -110,7 +110,7 @@ # 0: default, normal init. # 1: enable moe_all2all_buffer. "VLLM_ASCEND_MOE_ALL2ALL_BUFFER": - lambda: bool(int(os.getenv("VLLM_ASCEND_MOE_ALL2ALL_BUFFER", '0'))), + lambda: bool(int(os.getenv("VLLM_ASCEND_MOE_ALL2ALL_BUFFER", '0'))), # Some models are optimized by vllm ascend. While in some case, e.g. rlhf # training, the optimized model may not be suitable. In this case, set this # value to False to disable the optimized model. diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 276c604f4f..f3260bcc9e 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -11,7 +11,7 @@ def register_model(): from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 - from .moe_block import AscendSparseMoeBlock + from .moe_block import AscendSparseMoeBlock # noqa: F401 ModelRegistry.register_model( "DeepSeekMTPModel", diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 5051b23ab3..6f02dc2a95 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -34,7 +34,7 @@ from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_ep_group, get_pp_group, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) @@ -71,7 +71,7 @@ make_multistream_metadata_ds) from vllm_ascend.quantization.w8a8_dynamic import ( AscendW8A8DynamicLinearMethod, apply_mlp) -from vllm_ascend.ops.fused_moe import apply_mlp, select_experts +from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -183,8 +183,10 @@ def forward( shared_experts=self.shared_experts) shared_experts_hidden = experts_hidden_states[1] - if not (self.shared_experts.down_proj.reduce_results and self.shared_experts.down_proj.tp_size > 1): - shared_experts_hidden = tensor_model_parallel_all_reduce(shared_experts_hidden) + if not (self.shared_experts.down_proj.reduce_results + and self.shared_experts.down_proj.tp_size > 1): + shared_experts_hidden = tensor_model_parallel_all_reduce( + shared_experts_hidden) hidden_states = ( experts_hidden_states[0] * self.routed_scaling_factor + @@ -211,8 +213,7 @@ def _forward_ms_op_gate( def _forward_op_gating( self, hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None - ) -> torch.Tensor: + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata # when profile runs, force experts to load balanced tokens @@ -278,42 +279,36 @@ def _forward_op_gating( return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes - - def _forward_op_shared_experts( - self, hidden_states - ): + def _forward_op_shared_experts(self, hidden_states): if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) return shared_output - def _forward_op_grouped_mlp( - self, dispatched_input, tokens_per_expert - ): - return apply_mlp( - dispatched_input, - self.experts.w13_weight, - self.experts.w2_weight, - tokens_per_expert - ) + def _forward_op_grouped_mlp(self, dispatched_input, tokens_per_expert): + from vllm_ascend.ops.fused_moe import apply_mlp + return apply_mlp(dispatched_input, self.experts.w13_weight, + self.experts.w2_weight, tokens_per_expert) - def _forward_combine_comm( - self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes - ): + def _forward_combine_comm(self, hidden_states, microbatch_id, num_tokens, + chunked_hidden_states_sizes): token_dispatcher = self.experts.token_dispatchers[microbatch_id] - final_hidden_states, _ = token_dispatcher.token_unpermutation(hidden_states) + final_hidden_states, _ = token_dispatcher.token_unpermutation( + hidden_states) if hasattr(self, 'routed_scaling_factor'): final_hidden_states = final_hidden_states * self.routed_scaling_factor if self.tp_size > 1: - final_hidden_states = gather_from_sequence_parallel_region(final_hidden_states, self.tp_group, - chunked_hidden_states_sizes) + final_hidden_states = gather_from_sequence_parallel_region( + final_hidden_states, self.tp_group, + chunked_hidden_states_sizes) if num_tokens < self.tp_size: final_hidden_states = final_hidden_states[:num_tokens] if self.shared_experts is not None: final_hidden_states = final_hidden_states + token_dispatcher.cached_shared_expert_output - token_dispatcher.cached_shared_expert_output.untyped_storage().resize_(0) + token_dispatcher.cached_shared_expert_output.untyped_storage( + ).resize_(0) token_dispatcher.cached_shared_expert_output = None final_hidden_states = final_hidden_states.view(num_tokens, -1) @@ -744,13 +739,13 @@ def _forward_ms_layer( # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_layer_alltoallv_finegrained( - self, - positions: List[torch.Tensor], - hidden_states: List[torch.Tensor], - residual: List[torch.Tensor], - attn_metadata: List[AttentionMetadata], - kv_cache: Optional[torch.Tensor] = None, - is_prefill: bool = False, + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: List[torch.Tensor], + attn_metadata: List[AttentionMetadata], + kv_cache: Optional[torch.Tensor] = None, + is_prefill: bool = False, ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( ) @@ -763,10 +758,11 @@ def _forward_ms_layer_alltoallv_finegrained( assert attn_metadata is not None num_tokens = [None] * num_micro_batchs hidden_dims = [None] * num_micro_batchs - topk_weights, topk_ids = [None] * num_micro_batchs, [None] * num_micro_batchs + topk_weights, topk_ids = [None] * num_micro_batchs, [ + None + ] * num_micro_batchs tokens_per_expert = [None] * num_micro_batchs dispatched_input = [None] * num_micro_batchs - shared_expert_output = [None] * num_micro_batchs router_expert_output = [None] * num_micro_batchs chunked_hidden_states_sizes = [None] * num_micro_batchs token_dispatchers = self.mlp.experts.token_dispatchers @@ -777,7 +773,6 @@ def discard_tensor(tensor): for t in tensor: t.untyped_storage().resize_(0) - # block 1 : attention # block 2 : Router Gating # block 3 : Token DisPatch @@ -790,30 +785,35 @@ def discard_tensor(tensor): forward_context = get_forward_context() layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( ) - ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) forward_context.attn_metadata = attn_metadata[i] # input layernorm hidden_states[i], residual[ i] = self._forward_ms_op_input_layernorm( - hidden_states[i], residual[i]) + hidden_states[i], residual[i]) # attention and tp allreduce hidden_states[i], residual[i] = self._forward_ms_op_attn( positions[i], hidden_states[i], residual[i], kv_cache, attn_metadata[i]) # post attention layer norm - hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm( - hidden_states[i], residual[i] - ) + hidden_states[i], residual[ + i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i]) num_tokens[i], hidden_dims[i] = hidden_states[i].shape # If TP is enabled, hidden_states will be chunked. - topk_weights[i], topk_ids[i], dispatched_input[i], chunked_hidden_states_sizes[ - i] = self.mlp._forward_op_gating(hidden_states[i], attn_metadata[i]) + topk_weights[i], topk_ids[i], dispatched_input[ + i], chunked_hidden_states_sizes[ + i] = self.mlp._forward_op_gating(hidden_states[i], + attn_metadata[i]) token_dispatchers[i].preprocess_and_permtute1( - dispatched_input[i], topk_weights[i], topk_ids[i], + dispatched_input[i], + topk_weights[i], + topk_ids[i], self.mlp.shared_experts, - shared_experts_input=hidden_states[i] if self.mlp.n_shared_experts else None - ) + shared_experts_input=hidden_states[i] + if self.mlp.n_shared_experts else None) # Launch DisPatch Comm in a New Stream. dispatch_context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, @@ -825,26 +825,31 @@ def discard_tensor(tensor): dispatch_context.before_comm_event.record() # print_with_sync(f'begin token dispatch{i}...', torch.distributed.get_rank()) with torch.npu.stream(dispatch_context.comm_stream): - dispatch_context.comm_stream.wait_event(dispatch_context.before_comm_event) + dispatch_context.comm_stream.wait_event( + dispatch_context.before_comm_event) token_dispatchers[i].dispatch_alltoall() - dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2() + dispatched_input[i], tokens_per_expert[i] = token_dispatchers[ + i].permute2() dispatch_context.after_comm_event.record() if self.mlp.n_shared_experts and self.tp_size > 1: - token_dispatchers[i].cached_shared_expert_output = tensor_model_parallel_all_reduce( - token_dispatchers[i].cached_shared_expert_output - ) - ms_metadata.ms_events[layer_index][i][MSEventKey.MOE_SE_COMM_FINISH].record() + token_dispatchers[ + i].cached_shared_expert_output = tensor_model_parallel_all_reduce( + token_dispatchers[i].cached_shared_expert_output) + ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMM_FINISH].record() # print_with_sync('begin experts...', torch.distributed.get_rank()) # block 4 : Router Experts Computation # block 5 : Token Combine Communication for i in range(num_micro_batchs): - ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM) + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_AFTER_COMM) discard_tensor(hidden_states[i]) - router_expert_output[i] = self.mlp._forward_op_grouped_mlp(dispatched_input[i], tokens_per_expert[i]) + router_expert_output[i] = self.mlp._forward_op_grouped_mlp( + dispatched_input[i], tokens_per_expert[i]) discard_tensor(dispatched_input[i]) # Launch Combine Comm in a New Stream. @@ -856,12 +861,14 @@ def discard_tensor(tensor): MSEventKey.FFN_AR_FINISH], ) combine_context.before_comm_event.record() - ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_SE_COMM_FINISH) with torch.npu.stream(combine_context.comm_stream): - combine_context.comm_stream.wait_event(combine_context.before_comm_event) + combine_context.comm_stream.wait_event( + combine_context.before_comm_event) hidden_states[i] = self.mlp._forward_combine_comm( - router_expert_output[i], i, num_tokens[i], chunked_hidden_states_sizes[i] - ) + router_expert_output[i], i, num_tokens[i], + chunked_hidden_states_sizes[i]) combine_context.after_comm_event.record() return hidden_states, residual diff --git a/vllm_ascend/models/moe_block.py b/vllm_ascend/models/moe_block.py index 097aa4833b..1129aa6716 100644 --- a/vllm_ascend/models/moe_block.py +++ b/vllm_ascend/models/moe_block.py @@ -41,10 +41,10 @@ class AscendSparseMoeBlock(nn.Module): top_k: int def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -84,7 +84,6 @@ def __init__( self.params_dtype = torch.get_default_dtype() - def forward( self, hidden_states: torch.Tensor, @@ -115,4 +114,5 @@ def forward( return hidden_states + qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 2f2760f559..35877cb0ae 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -1,5 +1,27 @@ -from collections.abc import Iterable -from typing import Any, Optional, Union, List +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# """Inference-only Qwen3 model.""" +from typing import Optional, Union, List from types import SimpleNamespace import torch @@ -12,10 +34,11 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.attention import AttentionMetadata from vllm.forward_context import get_forward_context, set_forward_context -from vllm.distributed import tensor_model_parallel_all_reduce, get_tensor_model_parallel_world_size, get_tp_group, \ +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, \ get_pp_group from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.models.utils import (make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.models.utils import ( + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.sequence import IntermediateTensors from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM @@ -24,43 +47,38 @@ from vllm.compilation.decorators import support_torch_compile from vllm_ascend.multistream.context import ( - advance_step_multistream_layer_context, get_multistream_comm_context, - get_multistream_layer_context, set_multistream_context) + advance_step_multistream_layer_context, get_multistream_layer_context) from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, MultiStreamPreTransformerLayer) from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.ops.fused_moe import AscendFusedMoE, select_experts, apply_mlp +from vllm_ascend.ops.fused_moe import select_experts, apply_mlp from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region import vllm_ascend.envs as envs_ascend -from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO class Qwen3MoeDecoderLayerDBO(Qwen3MoeDecoderLayer): + def __init__( - self, - config: PretrainedConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: - super(Qwen3MoeDecoderLayerDBO, self).__init__(config, cache_config, quant_config, prefix) + super(Qwen3MoeDecoderLayerDBO, self).__init__(config, cache_config, + quant_config, prefix) self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group self.tp_group = get_tp_group().device_group self.dummy_vllm_config = SimpleNamespace( - parallel_config=SimpleNamespace( - data_parallel_size=1, - ), - compilation_config=SimpleNamespace( - static_forward_context=None, - ), - other_setting="value" - ) + parallel_config=SimpleNamespace(data_parallel_size=1, ), + compilation_config=SimpleNamespace(static_forward_context=None, ), + other_setting="value") self.config = config def forward(self, *args, **kwargs): @@ -68,9 +86,9 @@ def forward(self, *args, **kwargs): # should split ops in Decoder Layer def _forward_ms_op_input_layernorm( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: residual = hidden_states @@ -81,14 +99,15 @@ def _forward_ms_op_input_layernorm( return hidden_states, residual def _forward_ms_op_attn( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None, + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - self.dummy_vllm_config.compilation_config.static_forward_context = get_forward_context().no_compile_layers + self.dummy_vllm_config.compilation_config.static_forward_context = get_forward_context( + ).no_compile_layers with set_forward_context(attn_metadata, self.dummy_vllm_config): hidden_states = self.self_attn( positions=positions, @@ -106,9 +125,9 @@ def _forward_ms_op_attn( return hidden_states, residual def _forward_ms_op_post_attn_layernorm( - self, - hidden_states: torch.Tensor, - residual: Optional[torch.Tensor], + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], ): hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) @@ -117,8 +136,7 @@ def _forward_ms_op_post_attn_layernorm( def _forward_op_gating( self, hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None - ) -> torch.Tensor: + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata # when profile runs, force experts to load balanced tokens @@ -146,7 +164,9 @@ def _forward_op_gating( chunk_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0) - chunked_hidden_states_sizes = [x.shape[0] for x in chunk_hidden_states] + chunked_hidden_states_sizes = [ + x.shape[0] for x in chunk_hidden_states + ] local_hidden_states = chunk_hidden_states[self.tp_rank] else: local_hidden_states = hidden_states @@ -182,8 +202,9 @@ def _forward_op_gating( num_expert_group=getattr(mlp_config, "n_group", None), custom_routing_function=None, scoring_func=getattr(mlp_config, "scoring_func", 'softmax'), - e_score_correction_bias=getattr(self.mlp.gate, "e_score_correction_bias", None) - ) + e_score_correction_bias=getattr(self.mlp.gate, + "e_score_correction_bias", + None)) topk_weights = topk_weights.to(hidden_states.dtype) # this is a naive implementation for experts load balance so as @@ -194,33 +215,29 @@ def _forward_op_gating( return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes - def _forward_op_grouped_mlp( - self, dispatched_input, tokens_per_expert - ): - return apply_mlp( - dispatched_input, - self.mlp.experts.w13_weight, - self.mlp.experts.w2_weight, - tokens_per_expert - ) + def _forward_op_grouped_mlp(self, dispatched_input, tokens_per_expert): + return apply_mlp(dispatched_input, self.mlp.experts.w13_weight, + self.mlp.experts.w2_weight, tokens_per_expert) - def _forward_combine_comm( - self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes - ): + def _forward_combine_comm(self, hidden_states, microbatch_id, num_tokens, + chunked_hidden_states_sizes): token_dispatcher = self.mlp.experts.token_dispatchers[microbatch_id] - final_hidden_states, _ = token_dispatcher.token_unpermutation(hidden_states) + final_hidden_states, _ = token_dispatcher.token_unpermutation( + hidden_states) if hasattr(self.mlp, 'routed_scaling_factor'): final_hidden_states = final_hidden_states * self.mlp.routed_scaling_factor if self.tp_size > 1: - final_hidden_states = gather_from_sequence_parallel_region(final_hidden_states, self.tp_group, - chunked_hidden_states_sizes) + final_hidden_states = gather_from_sequence_parallel_region( + final_hidden_states, self.tp_group, + chunked_hidden_states_sizes) if num_tokens < self.tp_size: final_hidden_states = final_hidden_states[:num_tokens] if hasattr(self.mlp, "shared_experts"): final_hidden_states = final_hidden_states + token_dispatcher.cached_shared_expert_output - token_dispatcher.cached_shared_expert_output.untyped_storage().resize_(0) + token_dispatcher.cached_shared_expert_output.untyped_storage( + ).resize_(0) token_dispatcher.cached_shared_expert_output = None final_hidden_states = final_hidden_states.view(num_tokens, -1) @@ -228,12 +245,12 @@ def _forward_combine_comm( return final_hidden_states def _forward_ms_layer_alltoallv_finegrained( - self, - positions: List[torch.Tensor], - hidden_states: List[torch.Tensor], - residual: List[torch.Tensor], - attn_metadata: List[AttentionMetadata], - kv_cache: Optional[torch.Tensor] = None, + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: List[torch.Tensor], + attn_metadata: List[AttentionMetadata], + kv_cache: Optional[torch.Tensor] = None, ): layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( ) @@ -245,7 +262,9 @@ def _forward_ms_layer_alltoallv_finegrained( assert attn_metadata is not None num_tokens = [None] * num_micro_batchs hidden_dims = [None] * num_micro_batchs - topk_weights, topk_ids = [None] * num_micro_batchs, [None] * num_micro_batchs + topk_weights, topk_ids = [None] * num_micro_batchs, [ + None + ] * num_micro_batchs tokens_per_expert = [None] * num_micro_batchs dispatched_input = [None] * num_micro_batchs shared_expert_output = [None] * num_micro_batchs @@ -270,29 +289,33 @@ def discard_tensor(tensor): forward_context = get_forward_context() layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( ) - ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) forward_context.attn_metadata = attn_metadata[i] # input layernorm hidden_states[i], residual[ i] = self._forward_ms_op_input_layernorm( - hidden_states[i], residual[i]) + hidden_states[i], residual[i]) # attention and tp allreduce hidden_states[i], residual[i] = self._forward_ms_op_attn( positions[i], hidden_states[i], residual[i], kv_cache, attn_metadata[i]) # post attention layer norm - hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm( - hidden_states[i], residual[i] - ) + hidden_states[i], residual[ + i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i]) num_tokens[i], hidden_dims[i] = hidden_states[i].shape # If TP is enabled, hidden_states will be chunked. - topk_weights[i], topk_ids[i], dispatched_input[i], chunked_hidden_states_sizes[i] = self._forward_op_gating( - hidden_states[i], attn_metadata[i]) + topk_weights[i], topk_ids[i], dispatched_input[ + i], chunked_hidden_states_sizes[i] = self._forward_op_gating( + hidden_states[i], attn_metadata[i]) token_dispatchers[i].preprocess_and_permtute1( - dispatched_input[i], topk_weights[i], topk_ids[i], - shared_experts=None, shared_experts_input=None - ) + dispatched_input[i], + topk_weights[i], + topk_ids[i], + shared_experts=None, + shared_experts_input=None) # Launch DisPatch Comm in a New Stream. dispatch_context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, @@ -304,18 +327,22 @@ def discard_tensor(tensor): dispatch_context.before_comm_event.record() # print_with_sync(f'begin token dispatch{i}...', torch.distributed.get_rank()) with torch.npu.stream(dispatch_context.comm_stream): - dispatch_context.comm_stream.wait_event(dispatch_context.before_comm_event) + dispatch_context.comm_stream.wait_event( + dispatch_context.before_comm_event) token_dispatchers[i].dispatch_alltoall() - dispatched_input[i], tokens_per_expert[i] = token_dispatchers[i].permute2() + dispatched_input[i], tokens_per_expert[i] = token_dispatchers[ + i].permute2() dispatch_context.after_comm_event.record() # print_with_sync('begin experts...', torch.distributed.get_rank()) # block 4 : Router Experts Computation # block 5 : Token Combine Communication for i in range(num_micro_batchs): - ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM) + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_AFTER_COMM) discard_tensor(hidden_states[i]) - router_expert_output[i] = self._forward_op_grouped_mlp(dispatched_input[i], tokens_per_expert[i]) + router_expert_output[i] = self._forward_op_grouped_mlp( + dispatched_input[i], tokens_per_expert[i]) discard_tensor(dispatched_input[i]) # Launch Combine Comm in a New Stream. @@ -327,19 +354,25 @@ def discard_tensor(tensor): MSEventKey.FFN_AR_FINISH], ) combine_context.before_comm_event.record() - ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_SE_COMM_FINISH) with torch.npu.stream(combine_context.comm_stream): - combine_context.comm_stream.wait_event(combine_context.before_comm_event) + combine_context.comm_stream.wait_event( + combine_context.before_comm_event) hidden_states[i] = self._forward_combine_comm( - router_expert_output[i], i, num_tokens[i], chunked_hidden_states_sizes[i] - ) - ms_metadata.ms_events[layer_index][i][MSEventKey.FFN_AR_FINISH] = combine_context.comm_stream.record_event() + router_expert_output[i], i, num_tokens[i], + chunked_hidden_states_sizes[i]) + ms_metadata.ms_events[layer_index][i][ + MSEventKey. + FFN_AR_FINISH] = combine_context.comm_stream.record_event( + ) return hidden_states, residual @support_torch_compile class CustomQwen3DBOMoEModel(Qwen3MoeModel): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) @@ -383,11 +416,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): multistream_metadata) def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: @@ -414,8 +447,7 @@ def forward( positions=positions, hidden_states=hidden_states, residual=residual, - moe_start_layer=moe_start_layer - ) + moe_start_layer=moe_start_layer) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -436,12 +468,12 @@ def can_run_ms(self): return True def _forward_ms_layers( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - residual: torch.Tensor, - moe_start_layer: int, - kv_caches: Optional[List[torch.Tensor]] = None, + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + kv_caches: Optional[List[torch.Tensor]] = None, ): if moe_start_layer == self.end_layer: @@ -449,7 +481,7 @@ def _forward_ms_layers( attn_metadata, [positions, hidden_states, residual] = self.ms_pre_layer( - [positions, hidden_states, residual], ) + [positions, hidden_states, residual], ) num_micro_batch = len(attn_metadata) # the rest layers for i in range(moe_start_layer, self.end_layer): @@ -464,10 +496,11 @@ def _forward_ms_layers( ) advance_step_multistream_layer_context() - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) for i in range(num_micro_batch): - ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) - + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) [hidden_states, residual] = self.ms_post_layer([hidden_states, residual], ) @@ -486,7 +519,7 @@ class CustomQwen3MoeForCausalLMDBO(Qwen3MoeForCausalLM): "up_proj", ], "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -496,7 +529,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.quant_config = quant_config self.model = CustomQwen3DBOMoEModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + prefix=maybe_prefix( + prefix, "model")) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) @@ -505,11 +539,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) - + def forward(self, *args, **kwargs): if "graph_enable" in kwargs: kwargs.pop('graph_enable') return super().forward(*args, **kwargs) - - - diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 69d078a47d..1c765d9bcd 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -249,105 +249,111 @@ def model_input_split_v1_mla_attn( def model_input_split_v1_attn( - attn_metadata: AscendMetadata, - _metadata_cls, - ms_split_config: MSAttentionMetadataSplitConfig, - ) -> List[Any]: - assert 0 < ms_split_config.num_micro_batches < 3 - if attn_metadata is None: - return [attn_metadata] - [token_index, - seq_index] = compute_split_seq_index(attn_metadata.query_lens, - attn_metadata.attn_state, - attn_metadata.num_actual_tokens) - if token_index == 0 or seq_index == 0 or seq_index == len( - attn_metadata.query_lens): - return [attn_metadata] - - - # split attn metadata - + attn_metadata: AscendMetadata, + _metadata_cls, + ms_split_config: MSAttentionMetadataSplitConfig, +) -> List[Any]: + assert 0 < ms_split_config.num_micro_batches < 3 + if attn_metadata is None: + return [attn_metadata] + [token_index, + seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_actual_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return [attn_metadata] - [block_table_pre, block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, seq_index) + # split attn metadata - query_start_loc_pre = query_start_loc_post = None - if attn_metadata.query_start_loc is not None: - query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] - query_start_loc_post = deepcopy( - attn_metadata.query_start_loc[seq_index:] - ) - attn_metadata.query_start_loc[seq_index] + [block_table_pre, + block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, + seq_index) - [query_lens_pre, query_lens_post] = split_attn_tensor_type(attn_metadata.query_lens, seq_index) - [seq_lens_pre, seq_lens_post] = split_attn_tensor_type(attn_metadata.seq_lens, seq_index) + query_start_loc_pre = query_start_loc_post = None + if attn_metadata.query_start_loc is not None: + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] - max_query_len_pre = max_query_len_post = None - if attn_metadata.max_query_len is not None: - max_query_len_pre, max_query_len_post = max(query_lens_pre), max(query_lens_post) + [query_lens_pre, + query_lens_post] = split_attn_tensor_type(attn_metadata.query_lens, + seq_index) + [seq_lens_pre, + seq_lens_post] = split_attn_tensor_type(attn_metadata.seq_lens, seq_index) - [slot_mapping_pre, slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping, token_index) + max_query_len_pre = max_query_len_post = None + if attn_metadata.max_query_len is not None: + max_query_len_pre, max_query_len_post = max(query_lens_pre), max( + query_lens_post) - is_only_prefill_pre = is_only_prefill_post = attn_metadata.is_only_prefill - has_prefill_pre, has_prefill_post = torch.any(query_lens_pre > 1).item(), torch.any(query_lens_post > 1).item() + [slot_mapping_pre, + slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping, + token_index) - if not attn_metadata.is_only_prefill: - is_only_prefill_post = torch.all(query_lens_post > 1).item() + is_only_prefill_pre = is_only_prefill_post = attn_metadata.is_only_prefill + has_prefill_pre, has_prefill_post = torch.any( + query_lens_pre > 1).item(), torch.any(query_lens_post > 1).item() + if not attn_metadata.is_only_prefill: + is_only_prefill_post = torch.all(query_lens_post > 1).item() - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: - # the attn_mla kernel in torch npu only accept 128*128 attn mask - attn_mask_pre = attn_mask_post = attn_metadata.attn_mask - attn_state_pre = attn_state_post = attn_metadata.attn_state - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - # should be none in decode only state - attn_mask_pre = attn_mask_post = attn_metadata.attn_mask - attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly - else: - # chunked prefill - if has_prefill_pre: - attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill - attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( - seq_lens_pre)].contiguous() - attn_state_post = AscendAttentionState.ChunkedPrefill - attn_mask_post = attn_metadata.attn_mask[ - token_index:, :max(seq_lens_post)].contiguous() - else: - attn_state_pre = AscendAttentionState.DecodeOnly - attn_mask_pre = None - attn_state_post = AscendAttentionState.ChunkedPrefill - attn_mask_post = attn_metadata.attn_mask[ - token_index:, :max(seq_lens_post)].contiguous() + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + # the attn_mla kernel in torch npu only accept 128*128 attn mask + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = attn_metadata.attn_state + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + # should be none in decode only state + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly + else: + # chunked prefill + if has_prefill_pre: + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( + seq_lens_pre)].contiguous() + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + else: + attn_state_pre = AscendAttentionState.DecodeOnly + attn_mask_pre = None + attn_state_post = AscendAttentionState.ChunkedPrefill + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() - # construct metadata - attention_metadata_pre = _metadata_cls( - num_actual_tokens=token_index, - block_tables=block_table_pre, - query_start_loc=query_start_loc_pre, - query_lens=query_lens_pre, - seq_lens=seq_lens_pre, - seq_lens_list=seq_lens_pre.tolist(), - max_query_len=max_query_len_pre, - slot_mapping=slot_mapping_pre, - is_only_prefill=is_only_prefill_pre, - attn_state=attn_state_pre, - attn_mask=attn_mask_pre, - num_input_tokens=token_index, - enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, - ) + # construct metadata + attention_metadata_pre = _metadata_cls( + num_actual_tokens=token_index, + block_tables=block_table_pre, + query_start_loc=query_start_loc_pre, + query_lens=query_lens_pre, + seq_lens=seq_lens_pre, + seq_lens_list=seq_lens_pre.tolist(), + max_query_len=max_query_len_pre, + slot_mapping=slot_mapping_pre, + is_only_prefill=is_only_prefill_pre, + attn_state=attn_state_pre, + attn_mask=attn_mask_pre, + num_input_tokens=token_index, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, + ) - attention_metadata_post = _metadata_cls( - num_actual_tokens=attn_metadata.num_actual_tokens - token_index, - block_tables=block_table_post, - query_start_loc=query_start_loc_post, - query_lens=query_lens_post, - seq_lens=seq_lens_post, - seq_lens_list=seq_lens_post.tolist(), - max_query_len=max_query_len_post, - slot_mapping=slot_mapping_post, - is_only_prefill=is_only_prefill_post, - attn_state=attn_state_post, - attn_mask=attn_mask_post, - num_input_tokens=attn_metadata.num_input_tokens - token_index, - enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, - ) + attention_metadata_post = _metadata_cls( + num_actual_tokens=attn_metadata.num_actual_tokens - token_index, + block_tables=block_table_post, + query_start_loc=query_start_loc_post, + query_lens=query_lens_post, + seq_lens=seq_lens_post, + seq_lens_list=seq_lens_post.tolist(), + max_query_len=max_query_len_post, + slot_mapping=slot_mapping_post, + is_only_prefill=is_only_prefill_post, + attn_state=attn_state_post, + attn_mask=attn_mask_post, + num_input_tokens=attn_metadata.num_input_tokens - token_index, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, + ) - return [attention_metadata_pre, attention_metadata_post] \ No newline at end of file + return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index af9c4082d6..04cb53f96c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -17,7 +17,7 @@ import math import os -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.distributed as dist @@ -338,7 +338,7 @@ def fused_experts_with_all2all( row_idx_len, dtype=torch.int32, device=device).view(top_k, -1).permute( - 1, 0).contiguous()) + 1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -377,7 +377,7 @@ def fused_experts_with_all2all( row_idx_len, dtype=torch.int32, device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() + top_k, -1).permute(1, 0).contiguous() hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, @@ -480,8 +480,9 @@ def fused_experts_with_all2all_buffer( expert_idx=topk_ids, active_num=num_tokens) - max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * max_model_len * - get_dp_group().world_size // ep_group.world_size + 1) * top_k * 2 + max_row_per_ep_rank = ( + -(-global_batch_size // ep_group.world_size) * max_model_len * + get_dp_group().world_size // ep_group.world_size + 1) * top_k * 2 expert_idx_buffer_scatter, unpad_indices = process_topk_ids( expanded_expert_idx, global_num_experts, ep_group.world_size, max_row_per_ep_rank, num_tokens, top_k) @@ -493,9 +494,9 @@ def fused_experts_with_all2all_buffer( (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) hidden_states_pad_idx[ expert_idx_buffer_scatter != global_num_experts] = torch.arange( - non_pad_len, - dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] expert_idx_buffer_gather = torch.empty_like( @@ -570,18 +571,15 @@ def fused_experts_with_all2all_buffer( return final_hidden_states -def fused_experts_with_all2allv(token_dispatcher, probs, routing_map, hidden_states: torch.Tensor, - w1: torch.Tensor, +def fused_experts_with_all2allv(token_dispatcher, probs, routing_map, + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor): # Enable moe alltoallv, it's a balanced policy for precision and efficiency. - (share_experts_output, dispatched_input, tokens_per_expert) = token_dispatcher.token_permutation( - hidden_states, probs, routing_map - ) + (share_experts_output, dispatched_input, + tokens_per_expert) = token_dispatcher.token_permutation( + hidden_states, probs, routing_map) - expert_output = apply_mlp(dispatched_input, - w1, - w2, - tokens_per_expert) + expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert) output, mlp_bias = token_dispatcher.token_unpermutation(expert_output) return output @@ -1004,12 +1002,13 @@ def apply( ep_group=get_ep_group()) elif fused_moe_state == FusedMoEState.All2AllSeq: token_dispatcher = kwargs.get('token_dispatcher') - return fused_experts_with_all2allv(token_dispatcher=token_dispatcher, - probs=topk_weights, - routing_map=topk_ids, - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight) + return fused_experts_with_all2allv( + token_dispatcher=token_dispatcher, + probs=topk_weights, + routing_map=topk_ids, + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight) else: return fused_experts_with_all2all(hidden_states=x, w1=layer.w13_weight, @@ -1159,17 +1158,20 @@ def __init__( self.quant_method, AscendUnquantizedFusedMoEMethod): self.reduce_results = False moe_dispatcher_config = ( - MoeDispatcherConfig().set_num_moe_experts(self.global_num_experts) - .set_num_local_experts(self.local_num_experts) - .set_moe_router_topk(top_k) - .set_group_topk(topk_group) - .set_num_groups(num_expert_group) - .set_expert_bias(e_score_correction_bias) - .set_scaling_factor(1.0).build()) - self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(moe_dispatcher_config) + MoeDispatcherConfig().set_num_moe_experts( + self.global_num_experts).set_num_local_experts( + self.local_num_experts).set_moe_router_topk( + top_k).set_group_topk(topk_group). + set_num_groups(num_expert_group).set_expert_bias( + e_score_correction_bias).set_scaling_factor(1.0).build()) + self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher( + moe_dispatcher_config) if envs_ascend.VLLM_ASCEND_ENABLE_DBO: - token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(moe_dispatcher_config) - self.token_dispatchers = [self.token_dispatcher, token_dispatcher1] + token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher( + moe_dispatcher_config) + self.token_dispatchers = [ + self.token_dispatcher, token_dispatcher1 + ] def forward(self, hidden_states: torch.Tensor, @@ -1208,7 +1210,8 @@ def forward(self, shared_hidden_states = shared_experts(hidden_states) attn_metadata = get_forward_context().attn_metadata - mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and getattr(attn_metadata, "decode", None) is not None else None + mc2_mask = attn_metadata.decode.mc2_mask if attn_metadata is not None and getattr( + attn_metadata, "decode", None) is not None else None tp_size = get_tensor_model_parallel_world_size() if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: @@ -1273,8 +1276,7 @@ def forward(self, quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=mc2_mask, - token_dispatcher=self.token_dispatcher - ) + token_dispatcher=self.token_dispatcher) if shared_experts: if isinstance(e_hidden_states, tuple): diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 6906577778..aa6143b8a2 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -32,9 +32,8 @@ reduce_scatter_last_dim_to_tensor_parallel_region) from vllm_ascend.ops.comm_utils import async_all_to_all from vllm_ascend.ops.moe_dispatcher.moe_utils import ( - get_capacity, permute, sort_chunks_by_idxs, topk_softmax_with_capacity, + get_capacity, permute, topk_softmax_with_capacity, unpermute) - """ We use the following notation throughout this file: H: hidden size B: micro batch size @@ -147,7 +146,6 @@ def tp_ep_size(self): class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher): overlap_stream = None - """ The implementation of the AlltoAll-based token dispatcher, which handles token dispatching on the sequence level instead of token level. The core of this implementation @@ -178,20 +176,18 @@ def __init__(self, config: MoeDispatcherConfig): device=torch.npu.current_device(), ) - local_expert_indices_offset = ( - self.ep_rank * self.num_local_experts - ) + local_expert_indices_offset = (self.ep_rank * self.num_local_experts) self.local_expert_indices = [ - local_expert_indices_offset + i for i in range(self.num_local_experts) + local_expert_indices_offset + i + for i in range(self.num_local_experts) ] - assert ( - len(self.local_expert_indices) == self.num_local_experts - ), "Invalid local expert indices" + assert (len(self.local_expert_indices) == self.num_local_experts + ), "Invalid local expert indices" for i in range(len(self.local_expert_indices) - 1): - assert ( - self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1 - ), "local_expert_indices must be continuous" + assert (self.local_expert_indices[i] == + self.local_expert_indices[i + 1] - + 1), "local_expert_indices must be continuous" self.probs = None self.input_splits = None self.output_splits = None @@ -205,12 +201,11 @@ def __init__(self, config: MoeDispatcherConfig): input_chunk_idxs = torch.arange(self.num_experts) # [num_local_experts, ep_size]. Sort the input chunks by local experts. self.sort_input_by_local_experts = input_chunk_idxs.reshape( - -1, self.num_local_experts - ).T.ravel() + -1, self.num_local_experts).T.ravel() # [ep_size, num_local_experts]. Restore the output chunks by local experts. self.restore_output_by_local_experts = input_chunk_idxs.reshape( - self.num_local_experts, -1 - ).T.ravel().to(torch.device("cpu"), non_blocking=True) + self.num_local_experts, -1).T.ravel().to(torch.device("cpu"), + non_blocking=True) # Token drop and padding. # We need to keep track of the token num if we drop tokens without padding them. @@ -240,7 +235,9 @@ def __init__(self, config: MoeDispatcherConfig): self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream - def preprocess(self, indices: torch.Tensor, with_sync=True) -> torch.Tensor: + def preprocess(self, + indices: torch.Tensor, + with_sync=True) -> torch.Tensor: """ Preprocess routing map for AlltoAll communication and token permutation. This method computes the number of tokens assigned to each expert based on @@ -255,9 +252,10 @@ def preprocess(self, indices: torch.Tensor, with_sync=True) -> torch.Tensor: Returns: torch.Tensor: Tensor containing the number of tokens assigned to local expert. """ - num_local_tokens_per_expert = torch.histc( - indices, bins=self.num_experts, min=0, max=self.num_experts - ) + num_local_tokens_per_expert = torch.histc(indices, + bins=self.num_experts, + min=0, + max=self.num_experts) # num_local_tokens_per_expert: [num_experts] @@ -272,18 +270,19 @@ def preprocess(self, indices: torch.Tensor, with_sync=True) -> torch.Tensor: ) self.num_out_tokens = self.capacity * self.num_experts num_tokens_per_local_expert = torch.full( - (self.num_local_experts,), self.capacity * self.ep_size, dtype=torch.long - ) + (self.num_local_experts, ), + self.capacity * self.ep_size, + dtype=torch.long) self.num_global_tokens_per_local_expert_cpu = torch.full( - (self.num_experts * self.tp_ep_size,), self.capacity, dtype=torch.long - ) + (self.num_experts * self.tp_ep_size, ), + self.capacity, + dtype=torch.long) return num_tokens_per_local_expert elif self.config.moe_expert_capacity_factor is not None: # Token drop but no pad. A synchronization is needed before the first # permutation to get the `num_out_tokens` CPU value. self.num_out_tokens = num_local_tokens_per_expert.sum().to( - torch.device("cpu"), non_blocking=True - ) + torch.device("cpu"), non_blocking=True) self.cuda_sync_point = "before_permutation_1" else: # Dropless @@ -301,23 +300,18 @@ def preprocess(self, indices: torch.Tensor, with_sync=True) -> torch.Tensor: # =================================================== # Calculate input_splits, output_splits for alltoall-v. # =================================================== - self.input_splits = ( - num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts) - .sum(axis=1) - .to(torch.device("cpu"), non_blocking=True) - .numpy() - ) + self.input_splits = (num_local_tokens_per_expert.reshape( + ep_size, self.num_local_experts).sum(axis=1).to( + torch.device("cpu"), non_blocking=True).numpy()) num_global_tokens_per_expert = gather_from_sequence_parallel_region( - num_local_tokens_per_expert, group=self.ep_group - ).reshape(ep_size, self.num_experts) - self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[0]: - self.local_expert_indices[-1] + 1] - self.output_splits = ( - self.num_global_tokens_per_local_expert.sum(axis=-1) - .to(torch.device("cpu"), non_blocking=True) - .numpy() - ) - num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(axis=0) + num_local_tokens_per_expert, + group=self.ep_group).reshape(ep_size, self.num_experts) + self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ + 0]:self.local_expert_indices[-1] + 1] + self.output_splits = (self.num_global_tokens_per_local_expert.sum( + axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) + num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( + axis=0) # =================================================== # num_global_tokens_per_expert: [ep_size, num_experts] # num_global_tokens_per_local_expert: [ep_size, num_local_experts] @@ -325,15 +319,14 @@ def preprocess(self, indices: torch.Tensor, with_sync=True) -> torch.Tensor: # =================================================== else: self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape( - -1, self.num_experts - ) + -1, self.num_experts) num_tokens_per_local_expert = num_local_tokens_per_expert if self.num_local_experts > 1 and with_sync: self.cuda_sync_point = "no_sync" self.global_input_tokens_local_experts_indices = torch.repeat_interleave( - self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel() - ) + self.expert_ids_per_ep_rank, + self.num_global_tokens_per_local_expert.ravel()) # self.num_global_tokens_per_local_expert_cpu = ( # self.num_global_tokens_per_local_expert.view(-1, self.num_local_experts).to( @@ -363,8 +356,7 @@ def routing(self, probs): num_groups=self.config.num_groups, expert_bias=self.config.expert_bias, scaling_factor=self.config.scaling_factor, - score_function=score_function - ) + score_function=score_function) self.top_indices = top_indices return scores, routing_map @@ -378,10 +370,10 @@ def preprocess_overlap(self, routing_map): return num_tokens_per_local_expert def token_permutation( - self, - hidden_states: torch.Tensor, - probs: torch.Tensor, - routing_map: torch.Tensor, + self, + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, ): """ Dispatch tokens to local experts using AlltoAllSeq communication. @@ -410,7 +402,8 @@ def alltoall_token_permutation1(hidden_states, routing_map): hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) tokens_per_expert = self.preprocess(routing_map) if self.tp_ep_size > 1: - hidden_states = all_to_all_sp2hp(hidden_states, group=self.tp_ep_group) + hidden_states = all_to_all_sp2hp(hidden_states, + group=self.tp_ep_group) self.hidden_shape_before_permute = hidden_states.shape if self.cuda_sync_point == "before_permutation_1": @@ -460,15 +453,13 @@ def alltoall_token_permutation2(global_input_tokens): if self.num_local_experts > 1: global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( global_input_tokens, - self.global_input_tokens_local_experts_indices - ) + self.global_input_tokens_local_experts_indices) # Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens. # global_input_tokens: [SEQL, H/TP] -> [SEQL, H] if self.tp_ep_size > 1 and self.config.moe_grouped_gemm: global_input_tokens = all_gather_last_dim_from_tensor_parallel_region( - global_input_tokens, self.tp_ep_group - ) + global_input_tokens, self.tp_ep_group) if self.cuda_sync_point == "before_finish": torch.npu.current_stream().synchronize() @@ -479,14 +470,12 @@ def alltoall_token_permutation2(global_input_tokens): return share_experts_output, global_input_tokens, tokens_per_expert - def preprocess_and_permtute1( - self, - hidden_states: torch.Tensor, - probs: torch.Tensor, - routing_map: torch.Tensor, - shared_experts=None, - shared_experts_input: torch.Tensor = None - ): + def preprocess_and_permtute1(self, + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + shared_experts=None, + shared_experts_input: torch.Tensor = None): self.hidden_shape = hidden_states.shape self.probs = probs self.top_indices = routing_map @@ -528,8 +517,8 @@ def preprocess_and_permtute1( if self.num_local_experts > 1: self.cuda_sync_point = "no_sync" self.global_input_tokens_local_experts_indices = torch.repeat_interleave( - self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel() - ) + self.expert_ids_per_ep_rank, + self.num_global_tokens_per_local_expert.ravel()) self.cached_permutated_local_input_tokens = hidden_states self.tokens_per_expert = tokens_per_expert @@ -558,23 +547,17 @@ def permute2(self): if self.num_local_experts > 1: global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( self.cached_global_input_tokens, - self.global_input_tokens_local_experts_indices - ) + self.global_input_tokens_local_experts_indices) self.cached_global_input_tokens.untyped_storage().resize_(0) self.cached_global_input_tokens = None return global_input_tokens, self.tokens_per_expert - def unpermute1( - self, - hidden_states: torch.Tensor - ): + def unpermute1(self, hidden_states: torch.Tensor): # Unpermutation 2: expert output to AlltoAll input if hidden_states.shape[0] > 0 and self.num_local_experts > 1: hidden_states = torch_npu.npu_moe_token_unpermute( - hidden_states, - self.reversed_global_input_permutation_mapping - ) + hidden_states, self.reversed_global_input_permutation_mapping) self.cached_global_output_tokens = hidden_states self.reversed_global_input_permutation_mapping = None @@ -583,11 +566,8 @@ def combine_alltoall(self): # Perform expert parallel AlltoAll communication # hidden_states: [SEQL, H] -> [SEQL, H/TP] _, self.cached_local_output_tokens, handle = async_all_to_all( - self.cached_global_output_tokens, - self.input_splits, - self.output_splits, - ep_group - ) + self.cached_global_output_tokens, self.input_splits, + self.output_splits, ep_group) handle.wait() self.cached_global_output_tokens.untyped_storage().resize_(0) self.cached_global_output_tokens = None @@ -597,10 +577,10 @@ def combine_alltoall(self): def unpermute2(self): output = torch_npu.npu_moe_token_unpermute( permuted_tokens=self.cached_local_output_tokens, - sorted_indices=self.reversed_local_input_permutation_mapping.to(torch.int32), + sorted_indices=self.reversed_local_input_permutation_mapping.to( + torch.int32), probs=self.probs, - restore_shape=self.hidden_shape_before_permute - ) + restore_shape=self.hidden_shape_before_permute) output = output.view(self.hidden_shape) @@ -611,11 +591,9 @@ def unpermute2(self): return output - def token_unpermutation( - self, - hidden_states: torch.Tensor, - bias: torch.Tensor = None - ): + def token_unpermutation(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): """ Reverse the token permutation to restore the original order. @@ -634,14 +612,14 @@ def alltoall_token_unpermutation1(hidden_states): # Perform tensor parallel Reduce-Scatter # hidden_states: [SEQL, H] -> [SEQL, H/TP] if self.tp_ep_size > 1: - hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region(hidden_states, group=self.tp_ep_group) + hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region( + hidden_states, group=self.tp_ep_group) # Unpermutation 2: expert output to AlltoAll input if hidden_states.shape[0] > 0 and self.num_local_experts > 1: hidden_states = torch_npu.npu_moe_token_unpermute( hidden_states, - self.reversed_global_input_permutation_mapping - ) + self.reversed_global_input_permutation_mapping) # hidden_states = sort_chunks_by_idxs( # hidden_states, # self.num_global_tokens_per_local_expert_cpu.T.ravel(), @@ -655,11 +633,7 @@ def alltoall_token_unpermutation1(hidden_states): # Perform expert parallel AlltoAll communication # hidden_states: [SEQL, H] -> [SEQL, H/TP] _, permutated_local_input_tokens, handle = async_all_to_all( - hidden_states, - self.input_splits, - self.output_splits, - ep_group - ) + hidden_states, self.input_splits, self.output_splits, ep_group) handle.wait() hidden_states.untyped_storage().resize_(0) @@ -670,10 +644,10 @@ def alltoall_token_unpermutation2(permutated_local_input_tokens): # .view(-1, self.config.moe_router_topk)) output = torch_npu.npu_moe_token_unpermute( permuted_tokens=permutated_local_input_tokens, - sorted_indices=self.reversed_local_input_permutation_mapping.to(torch.int32), + sorted_indices=self. + reversed_local_input_permutation_mapping.to(torch.int32), probs=self.probs, - restore_shape=self.hidden_shape_before_permute - ) + restore_shape=self.hidden_shape_before_permute) else: output = unpermute( permutated_local_input_tokens, From af85566b1655629991e0090277952ad26a39272a Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Wed, 9 Jul 2025 23:48:15 +0800 Subject: [PATCH 34/60] handle code clean Signed-off-by: weijinqian_v1 --- vllm_ascend/models/qwen3_dbo.py | 2 -- vllm_ascend/multistream/ms_split.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 35877cb0ae..73e04e0342 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -267,11 +267,9 @@ def _forward_ms_layer_alltoallv_finegrained( ] * num_micro_batchs tokens_per_expert = [None] * num_micro_batchs dispatched_input = [None] * num_micro_batchs - shared_expert_output = [None] * num_micro_batchs router_expert_output = [None] * num_micro_batchs chunked_hidden_states_sizes = [None] * num_micro_batchs token_dispatchers = self.mlp.experts.token_dispatchers - has_shared_expert = hasattr(self.mlp, 'shared_experts') def discard_tensor(tensor): if isinstance(tensor, torch.Tensor): diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 1c765d9bcd..40281a918b 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -293,7 +293,7 @@ def model_input_split_v1_attn( token_index) is_only_prefill_pre = is_only_prefill_post = attn_metadata.is_only_prefill - has_prefill_pre, has_prefill_post = torch.any( + has_prefill_pre, _ = torch.any( query_lens_pre > 1).item(), torch.any(query_lens_post > 1).item() if not attn_metadata.is_only_prefill: From d4ad734514b492196b1a278281cbb5baa7917ab9 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Thu, 10 Jul 2025 00:11:43 +0800 Subject: [PATCH 35/60] handle code clean Signed-off-by: weijinqian_v1 --- tests/ut/test_distributed_tensor_parallel.py | 8 +++-- tests/ut/test_moe_util.py | 11 ++++-- tests/ut/test_token_dispatcher.py | 11 +++--- vllm_ascend/models/__init__.py | 2 +- vllm_ascend/models/deepseek_dbo.py | 8 ++--- vllm_ascend/models/moe_block.py | 12 +++---- vllm_ascend/models/qwen3_dbo.py | 36 ++++++++++--------- vllm_ascend/multistream/ms_split.py | 3 +- vllm_ascend/ops/fused_moe.py | 4 +-- .../ops/moe_dispatcher/token_dispatcher.py | 6 ++-- 10 files changed, 55 insertions(+), 46 deletions(-) diff --git a/tests/ut/test_distributed_tensor_parallel.py b/tests/ut/test_distributed_tensor_parallel.py index ff4b8cde64..7072e295a0 100644 --- a/tests/ut/test_distributed_tensor_parallel.py +++ b/tests/ut/test_distributed_tensor_parallel.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -import pytest -import torch import importlib from unittest.mock import MagicMock, patch + +import pytest +import torch + from vllm_ascend.distributed.tensor_parallel import ( _gather_along_first_dim, _gather_along_last_dim, _reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, - all_to_all_sp2hp, all_to_all_hp2sp) + all_to_all_hp2sp, all_to_all_sp2hp) # 测试用的固定数据 diff --git a/tests/ut/test_moe_util.py b/tests/ut/test_moe_util.py index c88d2071ec..cdf08d7bf0 100644 --- a/tests/ut/test_moe_util.py +++ b/tests/ut/test_moe_util.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -import torch -import pytest import math + +import pytest +import torch + +from vllm_ascend.ops.moe_dispatcher.moe_utils import ( + get_capacity, group_limited_topk, permute, sort_chunks_by_idxs, + topk_softmax_with_capacity, unpermute) + import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa -from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, group_limited_topk, unpermute, sort_chunks_by_idxs class TestMoeUtils: diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py index a5d313cf12..44b625e037 100644 --- a/tests/ut/test_token_dispatcher.py +++ b/tests/ut/test_token_dispatcher.py @@ -2,14 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -import torch import pytest -import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa - +import torch from pytest_mock import MockerFixture + +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) from vllm_ascend.utils import adapt_patch # noqa E402 -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import MoeDispatcherConfig, MoEAlltoAllSeqOverLapDispatcher +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + + adapt_patch(True) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 1c989cb5c6..47380286c0 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -8,10 +8,10 @@ def register_model(): from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 + from .moe_block import AscendSparseMoeBlock # noqa: F401 from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 - from .moe_block import AscendSparseMoeBlock # noqa: F401 from .qwen3 import CustomQwen3ForCausalLM # noqa: F401 ModelRegistry.register_model( diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 6f02dc2a95..d9abb9be87 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -34,8 +34,7 @@ from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_dp_group, get_ep_group @@ -55,8 +54,9 @@ from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend -from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.distributed.tensor_parallel import \ + gather_from_sequence_parallel_region from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer, CustomDeepseekV2MLP, CustomDeepseekV2MoE) @@ -69,9 +69,9 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) +from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.quantization.w8a8_dynamic import ( AscendW8A8DynamicLinearMethod, apply_mlp) -from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO diff --git a/vllm_ascend/models/moe_block.py b/vllm_ascend/models/moe_block.py index 1129aa6716..3a191411d6 100644 --- a/vllm_ascend/models/moe_block.py +++ b/vllm_ascend/models/moe_block.py @@ -19,22 +19,18 @@ import torch import vllm.model_executor.models.qwen3_moe as qwen3 - from torch import nn +from transformers import PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.distributed import (get_tensor_model_parallel_world_size, - get_tp_group) -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed.parallel_state import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm_ascend.ascend_config import get_ascend_config -from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE -from transformers import PretrainedConfig -from vllm.model_executor.layers.quantization import QuantizationConfig - class AscendSparseMoeBlock(nn.Module): diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 73e04e0342..e0982cf89e 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -21,42 +21,44 @@ # limitations under the License. # # Adapted from # """Inference-only Qwen3 model.""" -from typing import Optional, Union, List from types import SimpleNamespace +from typing import List, Optional, Union import torch import torch_npu from torch import nn from transformers import PretrainedConfig - -from vllm.model_executor.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeModel -from vllm.config import CacheConfig, VllmConfig -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.attention import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group) from vllm.forward_context import get_forward_context, set_forward_context -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, \ - get_pp_group -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeModel) from vllm.model_executor.models.utils import ( make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.sequence import IntermediateTensors -from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.compilation.decorators import support_torch_compile +import vllm_ascend.envs as envs_ascend +from vllm_ascend.distributed.tensor_parallel import \ + gather_from_sequence_parallel_region +from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( advance_step_multistream_layer_context, get_multistream_layer_context) -from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, MultiStreamPreTransformerLayer) from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.ops.fused_moe import select_experts, apply_mlp -from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region -import vllm_ascend.envs as envs_ascend +from vllm_ascend.ops.fused_moe import apply_mlp, select_experts VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 40281a918b..605e6065c2 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -4,7 +4,8 @@ import numpy as np import torch -from vllm_ascend.attention.attention_v1 import AscendAttentionState, AscendMetadata +from vllm_ascend.attention.attention_v1 import (AscendAttentionState, + AscendMetadata) from .base import MSAttentionMetadataSplitConfig diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 570dc38140..8679fd1f2b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -40,11 +40,11 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_ascend_soc_version, npu_stream_switch, npu_wait_tensor) -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index aa6143b8a2..0d5e96ac31 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -24,16 +24,16 @@ import torch import torch_npu - from vllm.distributed.parallel_state import get_ep_group + from vllm_ascend.distributed.tensor_parallel import ( all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, all_to_all_sp2hp, gather_from_sequence_parallel_region, reduce_scatter_last_dim_to_tensor_parallel_region) from vllm_ascend.ops.comm_utils import async_all_to_all from vllm_ascend.ops.moe_dispatcher.moe_utils import ( - get_capacity, permute, topk_softmax_with_capacity, - unpermute) + get_capacity, permute, topk_softmax_with_capacity, unpermute) + """ We use the following notation throughout this file: H: hidden size B: micro batch size From 3b7269a9c6802cff5d657f44b37295297c5e238e Mon Sep 17 00:00:00 2001 From: harygo22 Date: Thu, 10 Jul 2025 14:31:16 +0800 Subject: [PATCH 36/60] fix comment Signed-off-by: duyangkai --- tests/ut/test_distributed_tensor_parallel.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/ut/test_distributed_tensor_parallel.py b/tests/ut/test_distributed_tensor_parallel.py index 7072e295a0..ae540cc08f 100644 --- a/tests/ut/test_distributed_tensor_parallel.py +++ b/tests/ut/test_distributed_tensor_parallel.py @@ -13,7 +13,6 @@ all_to_all_hp2sp, all_to_all_sp2hp) -# 测试用的固定数据 @pytest.fixture def test_tensor(): return torch.randn(8, 16) @@ -29,7 +28,6 @@ def mock_group(): return MagicMock() -# 模拟分布式环境 @pytest.fixture(autouse=True) def mock_dist(): with patch("torch.distributed") as mock: @@ -39,12 +37,11 @@ def mock_dist(): class TestDistributedCommunication: - """测试分布式通信函数""" @pytest.mark.parametrize("world_size", [1, 4]) def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, world_size): - """测试_gather_along_first_dim""" + """test _gather_along_first_dim""" mock_dist.get_world_size.return_value = world_size result = _gather_along_first_dim(test_tensor, mock_group) @@ -56,7 +53,7 @@ def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, def test_gather_along_first_dim_unequal_split(self, test_tensor, mock_group): - """测试不等分分割情况""" + """test unequal split""" output_split_sizes = [5, 10, 15, 2] result = _gather_along_first_dim(test_tensor, mock_group, output_split_sizes) @@ -65,7 +62,7 @@ def test_gather_along_first_dim_unequal_split(self, test_tensor, @pytest.mark.parametrize("world_size", [1, 4]) def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, mock_dist, world_size): - """测试_gather_along_last_dim""" + """test _gather_along_last_dim""" mock_dist.get_world_size.return_value = world_size result = _gather_along_last_dim(test_tensor_last_dim, mock_group) @@ -100,7 +97,7 @@ def test_reduce_scatter_along_last_dim(self, mock_group): ]) def test_wrapper_functions(self, mock_group, func, input_shape, expected_shape): - """测试包装函数""" + """test wrapper funcs""" mod = importlib.import_module( 'vllm_ascend.distributed.tensor_parallel') globals = mod.__dict__ From deb431906eba593ca559939b73ccfb6db887f466 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Thu, 10 Jul 2025 00:11:43 +0800 Subject: [PATCH 37/60] handle code clean Signed-off-by: weijinqian_v1 --- tests/ut/test_distributed_tensor_parallel.py | 8 +++-- tests/ut/test_moe_util.py | 11 ++++-- tests/ut/test_token_dispatcher.py | 11 +++--- vllm_ascend/models/__init__.py | 2 +- vllm_ascend/models/deepseek_dbo.py | 8 ++--- vllm_ascend/models/moe_block.py | 12 +++---- vllm_ascend/models/qwen3_dbo.py | 36 ++++++++++--------- vllm_ascend/multistream/ms_split.py | 3 +- vllm_ascend/ops/fused_moe.py | 4 +-- .../ops/moe_dispatcher/token_dispatcher.py | 6 ++-- 10 files changed, 55 insertions(+), 46 deletions(-) diff --git a/tests/ut/test_distributed_tensor_parallel.py b/tests/ut/test_distributed_tensor_parallel.py index ff4b8cde64..7072e295a0 100644 --- a/tests/ut/test_distributed_tensor_parallel.py +++ b/tests/ut/test_distributed_tensor_parallel.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -import pytest -import torch import importlib from unittest.mock import MagicMock, patch + +import pytest +import torch + from vllm_ascend.distributed.tensor_parallel import ( _gather_along_first_dim, _gather_along_last_dim, _reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, - all_to_all_sp2hp, all_to_all_hp2sp) + all_to_all_hp2sp, all_to_all_sp2hp) # 测试用的固定数据 diff --git a/tests/ut/test_moe_util.py b/tests/ut/test_moe_util.py index c88d2071ec..cdf08d7bf0 100644 --- a/tests/ut/test_moe_util.py +++ b/tests/ut/test_moe_util.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -import torch -import pytest import math + +import pytest +import torch + +from vllm_ascend.ops.moe_dispatcher.moe_utils import ( + get_capacity, group_limited_topk, permute, sort_chunks_by_idxs, + topk_softmax_with_capacity, unpermute) + import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa -from vllm_ascend.ops.moe_dispatcher.moe_utils import permute, get_capacity, topk_softmax_with_capacity, group_limited_topk, unpermute, sort_chunks_by_idxs class TestMoeUtils: diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py index a5d313cf12..44b625e037 100644 --- a/tests/ut/test_token_dispatcher.py +++ b/tests/ut/test_token_dispatcher.py @@ -2,14 +2,17 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -import torch import pytest -import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa - +import torch from pytest_mock import MockerFixture + +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) from vllm_ascend.utils import adapt_patch # noqa E402 -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import MoeDispatcherConfig, MoEAlltoAllSeqOverLapDispatcher +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + + adapt_patch(True) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 1c989cb5c6..47380286c0 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -8,10 +8,10 @@ def register_model(): from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 + from .moe_block import AscendSparseMoeBlock # noqa: F401 from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 - from .moe_block import AscendSparseMoeBlock # noqa: F401 from .qwen3 import CustomQwen3ForCausalLM # noqa: F401 ModelRegistry.register_model( diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 6f02dc2a95..d9abb9be87 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -34,8 +34,7 @@ from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_dp_group, get_ep_group @@ -55,8 +54,9 @@ from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend -from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.distributed.tensor_parallel import \ + gather_from_sequence_parallel_region from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer, CustomDeepseekV2MLP, CustomDeepseekV2MoE) @@ -69,9 +69,9 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) +from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.quantization.w8a8_dynamic import ( AscendW8A8DynamicLinearMethod, apply_mlp) -from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO diff --git a/vllm_ascend/models/moe_block.py b/vllm_ascend/models/moe_block.py index 1129aa6716..3a191411d6 100644 --- a/vllm_ascend/models/moe_block.py +++ b/vllm_ascend/models/moe_block.py @@ -19,22 +19,18 @@ import torch import vllm.model_executor.models.qwen3_moe as qwen3 - from torch import nn +from transformers import PretrainedConfig from vllm.attention import AttentionMetadata -from vllm.distributed import (get_tensor_model_parallel_world_size, - get_tp_group) -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed.parallel_state import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm_ascend.ascend_config import get_ascend_config -from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE -from transformers import PretrainedConfig -from vllm.model_executor.layers.quantization import QuantizationConfig - class AscendSparseMoeBlock(nn.Module): diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 73e04e0342..e0982cf89e 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -21,42 +21,44 @@ # limitations under the License. # # Adapted from # """Inference-only Qwen3 model.""" -from typing import Optional, Union, List from types import SimpleNamespace +from typing import List, Optional, Union import torch import torch_npu from torch import nn from transformers import PretrainedConfig - -from vllm.model_executor.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeModel -from vllm.config import CacheConfig, VllmConfig -from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.attention import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group) from vllm.forward_context import get_forward_context, set_forward_context -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group, \ - get_pp_group -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeModel) from vllm.model_executor.models.utils import ( make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.sequence import IntermediateTensors -from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.compilation.decorators import support_torch_compile +import vllm_ascend.envs as envs_ascend +from vllm_ascend.distributed.tensor_parallel import \ + gather_from_sequence_parallel_region +from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( advance_step_multistream_layer_context, get_multistream_layer_context) -from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, MultiStreamPreTransformerLayer) from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.ops.fused_moe import select_experts, apply_mlp -from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region -import vllm_ascend.envs as envs_ascend +from vllm_ascend.ops.fused_moe import apply_mlp, select_experts VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 40281a918b..605e6065c2 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -4,7 +4,8 @@ import numpy as np import torch -from vllm_ascend.attention.attention_v1 import AscendAttentionState, AscendMetadata +from vllm_ascend.attention.attention_v1 import (AscendAttentionState, + AscendMetadata) from .base import MSAttentionMetadataSplitConfig diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 570dc38140..8679fd1f2b 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -40,11 +40,11 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_ascend_soc_version, npu_stream_switch, npu_wait_tensor) -from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index aa6143b8a2..0d5e96ac31 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -24,16 +24,16 @@ import torch import torch_npu - from vllm.distributed.parallel_state import get_ep_group + from vllm_ascend.distributed.tensor_parallel import ( all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, all_to_all_sp2hp, gather_from_sequence_parallel_region, reduce_scatter_last_dim_to_tensor_parallel_region) from vllm_ascend.ops.comm_utils import async_all_to_all from vllm_ascend.ops.moe_dispatcher.moe_utils import ( - get_capacity, permute, topk_softmax_with_capacity, - unpermute) + get_capacity, permute, topk_softmax_with_capacity, unpermute) + """ We use the following notation throughout this file: H: hidden size B: micro batch size From a8b3e156edfcd0baec48574717054b6b3b338d0f Mon Sep 17 00:00:00 2001 From: duyangkai Date: Thu, 10 Jul 2025 19:42:02 +0800 Subject: [PATCH 38/60] fix init Signed-off-by: duyangkai --- vllm_ascend/models/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 47380286c0..c525849cb8 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -58,3 +58,6 @@ def register_model(): ModelRegistry.register_model( "Qwen3MoeForCausalLM", "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") + + ModelRegistry.register_model( + "Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM") From d290b7d22ebd5b811423ea8ad47505cedef659d7 Mon Sep 17 00:00:00 2001 From: duyangkai Date: Thu, 10 Jul 2025 22:07:57 +0800 Subject: [PATCH 39/60] remove files & move sparsemoeblock to ops Signed-off-by: duyangkai --- tests/ut/test_moe_util.py | 169 -------- vllm_ascend/models/__init__.py | 1 - vllm_ascend/models/moe_block.py | 114 ------ vllm_ascend/models/qwen3_dbo.py | 18 +- vllm_ascend/models/qwen3_moe.py | 5 +- vllm_ascend/ops/fused_moe.py | 78 ++++ vllm_ascend/ops/moe_dispatcher/moe_utils.py | 379 ------------------ .../ops/moe_dispatcher/token_dispatcher.py | 170 ++------ 8 files changed, 118 insertions(+), 816 deletions(-) delete mode 100644 tests/ut/test_moe_util.py delete mode 100644 vllm_ascend/models/moe_block.py delete mode 100644 vllm_ascend/ops/moe_dispatcher/moe_utils.py diff --git a/tests/ut/test_moe_util.py b/tests/ut/test_moe_util.py deleted file mode 100644 index cdf08d7bf0..0000000000 --- a/tests/ut/test_moe_util.py +++ /dev/null @@ -1,169 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -import math - -import pytest -import torch - -from vllm_ascend.ops.moe_dispatcher.moe_utils import ( - get_capacity, group_limited_topk, permute, sort_chunks_by_idxs, - topk_softmax_with_capacity, unpermute) - -import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa - - - -class TestMoeUtils: - - @pytest.fixture - def setup(self): - self.num_tokens = 16 - self.num_experts = 4 - self.hidden_size = 8 - self.topk = 2 - self.capacity_factor = 1.0 - self.group_topk = 2 - self.num_groups = 2 - self.scaling_factor = 1.0 - - def test_group_limited_topk(self, setup): - # Test group-limited topk routing - scores = torch.randn(self.num_tokens, self.num_experts) - probs, indices = group_limited_topk(scores, - topk=self.topk, - num_tokens=self.num_tokens, - num_experts=self.num_experts, - num_groups=self.num_groups, - group_topk=self.group_topk) - - assert probs.shape == (self.num_tokens, self.topk) - assert indices.shape == (self.num_tokens, self.topk) - assert torch.all(indices < self.num_experts) - - @pytest.mark.parametrize("score_function", ["softmax"]) - def test_topk_softmax_with_capacity(self, setup, score_function): - # Test topk softmax with capacity - logits = torch.randn(self.num_tokens, self.num_experts) - - # Test without capacity - probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( - logits, topk=self.topk, score_function=score_function) - assert probs.shape == (self.num_tokens, self.num_experts) - assert routing_map.shape == (self.num_tokens, self.num_experts) - assert tokens_per_expert.shape == (self.num_experts, ) - - # Test with group routing - probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( - logits, - topk=self.topk, - num_groups=self.num_groups, - group_topk=self.group_topk, - score_function=score_function) - assert probs.shape == (self.num_tokens, self.num_experts) - - def test_get_capacity(self, setup): - # Test capacity calculation - capacity = get_capacity(num_tokens=self.num_tokens, - num_experts=self.num_experts, - capacity_factor=self.capacity_factor) - expected = math.ceil( - (self.num_tokens / self.num_experts) * self.capacity_factor) - assert capacity == expected - - # Test with min capacity - min_capacity = 5 - capacity = get_capacity(num_tokens=self.num_tokens, - num_experts=self.num_experts, - capacity_factor=self.capacity_factor, - min_capacity=min_capacity) - assert capacity == min_capacity - - def test_permute(self, setup): - # Test token permutation - tokens = torch.randn(self.num_tokens, self.hidden_size) - routing_map = torch.randint( - 0, 2, (self.num_tokens, self.num_experts)).bool() - - # Basic permutation - permuted_tokens, sorted_indices = permute(tokens, routing_map) - assert permuted_tokens.shape[0] == routing_map.sum() - assert sorted_indices.shape[0] == routing_map.sum() - - # With drop and pad - capacity = get_capacity(num_tokens=self.num_tokens * self.topk, - num_experts=self.num_experts, - capacity_factor=self.capacity_factor) - num_out_tokens = capacity * self.num_experts - permuted_tokens, sorted_indices = permute( - tokens, - routing_map, - num_out_tokens=num_out_tokens, - drop_and_pad=True) - assert permuted_tokens.shape[0] == num_out_tokens - assert sorted_indices.shape[0] == num_out_tokens - - def test_unpermute(self, setup): - # Test token unpermutation - tokens = torch.randn(self.num_tokens, self.hidden_size) - routing_map = torch.randint( - 0, 2, (self.num_tokens, self.num_experts)).bool() - probs = torch.rand(self.num_tokens, self.num_experts) - - # First permute - permuted_tokens, sorted_indices = permute(tokens, routing_map) - - # Then unpermute - restored_tokens = unpermute(permuted_tokens, - sorted_indices, - tokens.shape, - probs=probs, - routing_map=routing_map) - assert restored_tokens.shape == tokens.shape - - # With drop and pad - capacity = get_capacity(num_tokens=self.num_tokens * self.topk, - num_experts=self.num_experts, - capacity_factor=self.capacity_factor) - num_out_tokens = capacity * self.num_experts - permuted_tokens, sorted_indices = permute( - tokens, - routing_map, - num_out_tokens=num_out_tokens, - drop_and_pad=True) - restored_tokens = unpermute(permuted_tokens, - sorted_indices, - tokens.shape, - probs=probs, - routing_map=routing_map, - drop_and_pad=True) - assert restored_tokens.shape == tokens.shape - - def test_sort_chunks_by_idxs(self, setup): - # Test chunk sorting - input_tensor = torch.randn(10, self.hidden_size) - split_sizes = torch.tensor([3, 2, 5]) - sorted_idxs = torch.tensor([2, 0, 1]) - - output = sort_chunks_by_idxs(input_tensor, split_sizes, sorted_idxs) - assert output.shape == input_tensor.shape - - # Verify the order is correct - expected = torch.cat( - [input_tensor[5:], input_tensor[0:3], input_tensor[3:5]]) - assert torch.allclose(output, expected) - - @pytest.mark.parametrize("score_function", ["softmax"]) - def test_score_functions(self, setup, score_function): - # Test different score functions - logits = torch.randn(self.num_tokens, self.num_experts) - expert_bias = torch.randn(self.num_experts) - - probs, routing_map, tokens_per_expert, top_indices = topk_softmax_with_capacity( - logits, - topk=self.topk, - score_function=score_function, - expert_bias=expert_bias) - assert probs.shape == (self.num_tokens, self.num_experts) - assert routing_map.shape == (self.num_tokens, self.num_experts) - assert tokens_per_expert.shape == (self.num_experts, ) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index c525849cb8..895382cb95 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -8,7 +8,6 @@ def register_model(): from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 - from .moe_block import AscendSparseMoeBlock # noqa: F401 from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 diff --git a/vllm_ascend/models/moe_block.py b/vllm_ascend/models/moe_block.py deleted file mode 100644 index 3a191411d6..0000000000 --- a/vllm_ascend/models/moe_block.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -from typing import Optional - -import torch -import vllm.model_executor.models.qwen3_moe as qwen3 -from torch import nn -from transformers import PretrainedConfig -from vllm.attention import AttentionMetadata -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group -from vllm.distributed.parallel_state import get_dp_group, get_ep_group -from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.quantization import QuantizationConfig - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.ops.fused_moe import AscendFusedMoE - - -class AscendSparseMoeBlock(nn.Module): - - top_k: int - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe - - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") - - self.experts = AscendFusedMoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts") - - self.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() - - self.params_dtype = torch.get_default_dtype() - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - if attn_metadata is None: - # for profile run - is_prefill = True - enable_force_load_balance = True - else: - is_prefill = get_forward_context().with_prefill - enable_force_load_balance = False - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=self.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=None, - ) - - return hidden_states - - -qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index e0982cf89e..7860bee643 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -28,6 +28,7 @@ import torch_npu from torch import nn from transformers import PretrainedConfig +import vllm.model_executor.models.qwen3_moe as qwen3 from vllm.attention import AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -58,7 +59,7 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.ops.fused_moe import apply_mlp, select_experts +from vllm_ascend.ops.fused_moe import apply_mlp, select_experts, AscendSparseMoeBlock VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -143,17 +144,7 @@ def _forward_op_gating( attn_metadata = get_forward_context().attn_metadata # when profile runs, force experts to load balanced tokens # to avoid high memory consumption on a single rank. - # TODO: need a better flag to indicate whether in profile run or not. - if attn_metadata is None: - # for profile run - self.is_prefill = True - self.enable_force_load_balance = True - else: - # is_prefill = attn_metadata.num_prefills > 0 - is_prefill = False - self.enable_force_load_balance = False - if hasattr(attn_metadata, 'with_prefill_across_dp'): - self.is_prefill = is_prefill or attn_metadata.with_prefill_across_dp + enable_force_load_balance = get_forward_context().in_profile_run num_tokens, hidden_dim = hidden_states.shape @@ -212,7 +203,7 @@ def _forward_op_gating( # this is a naive implementation for experts load balance so as # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. - if self.enable_force_load_balance: + if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, self.config.num_experts) return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes @@ -521,6 +512,7 @@ class CustomQwen3MoeForCausalLMDBO(Qwen3MoeForCausalLM): "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } + qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): nn.Module.__init__(self) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 8ff1b52a7a..aa21455957 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -17,7 +17,8 @@ # This file is a part of the vllm-ascend project. from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM - +import vllm.model_executor.models.qwen3_moe as qwen3 +from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { @@ -33,3 +34,5 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } + qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock + diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6e1615e0e1..09fca4dec8 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -23,6 +23,7 @@ import torch.distributed as dist import torch_npu from torch import nn +from transformers import PretrainedConfig from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -35,6 +36,9 @@ determine_expert_map) from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.attention import AttentionMetadata + import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -1338,3 +1342,77 @@ def _forward_ms_fused_moe_comp( enable_force_load_balance=enable_force_load_balance) return hidden_states + + +class AscendSparseMoeBlock(nn.Module): + + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = \ + ascend_config.torchair_graph_config.enable_multistream_moe + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = get_forward_context().in_profile_run + is_prefill = get_forward_context().with_prefill + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + ) + + return hidden_states \ No newline at end of file diff --git a/vllm_ascend/ops/moe_dispatcher/moe_utils.py b/vllm_ascend/ops/moe_dispatcher/moe_utils.py deleted file mode 100644 index dc19f75b33..0000000000 --- a/vllm_ascend/ops/moe_dispatcher/moe_utils.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. - -import math -from typing import Optional - -import torch -import torch_npu - - -def group_limited_topk( - scores: torch.Tensor, - topk: int, - num_tokens: int, - num_experts: int, - num_groups: int, - group_topk: int, -): - """Perform top-k routing on a subset of expert groups. - - When using group-limited routing: - 1. Experts are divided into 'moe_router_num_groups' equal-sized groups - 2. For each token, 'moe_router_group_topk' groups are selected based on routing scores - (specifically, the sum of top-2 expert scores within each group) - 3. From these selected groups, 'moe_router_topk' individual experts are chosen - - Two common use cases: - - Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP) - to limit each token to experts on a subset of devices - (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) - - - Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group - to limit each token to experts on a subset of nodes - (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437) - - Args: - scores (torch.Tensor): Softmax scores generated by the router. - topk (int): The number of experts to select for each token. - num_tokens (int): The number of tokens. - num_experts (int): The number of experts. - num_groups (int): Number of groups for routed experts. - group_topk (int): Number of groups selected for each token. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor. - """ - # Organize the experts into groups - # Select groups based on sum of top-(num_groups/group_topk) routing scores within each group - group_scores = (scores.view(num_tokens, - num_groups, -1).topk(num_groups // group_topk, - dim=-1)[0].sum(dim=-1)) - group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1] - group_mask = torch.zeros_like(group_scores) - group_mask.scatter_(1, group_idx, 1) - - # Mask the experts based on selection groups - score_mask = (group_mask.unsqueeze(-1).expand( - num_tokens, num_groups, - num_experts // num_groups).reshape(num_tokens, -1)) - - masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf')) - probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1) - - return probs, top_indices - - -def topk_softmax_with_capacity( - logits: torch.Tensor, - topk: int, - capacity_factor: Optional[float] = None, - pad_to_capacity: bool = False, - drop_policy: str = "probs", - use_pre_softmax: bool = False, - num_groups: Optional[int] = None, - group_topk: Optional[int] = None, - scaling_factor: Optional[float] = None, - deterministic_mode: bool = False, - score_function: str = "sigmoid", - expert_bias: Optional[torch.Tensor] = None, -): - """Apply capacity and padding to the top-k selection. - Args: - logits (torch.Tensor): Logits tensor. - topk (int): The number of experts to select for each token. - capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number - of tokens exceeds the capacity. - pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded - tokens will be 0. - drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". - If "prob", the tokens with the lowest probabilities will be dropped. - If "position", tokens at the end of each batch will be dropped. - use_pre_softmax (bool): Whether to apply softmax before top-k selection. - num_groups (int): Number of groups for routed experts. - group_topk (int): Number of selected groups for each token. - scaling_factor (float): Scaling factor of routing score in top-k selection. - deterministic_mode (bool): Deprecated. - score_function (str): The score function to use. Can be either "softmax" or "sigmoid". - expert_bias (torch.Tensor): The bias added to logits for expert routing. - - Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing - the routing probabilities for each token to each expert. - - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts] - indicating which experts were selected for each token. True values represent - the selected experts. - - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing - the number of local tokens assigned to each expert before dropping and padding. - """ - assert logits.dim( - ) == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." - num_tokens, num_experts = logits.shape - - def compute_topk(scores, topk, num_groups=None, group_topk=None): - if group_topk: - return group_limited_topk( - scores=scores, - topk=topk, - num_tokens=num_tokens, - num_experts=num_experts, - num_groups=num_groups, - group_topk=group_topk, - ) - else: - return torch.topk(scores, k=topk, dim=1) - - if score_function == "softmax": - if use_pre_softmax: - scores = torch.softmax(logits, dim=-1, - dtype=torch.float32).type_as(logits) - probs, top_indices = compute_topk(scores, topk, num_groups, - group_topk) - else: - scores, top_indices = compute_topk(logits, topk, num_groups, - group_topk) - probs = torch.softmax(scores, dim=-1, - dtype=torch.float32).type_as(logits) - if scaling_factor: - probs = probs * scaling_factor - elif score_function == "sigmoid": - probs, top_indices, _ = torch_npu.npu_moe_gating_top_k( - logits, - k=topk, # topk当前写8 - bias=expert_bias, - k_group=group_topk, # fix: 4 - group_count=num_groups, # fix 8 - group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) - renorm=0, # 0: softmax->topk(fix); 1: topk->softmax - norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # 第三个输出是否输出 - # y2_flag=False, # old api; 第三个输出是否输出 - routed_scaling_factor=scaling_factor, - eps=float(1e-20)) - else: - raise ValueError(f"Invalid score_function: {score_function}") - - # Try using element-wise operations instead of scatter? - topk_masked_gates = torch.zeros_like(logits).scatter( - 1, top_indices.type(torch.int64), probs) - topk_map = torch.zeros_like(logits).int().scatter( - 1, top_indices.type(torch.int64), 1).bool() - tokens_per_expert = topk_map.sum(dim=0) - - if capacity_factor is None: - # TopK without capacity - return topk_masked_gates, topk_map, tokens_per_expert, top_indices - else: - # TopK with capacity - expert_capacity = get_capacity(num_tokens=num_tokens * topk, - num_experts=num_experts, - capacity_factor=capacity_factor) - - # Maskout exceeded tokens - if drop_policy == "probs": - _, capacity_indices = torch.topk(topk_masked_gates, - k=expert_capacity, - dim=0, - sorted=False) - capacity_mask = torch.zeros_like(logits).scatter( - 0, capacity_indices, 1).bool() - elif drop_policy == "position": - _, capacity_indices = torch.topk(topk_map.int(), - k=expert_capacity, - dim=0, - sorted=False) - capacity_mask = torch.zeros_like(logits).scatter( - 0, capacity_indices, 1).bool() - else: - raise ValueError(f"Invalid drop_policy: {drop_policy}") - - if pad_to_capacity: - final_map = capacity_mask - final_probs = topk_masked_gates * final_map - else: - # Get exceed mask and maskout exceeded probs and indices - final_map = torch.logical_and(topk_map, capacity_mask) - final_probs = topk_masked_gates * final_map - return final_probs, final_map, tokens_per_expert, top_indices - - -def get_capacity(num_tokens: int, - num_experts: int, - capacity_factor: float, - min_capacity=None): - """ - Calculate the capacity of each expert. - - Args: - num_tokens (int): num of the input tokens. - num_experts (int): num of the experts. - capacity_factor (float): Capacity factor. - min_capacity (int, optional): Minimum capacity. Defaults to None. - - Returns: - Tensor: Capacity of each expert. - """ - capacity = math.ceil((num_tokens / num_experts) * capacity_factor) - if min_capacity is not None and capacity < min_capacity: - capacity = min_capacity - return capacity - - -def permute( - tokens, - routing_map, - num_out_tokens: Optional[int] = None, - fused: bool = False, - drop_and_pad: bool = False, -): - """Permute the tokens and probs based on the mask. - Tokens with the same designated expert will be grouped together. - The shape of mask is [tokens, num_experts], it indicates which experts were selected - by each token. - - When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to - expert capacity. This function exploits this feature to use ops that support cuda graph. - - Args: - tokens (torch.Tensor): The input token tensor, [num_tokens, hidden]. - routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts]. - num_out_tokens (int, optional): The number of output tokens. If None, it's set to - the number of input tokens. - fused (bool, optional): Whether use the fused permute function. - drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop - and pads the number of tokens to the expert capacity. - If set to true, routing_map has a fixed number of non-zeros - in each column. - """ - - num_tokens, hidden = tokens.shape - num_experts = routing_map.shape[1] - if drop_and_pad and (num_out_tokens is not None): - capacity = num_out_tokens // num_experts - assert not routing_map.requires_grad - # mask [num_tokens, num_experts] -> [num_experts, num_tokens] - routing_map = routing_map.to(dtype=torch.int8).T.contiguous() - # use argsort to put indices of all non-zeros in the beginning of list - # and keep the first `capacity` number of indices - sorted_indices = routing_map.argsort( - dim=-1, descending=True, stable=True)[:, :capacity].contiguous() - # flatten from [num_experts, capacity] to 1D - sorted_indices = sorted_indices.view(-1) - else: - # mask [num_tokens, num_experts] -> [num_experts, num_tokens] - routing_map = routing_map.bool().T.contiguous() - - # Create a dense expert-to-token mapping from the sparse token-to-expert mapping - token_indices = (torch.arange( - num_tokens, - device=routing_map.device).unsqueeze(0).expand(num_experts, -1)) - sorted_indices = token_indices.masked_select(routing_map) - - # use the mapping to permute the tokens - permuted_input = tokens.index_select(0, sorted_indices) - - return permuted_input, sorted_indices - - -def unpermute( - permuted_tokens: torch.Tensor, - sorted_indices: torch.Tensor, - restore_shape: torch.Size, - probs: torch.Tensor = None, - routing_map: torch.Tensor = None, - fused: bool = False, - drop_and_pad: bool = False, -): - """ - Restore the original order of tokens after permutation. If probs are provided, it - will also apply them to the tokens before restoring the order. - - When drop_and_pad=True, the tensors will have the following properties: - - In routing_map, the number of non-zeros in each column equals to expert capacity - - The size of sorted_indices equals to num_experts * capacity, each split of `capacity` - contains the indices of tokens routed to an expert. - This function exploits these features to use ops that support cuda graph. - - Args: - permuted_tokens (torch.Tensor): The permuted token tensor. - sorted_indices (torch.Tensor): The indices used to sort the tokens. - restore_shape (torch.Size): The shape of the unpermuted tensor. - probs (torch.Tensor, optional): The unpermuted probs tensor, - routing_map (torch.Tensor, optional): Token to expert mapping, shape - [num_tokens, num_experts]. - fused (bool, optional): Whether use the fused unpermute function. - drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop - and pads the number of tokens to the expert capacity. - - Returns: - torch.Tensor: The tokens restored to their original order. - """ - - _, hidden = restore_shape - input_dtype = permuted_tokens.dtype - - if probs is not None: - assert routing_map is not None, "Mask must be provided to permute the probs." - if drop_and_pad: - num_experts = routing_map.size(1) - num_permuted_tokens = sorted_indices.size(0) - capacity = num_permuted_tokens // num_experts - num_unpermuted_tokens = probs.size(0) - - # [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens - probs_T_1D = probs.T.contiguous().view(-1) - - # get 1D indices of the probs selected by routing_map - indices_dim0 = torch.arange( - num_experts, device=routing_map.device).unsqueeze(-1) - indices_dim1 = sorted_indices.view(num_experts, capacity) - indices_1D = (indices_dim0 * num_unpermuted_tokens + - indices_dim1).view(-1) - - # get probs from indices - permuted_probs = probs_T_1D.index_select(0, indices_1D) - else: - permuted_probs = probs.T.contiguous().masked_select( - routing_map.T.contiguous()) - # Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in - # higher precision due to moe_router_dtype being enabled. This can lead to - # additional GPU memory usage. Use --moe-permute-fusion flag to avoid this extra memory - # allocation. - permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) - - # Create an output tensor filled with zeros - output_tokens = torch.zeros(restore_shape, - dtype=permuted_tokens.dtype, - device=permuted_tokens.device) - # Scatter add the permuted_input back to the original positions - output_tokens.scatter_add_(0, - sorted_indices.unsqueeze(1).expand(-1, hidden), - permuted_tokens) - return output_tokens.to(dtype=input_dtype) - - -def sort_chunks_by_idxs(input: torch.Tensor, - split_sizes: torch.Tensor, - sorted_idxs: torch.Tensor, - fused: bool = False): - """Split and sort the input tensor based on the split_sizes and sorted indices.""" - if input.shape[0] == 0: - return input - - input = torch.split(input, split_sizes.tolist(), dim=0) - output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0) - return output diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 0d5e96ac31..c3dabaaaf4 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -31,8 +31,6 @@ all_to_all_sp2hp, gather_from_sequence_parallel_region, reduce_scatter_last_dim_to_tensor_parallel_region) from vllm_ascend.ops.comm_utils import async_all_to_all -from vllm_ascend.ops.moe_dispatcher.moe_utils import ( - get_capacity, permute, topk_softmax_with_capacity, unpermute) """ We use the following notation throughout this file: H: hidden size @@ -198,23 +196,6 @@ def __init__(self, config: MoeDispatcherConfig): # to each local expert by all ranks. self.num_global_tokens_per_local_expert_cpu = None self.num_global_tokens_per_local_expert = None - input_chunk_idxs = torch.arange(self.num_experts) - # [num_local_experts, ep_size]. Sort the input chunks by local experts. - self.sort_input_by_local_experts = input_chunk_idxs.reshape( - -1, self.num_local_experts).T.ravel() - # [ep_size, num_local_experts]. Restore the output chunks by local experts. - self.restore_output_by_local_experts = input_chunk_idxs.reshape( - self.num_local_experts, -1).T.ravel().to(torch.device("cpu"), - non_blocking=True) - - # Token drop and padding. - # We need to keep track of the token num if we drop tokens without padding them. - self.num_out_tokens = None - # Drop and pad the input to capacity. - self.drop_and_pad = self.config.moe_pad_expert_input_to_capacity - if self.drop_and_pad: - assert self.config.moe_expert_capacity_factor is not None - self.capacity = None # A cuda stream synchronization is needed in self.token_permutation() # in some cases, because there are several non-blocking DtoH data @@ -260,41 +241,18 @@ def preprocess(self, # num_local_tokens_per_expert: [num_experts] ep_size = self.ep_size - if self.drop_and_pad: - # Drop and pad the input to capacity. - num_tokens = indices.numel() - self.capacity = get_capacity( - num_tokens=num_tokens, - num_experts=self.num_experts, - capacity_factor=self.config.moe_expert_capacity_factor, - ) - self.num_out_tokens = self.capacity * self.num_experts - num_tokens_per_local_expert = torch.full( - (self.num_local_experts, ), - self.capacity * self.ep_size, - dtype=torch.long) - self.num_global_tokens_per_local_expert_cpu = torch.full( - (self.num_experts * self.tp_ep_size, ), - self.capacity, - dtype=torch.long) - return num_tokens_per_local_expert - elif self.config.moe_expert_capacity_factor is not None: - # Token drop but no pad. A synchronization is needed before the first - # permutation to get the `num_out_tokens` CPU value. - self.num_out_tokens = num_local_tokens_per_expert.sum().to( - torch.device("cpu"), non_blocking=True) - self.cuda_sync_point = "before_permutation_1" + + + # Dropless + self.num_out_tokens = indices.numel() + if self.ep_size > 1 or self.num_local_experts > 1: + # Token dropless and enable ep. A synchronization is needed before expert parallel + # AlltoAll communication to get the `input_splits` and `output_splits` CPU values. + self.cuda_sync_point = "before_ep_alltoall" else: - # Dropless - self.num_out_tokens = indices.numel() - if self.ep_size > 1 or self.num_local_experts > 1: - # Token dropless and enable ep. A synchronization is needed before expert parallel - # AlltoAll communication to get the `input_splits` and `output_splits` CPU values. - self.cuda_sync_point = "before_ep_alltoall" - else: - # Token dropless and no ep. A synchronization is needed to get the - # `tokens_per_expert` CPU value. - self.cuda_sync_point = "before_finish" + # Token dropless and no ep. A synchronization is needed to get the + # `tokens_per_expert` CPU value. + self.cuda_sync_point = "before_finish" if ep_size > 1: # =================================================== @@ -328,45 +286,6 @@ def preprocess(self, self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()) - # self.num_global_tokens_per_local_expert_cpu = ( - # self.num_global_tokens_per_local_expert.view(-1, self.num_local_experts).to( - # torch.device("cpu"), non_blocking=True - # ) - # ) - # if not hasattr(self, "comm_stream"): - # self.comm_stream = torch.npu.Stream() - # self.comm_stream.wait_stream(torch.npu.current_stream()) - - return num_tokens_per_local_expert - - def routing(self, probs): - seq_length, bsz = probs.shape[:2] - probs = probs.view(-1, self.config.num_moe_experts) - if self.config.is_fused: - score_function = "sigmoid" - else: - score_function = "softmax" - - scores, routing_map, _, top_indices = topk_softmax_with_capacity( - probs, - self.config.moe_router_topk, - capacity_factor=self.config.moe_expert_capacity_factor, - pad_to_capacity=self.config.moe_pad_expert_input_to_capacity, - group_topk=self.config.group_topk, - num_groups=self.config.num_groups, - expert_bias=self.config.expert_bias, - scaling_factor=self.config.scaling_factor, - score_function=score_function) - self.top_indices = top_indices - return scores, routing_map - - def preprocess_overlap(self, routing_map): - num_tokens_per_local_expert = self.preprocess(routing_map) - self.num_global_tokens_per_local_expert = self.num_global_tokens_per_local_expert - self.input_splits = self.input_splits - self.output_splits = self.output_splits - self.num_out_tokens = self.num_out_tokens - self.num_global_tokens_per_local_expert_cpu = self.num_global_tokens_per_local_expert_cpu return num_tokens_per_local_expert def token_permutation( @@ -392,7 +311,6 @@ def token_permutation( """ self.hidden_shape = hidden_states.shape self.probs = probs - self.routing_map = routing_map self.top_indices = routing_map assert probs.dim() == 2, "Expected 2D tensor for probs" assert routing_map.dim() == 2, "Expected 2D tensor for routing map" @@ -408,18 +326,12 @@ def alltoall_token_permutation1(hidden_states, routing_map): if self.cuda_sync_point == "before_permutation_1": torch.npu.current_stream().synchronize() - if not self.config.is_fused: - permutated_local_input_tokens, reversed_local_input_permutation_mapping = permute( - hidden_states, - routing_map, - num_out_tokens=self.num_out_tokens, - ) - else: - permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( - tokens=hidden_states, - indices=self.top_indices, - num_out_tokens=self.num_out_tokens, - ) + + permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( + tokens=hidden_states, + indices=self.top_indices, + num_out_tokens=self.num_out_tokens, + ) return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1( @@ -498,18 +410,12 @@ def preprocess_and_permtute1(self, shared_output = shared_experts(shared_experts_input) self.cached_shared_expert_output = shared_output - if not self.config.is_fused: - hidden_states, self.reversed_local_input_permutation_mapping = permute( - hidden_states, - routing_map, - num_out_tokens=self.num_out_tokens, - ) - else: - hidden_states, self.reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( - tokens=hidden_states, - indices=self.top_indices, - num_out_tokens=self.num_out_tokens, - ) + + hidden_states, self.reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( + tokens=hidden_states, + indices=self.top_indices, + num_out_tokens=self.num_out_tokens, + ) self.perm1_finish_event.record() @@ -620,11 +526,7 @@ def alltoall_token_unpermutation1(hidden_states): hidden_states = torch_npu.npu_moe_token_unpermute( hidden_states, self.reversed_global_input_permutation_mapping) - # hidden_states = sort_chunks_by_idxs( - # hidden_states, - # self.num_global_tokens_per_local_expert_cpu.T.ravel(), - # self.restore_output_by_local_experts, - # ) + return hidden_states hidden_states = alltoall_token_unpermutation1(hidden_states) @@ -639,23 +541,13 @@ def alltoall_token_unpermutation1(hidden_states): def alltoall_token_unpermutation2(permutated_local_input_tokens): # Unpermutation 1: AlltoAll output to output - if self.config.is_fused: - # permuted_probs = (self.probs.T.contiguous().masked_select(self.routing_map.T.contiguous()) - # .view(-1, self.config.moe_router_topk)) - output = torch_npu.npu_moe_token_unpermute( - permuted_tokens=permutated_local_input_tokens, - sorted_indices=self. - reversed_local_input_permutation_mapping.to(torch.int32), - probs=self.probs, - restore_shape=self.hidden_shape_before_permute) - else: - output = unpermute( - permutated_local_input_tokens, - self.reversed_local_input_permutation_mapping, - probs=self.probs, - restore_shape=self.hidden_shape_before_permute, - routing_map=self.routing_map, - ) + + output = torch_npu.npu_moe_token_unpermute( + permuted_tokens=permutated_local_input_tokens, + sorted_indices=self. + reversed_local_input_permutation_mapping.to(torch.int32), + probs=self.probs, + restore_shape=self.hidden_shape_before_permute) # Perform tensor parallel AlltoAll communication # output: [S*B, H/TP] -> [S*B/TP, H] From 2c102d32c7f13d82df3f24a61f0ce0d8d5f62f23 Mon Sep 17 00:00:00 2001 From: duyangkai Date: Fri, 11 Jul 2025 09:23:20 +0800 Subject: [PATCH 40/60] remove test Signed-off-by: duyangkai --- tests/ut/test_token_dispatcher.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py index 44b625e037..a7546ee7fe 100644 --- a/tests/ut/test_token_dispatcher.py +++ b/tests/ut/test_token_dispatcher.py @@ -54,9 +54,3 @@ def test_initialization(self, dispatcher, config): assert dispatcher.ep_rank == 0 assert dispatcher.ep_size == 2 assert dispatcher.overlap_stream is not None - - def test_routing(self, dispatcher): - probs = torch.randn(4, 4) # 4 tokens, 4 experts - scores, routing_map = dispatcher.routing(probs) - assert scores.shape == (4, 4) # topk=2 - assert routing_map.shape == (4, 4) From 565fa2d73954b79f638830786e6caed0b1bd52d6 Mon Sep 17 00:00:00 2001 From: duyangkai Date: Fri, 11 Jul 2025 14:30:14 +0800 Subject: [PATCH 41/60] clean code Signed-off-by: duyangkai --- tests/ut/test_distributed_tensor_parallel.py | 42 +++++++++++++------- tests/ut/test_token_dispatcher.py | 32 ++++++++++----- vllm_ascend/models/deepseek_dbo.py | 8 ++-- vllm_ascend/models/qwen3_dbo.py | 9 +---- vllm_ascend/ops/fused_moe.py | 2 - 5 files changed, 55 insertions(+), 38 deletions(-) diff --git a/tests/ut/test_distributed_tensor_parallel.py b/tests/ut/test_distributed_tensor_parallel.py index ae540cc08f..5792fb6df5 100644 --- a/tests/ut/test_distributed_tensor_parallel.py +++ b/tests/ut/test_distributed_tensor_parallel.py @@ -1,7 +1,22 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + import importlib +import unittest from unittest.mock import MagicMock, patch import pytest @@ -36,7 +51,7 @@ def mock_dist(): yield mock -class TestDistributedCommunication: +class TestDistributedCommunication(unittest.TestCase): @pytest.mark.parametrize("world_size", [1, 4]) def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, @@ -47,9 +62,9 @@ def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, result = _gather_along_first_dim(test_tensor, mock_group) if world_size == 1: - assert torch.equal(result, test_tensor) + self.assertEqual(result.shape, (8, 16)) else: - assert result.shape == (32, 16) # 8*4=32 + self.assertEqual(result.shape, (32, 16)) # 8*4=32 def test_gather_along_first_dim_unequal_split(self, test_tensor, mock_group): @@ -57,7 +72,7 @@ def test_gather_along_first_dim_unequal_split(self, test_tensor, output_split_sizes = [5, 10, 15, 2] result = _gather_along_first_dim(test_tensor, mock_group, output_split_sizes) - assert result.shape == (32, 16) # 5+10+15+2=32 + self.assertEqual(result.shape, (32, 16)) # 5+10+15+2=32 @pytest.mark.parametrize("world_size", [1, 4]) def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, @@ -67,10 +82,7 @@ def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, result = _gather_along_last_dim(test_tensor_last_dim, mock_group) - if world_size == 1: - assert torch.equal(result, test_tensor_last_dim) - else: - assert result.shape == (8, 16, 32 * world_size) # 8*4=32 + self.assertEqual(result.shape, (8, 16, 32 * world_size)) @pytest.mark.parametrize("input_shape,expected_shape", [ ((32, 16), (8, 16)), @@ -80,12 +92,12 @@ def test_reduce_scatter_along_first_dim(self, mock_group, input_shape, expected_shape): input_tensor = torch.randn(*input_shape) result = _reduce_scatter_along_first_dim(input_tensor, mock_group) - assert result.shape == expected_shape + self.assertEqual(result.shape, expected_shape) def test_reduce_scatter_along_last_dim(self, mock_group): input_tensor = torch.randn(8, 16, 32) result = _reduce_scatter_along_last_dim(input_tensor, mock_group) - assert result.shape == (8, 16, 8) # 32/4=8 + self.assertEqual(result.shape, (8, 16, 8)) @pytest.mark.parametrize("func,input_shape,expected_shape", [ ("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), @@ -104,7 +116,7 @@ def test_wrapper_functions(self, mock_group, func, input_shape, test_func = globals[func] input_tensor = torch.randn(*input_shape) result = test_func(input_tensor, mock_group) - assert result.shape == expected_shape + self.assertEqual(result.shape, expected_shape) @pytest.mark.parametrize( "input_shape,output_shape", @@ -114,7 +126,7 @@ def test_wrapper_functions(self, mock_group, func, input_shape, def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape): input_tensor = torch.randn(*input_shape) result = all_to_all_sp2hp(input_tensor, mock_group) - assert result.shape == output_shape + self.assertEqual(result.shape, output_shape) @pytest.mark.parametrize( "input_shape,output_shape", @@ -124,4 +136,4 @@ def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape): def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape): input_tensor = torch.randn(*input_shape) result = all_to_all_hp2sp(input_tensor, mock_group) - assert result.shape == output_shape + self.assertEqual(result.shape, output_shape) diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py index a7546ee7fe..497eaad6be 100644 --- a/tests/ut/test_token_dispatcher.py +++ b/tests/ut/test_token_dispatcher.py @@ -1,9 +1,23 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. import pytest import torch +import unittest from pytest_mock import MockerFixture from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( @@ -17,7 +31,7 @@ adapt_patch(True) -class TestMoEAlltoAllSeqOverLapDispatcher: +class TestMoEAlltoAllSeqOverLapDispatcher(unittest.TestCase): @pytest.fixture def config(self): @@ -48,9 +62,9 @@ def dispatcher(self, config, mocker: MockerFixture): return MoEAlltoAllSeqOverLapDispatcher(config) def test_initialization(self, dispatcher, config): - assert dispatcher.num_local_experts == config.num_local_experts - assert dispatcher.num_experts == config.num_moe_experts - assert dispatcher.local_expert_indices == [0, 1] - assert dispatcher.ep_rank == 0 - assert dispatcher.ep_size == 2 - assert dispatcher.overlap_stream is not None + self.assertEqual(dispatcher.num_local_experts, config.num_local_experts) + self.assertEqual(dispatcher.num_experts, config.num_moe_experts) + self.assertEqual(dispatcher.local_expert_indices, [0, 1]) + self.assertEqual(dispatcher.ep_rank, 0) + self.assertEqual(dispatcher.ep_size, 2) + self.assertIsNotNone(dispatcher.overlap_stream) diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index d9abb9be87..841eed6b07 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -147,7 +147,7 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=True if not envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ else False, + reduce_results=not envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is seperated in alltoallv for better overlap. prefix=f"{prefix}.shared_experts", ) CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok @@ -245,15 +245,13 @@ def _forward_op_gating( if self.config.n_routed_experts == 256: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( router_logits, - k=self.config.num_experts_per_tok, # topk当前写8 + k=self.config.num_experts_per_tok, bias=self.gate.e_score_correction_bias, k_group=self.config.topk_group, # fix: 4 group_count=self.config.n_group, # fix 8 - group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + group_select_mode=1, # 0: max in group; 1: topk2.sum(fix) renorm=0, # 0: softmax->topk(fix); 1: topk->softmax norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; 第三个输出是否输出 - # y2_flag=False, # old api; 第三个输出是否输出 routed_scaling_factor=1, eps=float(1e-20)) else: diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 7860bee643..7035cfe936 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -1,12 +1,6 @@ -# SPDX-License-Identifier: Apache-2.0 # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. # -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# # Adapted from +# This file is a part of the vllm-ascend project. + # """Inference-only Qwen3 model.""" from types import SimpleNamespace from typing import List, Optional, Union diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 09fca4dec8..0f8dce525a 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1346,8 +1346,6 @@ def _forward_ms_fused_moe_comp( class AscendSparseMoeBlock(nn.Module): - top_k: int - def __init__( self, config: PretrainedConfig, From f980ad03b2a6d3297bd77e10f17fa814cd2784e4 Mon Sep 17 00:00:00 2001 From: duyangkai Date: Fri, 11 Jul 2025 14:34:51 +0800 Subject: [PATCH 42/60] fix clean code Signed-off-by: duyangkai --- vllm_ascend/models/qwen3_dbo.py | 4 +--- vllm_ascend/ops/fused_moe.py | 4 ++-- vllm_ascend/ops/moe_dispatcher/token_dispatcher.py | 2 +- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 7035cfe936..94fc68c78e 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -172,11 +172,9 @@ def _forward_op_gating( bias=self.mlp.gate.e_score_correction_bias, k_group=mlp_config.topk_group, # fix: 4 group_count=mlp_config.n_group, # fix 8 - group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix) + group_select_mode=1, # 0: max in group; 1: topk2.sum(fix) renorm=0, # 0: softmax->topk(fix); 1: topk->softmax norm_type=1, # 0: softmax; 1: sigmoid(fix) - # out_flag=False, # todo new api; 第三个输出是否输出 - # y2_flag=False, # old api; 第三个输出是否输出 routed_scaling_factor=1, eps=float(1e-20)) else: diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 0f8dce525a..3b082bb18f 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -46,7 +46,7 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) + MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_ascend_soc_version, npu_stream_switch, npu_wait_tensor) @@ -1164,7 +1164,7 @@ def __init__( self.quant_method, AscendUnquantizedFusedMoEMethod): self.reduce_results = False moe_dispatcher_config = ( - MoeDispatcherConfig().set_num_moe_experts( + MoEDispatcherConfig().set_num_moe_experts( self.global_num_experts).set_num_local_experts( self.local_num_experts).set_moe_router_topk( top_k).set_group_topk(topk_group). diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index c3dabaaaf4..7feead31d7 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -43,7 +43,7 @@ """ -class MoeDispatcherConfig: +class MoEDispatcherConfig: def __init__(self): self.num_local_experts: int = 0 From 969ee25fb467d5878bf2ac1eda49a16e9965d637 Mon Sep 17 00:00:00 2001 From: duyangkai Date: Fri, 11 Jul 2025 14:41:44 +0800 Subject: [PATCH 43/60] typo Signed-off-by: duyangkai --- tests/ut/test_token_dispatcher.py | 4 ++-- vllm_ascend/ops/moe_dispatcher/token_dispatcher.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py index 497eaad6be..64865d0677 100644 --- a/tests/ut/test_token_dispatcher.py +++ b/tests/ut/test_token_dispatcher.py @@ -21,7 +21,7 @@ from pytest_mock import MockerFixture from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - MoEAlltoAllSeqOverLapDispatcher, MoeDispatcherConfig) + MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.utils import adapt_patch # noqa E402 import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa @@ -35,7 +35,7 @@ class TestMoEAlltoAllSeqOverLapDispatcher(unittest.TestCase): @pytest.fixture def config(self): - config = MoeDispatcherConfig() + config = MoEDispatcherConfig() config.set_num_local_experts(2) config.set_num_moe_experts(4) config.set_moe_pad_expert_input_to_capacity(False) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 7feead31d7..8664cfc917 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -109,7 +109,7 @@ def build(self): class MoEDispatcher: - def __init__(self, config: MoeDispatcherConfig) -> None: + def __init__(self, config: MoEDispatcherConfig) -> None: """ Initialize the MoE Token Dispatcher. """ @@ -151,12 +151,12 @@ class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher): """ - def __init__(self, config: MoeDispatcherConfig): + def __init__(self, config: MoEDispatcherConfig): """ Initialize the AlltoAllSeq token dispatcher. Args: - config (MoeDispatcherConfig): Configuration for the transformer model. + config (MoEDispatcherConfig): Configuration for the transformer model. """ super().__init__(config) self.num_local_experts = config.num_local_experts From a70be9acd1dc84f5f7c5c8af9fc0c77296e71f56 Mon Sep 17 00:00:00 2001 From: duyangkai Date: Fri, 11 Jul 2025 17:32:57 +0800 Subject: [PATCH 44/60] renaming cuda sync point Signed-off-by: duyangkai --- .../ops/moe_dispatcher/token_dispatcher.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 8664cfc917..1e2352000d 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -203,7 +203,7 @@ def __init__(self, config: MoEDispatcherConfig): # at different points based on MoE settings as late as possible. # Valid sync points are "before_permutation_1", "before_ep_alltoall", # "before_finish", and "no_sync". - self.cuda_sync_point = "no_sync" + self.device_sync_point = "no_sync" # cached intermediate tensors. self.cached_permutated_local_input_tokens = None @@ -248,11 +248,11 @@ def preprocess(self, if self.ep_size > 1 or self.num_local_experts > 1: # Token dropless and enable ep. A synchronization is needed before expert parallel # AlltoAll communication to get the `input_splits` and `output_splits` CPU values. - self.cuda_sync_point = "before_ep_alltoall" + self.device_sync_point = "before_ep_alltoall" else: # Token dropless and no ep. A synchronization is needed to get the # `tokens_per_expert` CPU value. - self.cuda_sync_point = "before_finish" + self.device_sync_point = "before_finish" if ep_size > 1: # =================================================== @@ -281,7 +281,7 @@ def preprocess(self, num_tokens_per_local_expert = num_local_tokens_per_expert if self.num_local_experts > 1 and with_sync: - self.cuda_sync_point = "no_sync" + self.device_sync_point = "no_sync" self.global_input_tokens_local_experts_indices = torch.repeat_interleave( self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()) @@ -324,7 +324,7 @@ def alltoall_token_permutation1(hidden_states, routing_map): group=self.tp_ep_group) self.hidden_shape_before_permute = hidden_states.shape - if self.cuda_sync_point == "before_permutation_1": + if self.device_sync_point == "before_permutation_1": torch.npu.current_stream().synchronize() permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( @@ -342,7 +342,7 @@ def alltoall_token_permutation1(hidden_states, routing_map): ep_group = self.ep_group # Perform expert parallel AlltoAll communication - if self.cuda_sync_point == "before_ep_alltoall": + if self.device_sync_point == "before_ep_alltoall": torch.npu.current_stream().synchronize() _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( permutated_local_input_tokens, @@ -372,7 +372,7 @@ def alltoall_token_permutation2(global_input_tokens): if self.tp_ep_size > 1 and self.config.moe_grouped_gemm: global_input_tokens = all_gather_last_dim_from_tensor_parallel_region( global_input_tokens, self.tp_ep_group) - if self.cuda_sync_point == "before_finish": + if self.device_sync_point == "before_finish": torch.npu.current_stream().synchronize() return global_input_tokens @@ -398,7 +398,7 @@ def preprocess_and_permtute1(self, tokens_per_expert = self.preprocess(routing_map, with_sync=False) self.hidden_shape_before_permute = hidden_states.shape - if self.cuda_sync_point == "before_permutation_1": + if self.device_sync_point == "before_permutation_1": torch.npu.current_stream().synchronize() event = torch.npu.current_stream().record_event() @@ -421,7 +421,7 @@ def preprocess_and_permtute1(self, # repeat interleve will launch a sync on current_stream. if self.num_local_experts > 1: - self.cuda_sync_point = "no_sync" + self.device_sync_point = "no_sync" self.global_input_tokens_local_experts_indices = torch.repeat_interleave( self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()) @@ -433,7 +433,7 @@ def dispatch_alltoall(self): ep_group = self.ep_group # Perform expert parallel AlltoAll communication - if self.cuda_sync_point == "before_ep_alltoall": + if self.device_sync_point == "before_ep_alltoall": torch.npu.current_stream().synchronize() torch.npu.current_stream().wait_event(self.perm1_finish_event) From 402f88931e04279392f77cae77e6ed4d937d55e3 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Fri, 11 Jul 2025 17:51:56 +0800 Subject: [PATCH 45/60] handle code clean Signed-off-by: weijinqian_v1 --- tests/multicard/test_qwen3_moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/multicard/test_qwen3_moe.py b/tests/multicard/test_qwen3_moe.py index 122f8024fb..e24770b792 100644 --- a/tests/multicard/test_qwen3_moe.py +++ b/tests/multicard/test_qwen3_moe.py @@ -38,7 +38,7 @@ "VLLM_ASCEND_ENABLE_DBO": "1" }) def test_qwen3_moe_inference(model, max_tokens): - script = "examples/offline_data_parallel.py" + script = "examples/dp_offline/data_parallel.py" env = os.environ.copy() @@ -56,7 +56,6 @@ def test_qwen3_moe_inference(model, max_tokens): "--node-rank", "0", "--trust-remote-code", - "--enforce-eager", ] print(f"Running subprocess: {' '.join(cmd)}") From 141407d9cdf9e84786f95ff7c74f46aca11d1f34 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Fri, 11 Jul 2025 18:23:43 +0800 Subject: [PATCH 46/60] handle code clean Signed-off-by: weijinqian_v1 --- vllm_ascend/models/deepseek_dbo.py | 2 +- .../ops/moe_dispatcher/token_dispatcher.py | 18 ++++++++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 841eed6b07..6562bb46bd 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -147,7 +147,7 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=not envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is seperated in alltoallv for better overlap. + reduce_results=not envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is separated in alltoallv for better overlap. prefix=f"{prefix}.shared_experts", ) CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 1e2352000d..1e18900870 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -32,16 +32,6 @@ reduce_scatter_last_dim_to_tensor_parallel_region) from vllm_ascend.ops.comm_utils import async_all_to_all -""" We use the following notation throughout this file: - H: hidden size - B: micro batch size - S: sequence length - TP: tensor model parallel size - EP: expert model parallel size - num_local_tokens: S/TP*B - num_global_tokens: num_local_tokens*TP*EP -""" - class MoEDispatcherConfig: @@ -266,6 +256,10 @@ def preprocess(self, group=self.ep_group).reshape(ep_size, self.num_experts) self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ 0]:self.local_expert_indices[-1] + 1] + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before sum." + ) self.output_splits = (self.num_global_tokens_per_local_expert.sum( axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( @@ -281,6 +275,10 @@ def preprocess(self, num_tokens_per_local_expert = num_local_tokens_per_expert if self.num_local_experts > 1 and with_sync: + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before operations." + ) self.device_sync_point = "no_sync" self.global_input_tokens_local_experts_indices = torch.repeat_interleave( self.expert_ids_per_ep_rank, From b1d7305f4526c19fb8ad0ff6a25db804b349e4aa Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Fri, 11 Jul 2025 18:26:53 +0800 Subject: [PATCH 47/60] handle code clean Signed-off-by: weijinqian_v1 --- tests/ut/test_token_dispatcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py index 64865d0677..15b1ac15fb 100644 --- a/tests/ut/test_token_dispatcher.py +++ b/tests/ut/test_token_dispatcher.py @@ -16,7 +16,6 @@ # This file is a part of the vllm-ascend project. import pytest -import torch import unittest from pytest_mock import MockerFixture From 62cebe1a277e966338d4c46dd5f49520ef5b63be Mon Sep 17 00:00:00 2001 From: weijinqian Date: Fri, 11 Jul 2025 22:11:16 +0800 Subject: [PATCH 48/60] handle clean code Signed-off-by: weijinqian_v1 --- vllm_ascend/ascend_forward_context.py | 92 +-- vllm_ascend/models/qwen3_dbo.py | 310 +++++----- vllm_ascend/models/qwen3_moe.py | 8 +- vllm_ascend/ops/fused_moe.py | 805 +++++++++++++++----------- 4 files changed, 693 insertions(+), 522 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 04685f5a43..c2d81037ff 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -10,7 +10,6 @@ from vllm.platforms import current_platform import vllm_ascend.envs as envs - import vllm_ascend.envs as envs_ascend @@ -29,8 +28,11 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool): return FusedMoEState.AllGather elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: # MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage. - return FusedMoEState.All2AllSeq if ( - ep_size < 16 or with_prefill) else FusedMoEState.MC2 + return ( + FusedMoEState.All2AllSeq + if (ep_size < 16 or with_prefill) + else FusedMoEState.MC2 + ) elif ep_size >= 16 and with_prefill and enable_chunk_mc2: return FusedMoEState.MC2_PREFILL # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. @@ -42,27 +44,33 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool): @contextmanager def set_ascend_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, - with_prefill: bool = True, - in_profile_run: bool = False, - num_actual_tokens: Optional[int] = None): + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + with_prefill: bool = True, + in_profile_run: bool = False, + num_actual_tokens: Optional[int] = None, +): """A context manager that stores the current forward context, can be attention metadata, etc. We add some additional param into forward_context. """ - with set_forward_context(attn_metadata, - vllm_config, - virtual_engine=virtual_engine, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + with set_forward_context( + attn_metadata, + vllm_config, + virtual_engine=virtual_engine, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + ): forward_context = get_forward_context() forward_context.with_prefill = with_prefill - ep_size = torch.distributed.get_world_size( - ) if vllm_config.parallel_config.enable_expert_parallel else 1 + ep_size = ( + torch.distributed.get_world_size() + if vllm_config.parallel_config.enable_expert_parallel + else 1 + ) fused_moe_state = get_fused_moe_state(ep_size, with_prefill) @@ -75,19 +83,22 @@ def set_ascend_forward_context( forward_context.capturing = False if num_tokens is None and attn_metadata is not None: - if hasattr(attn_metadata, 'num_actual_tokens'): + if hasattr(attn_metadata, "num_actual_tokens"): # for v1 engine num_tokens = attn_metadata.num_actual_tokens else: # for v0 engine - num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + num_tokens = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) if num_actual_tokens is None: num_actual_tokens = num_tokens dp_world_size = get_dp_group().world_size if dp_world_size > 1 and forward_context.dp_metadata is not None: - max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item( + max_tokens_across_dp = ( + forward_context.dp_metadata.max_tokens_across_dp_cpu.item() ) else: max_tokens_across_dp = num_tokens @@ -98,29 +109,38 @@ def set_ascend_forward_context( tp_world_size = get_tp_group().world_size world_size = torch.distributed.get_world_size() # NOTE: token num which need to pad to when mc2 - forward_context.padded_num_tokens = math.ceil( - max_tokens_across_dp / tp_world_size) * tp_world_size + forward_context.padded_num_tokens = ( + math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size + ) # NOTE: mc2 op's param `global_bs`, add `world_size` to make `global_bs` absolutely larger than actual global_bs. - forward_context.global_bs = math.ceil( - max_tokens_across_dp / tp_world_size) * world_size + forward_context.global_bs = ( + math.ceil(max_tokens_across_dp / tp_world_size) * world_size + ) if fused_moe_state == FusedMoEState.MC2_PREFILL: chunk_size = envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE forward_context.max_num_chunks = math.ceil( - math.ceil(max_tokens_across_dp / tp_world_size) / - chunk_size) + math.ceil(max_tokens_across_dp / tp_world_size) / chunk_size + ) - forward_context.global_bs = math.ceil( - math.ceil(max_tokens_across_dp / tp_world_size) / - forward_context.max_num_chunks) * world_size + forward_context.global_bs = ( + math.ceil( + math.ceil(max_tokens_across_dp / tp_world_size) + / forward_context.max_num_chunks + ) + * world_size + ) min_num_tokens = forward_context.max_num_chunks * tp_world_size - forward_context.padded_num_tokens = math.ceil( - max_tokens_across_dp / min_num_tokens) * min_num_tokens - - mc2_mask = torch.zeros(forward_context.padded_num_tokens, - dtype=torch.bool, - device=current_platform.device_type) + forward_context.padded_num_tokens = ( + math.ceil(max_tokens_across_dp / min_num_tokens) * min_num_tokens + ) + + mc2_mask = torch.zeros( + forward_context.padded_num_tokens, + dtype=torch.bool, + device=current_platform.device_type, + ) mc2_mask[:num_actual_tokens] = True forward_context.mc2_mask = mc2_mask diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 94fc68c78e..4e7dc12df7 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -21,40 +21,54 @@ import torch import torch_npu +import vllm.model_executor.models.qwen3_moe as qwen3 from torch import nn from transformers import PretrainedConfig -import vllm.model_executor.models.qwen3_moe as qwen3 from vllm.attention import AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, - get_tensor_model_parallel_world_size, - get_tp_group) +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group, +) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer, - Qwen3MoeForCausalLM, - Qwen3MoeModel) + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.models.qwen3_moe import ( + Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeModel, +) from vllm.model_executor.models.utils import ( - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend -from vllm_ascend.distributed.tensor_parallel import \ - gather_from_sequence_parallel_region +from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( - advance_step_multistream_layer_context, get_multistream_layer_context) -from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, - MultiStreamPreTransformerLayer) -from vllm_ascend.multistream.metadata import (MultiStreamConfig, - MultiStreamStepMetadata, - make_multistream_metadata_ds) -from vllm_ascend.ops.fused_moe import apply_mlp, select_experts, AscendSparseMoeBlock + advance_step_multistream_layer_context, + get_multistream_layer_context, +) +from vllm_ascend.multistream.layers import ( + MultiStreamPostTransformerLayer, + MultiStreamPreTransformerLayer, +) +from vllm_ascend.multistream.metadata import ( + MultiStreamConfig, + MultiStreamStepMetadata, + make_multistream_metadata_ds, +) +from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock, apply_mlp, select_experts VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -68,15 +82,21 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super(Qwen3MoeDecoderLayerDBO, self).__init__(config, cache_config, - quant_config, prefix) + super(Qwen3MoeDecoderLayerDBO, self).__init__( + config, cache_config, quant_config, prefix + ) self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group self.tp_group = get_tp_group().device_group self.dummy_vllm_config = SimpleNamespace( - parallel_config=SimpleNamespace(data_parallel_size=1, ), - compilation_config=SimpleNamespace(static_forward_context=None, ), - other_setting="value") + parallel_config=SimpleNamespace( + data_parallel_size=1, + ), + compilation_config=SimpleNamespace( + static_forward_context=None, + ), + other_setting="value", + ) self.config = config def forward(self, *args, **kwargs): @@ -92,8 +112,7 @@ def _forward_ms_op_input_layernorm( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) return hidden_states, residual def _forward_ms_op_attn( @@ -104,8 +123,9 @@ def _forward_ms_op_attn( kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - self.dummy_vllm_config.compilation_config.static_forward_context = get_forward_context( - ).no_compile_layers + self.dummy_vllm_config.compilation_config.static_forward_context = ( + get_forward_context().no_compile_layers + ) with set_forward_context(attn_metadata, self.dummy_vllm_config): hidden_states = self.self_attn( positions=positions, @@ -115,11 +135,11 @@ def _forward_ms_op_attn( # Fix FP16 overflow # We scale both hidden_states and residual before # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor + hidden_states *= 1.0 / self.routed_scaling_factor if self.layer_idx == 0: # The residual is shared by all layers, we only scale it on # first layer. - residual *= 1. / self.routed_scaling_factor + residual *= 1.0 / self.routed_scaling_factor return hidden_states, residual def _forward_ms_op_post_attn_layernorm( @@ -127,14 +147,14 @@ def _forward_ms_op_post_attn_layernorm( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ): - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) return hidden_states, residual def _forward_op_gating( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata # when profile runs, force experts to load balanced tokens @@ -148,13 +168,10 @@ def _forward_op_gating( num_tokens, hidden_size = hidden_states.shape if num_tokens < self.tp_size: hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, self.tp_size - num_tokens)) - chunk_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - chunked_hidden_states_sizes = [ - x.shape[0] for x in chunk_hidden_states - ] + hidden_states, (0, 0, 0, self.tp_size - num_tokens) + ) + chunk_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0) + chunked_hidden_states_sizes = [x.shape[0] for x in chunk_hidden_states] local_hidden_states = chunk_hidden_states[self.tp_rank] else: local_hidden_states = hidden_states @@ -176,7 +193,8 @@ def _forward_op_gating( renorm=0, # 0: softmax->topk(fix); 1: topk->softmax norm_type=1, # 0: softmax; 1: sigmoid(fix) routed_scaling_factor=1, - eps=float(1e-20)) + eps=float(1e-20), + ) else: topk_weights, topk_ids = select_experts( hidden_states=local_hidden_states, @@ -187,10 +205,11 @@ def _forward_op_gating( topk_group=getattr(mlp_config, "topk_group", None), num_expert_group=getattr(mlp_config, "n_group", None), custom_routing_function=None, - scoring_func=getattr(mlp_config, "scoring_func", 'softmax'), - e_score_correction_bias=getattr(self.mlp.gate, - "e_score_correction_bias", - None)) + scoring_func=getattr(mlp_config, "scoring_func", "softmax"), + e_score_correction_bias=getattr( + self.mlp.gate, "e_score_correction_bias", None + ), + ) topk_weights = topk_weights.to(hidden_states.dtype) # this is a naive implementation for experts load balance so as @@ -202,28 +221,33 @@ def _forward_op_gating( return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes def _forward_op_grouped_mlp(self, dispatched_input, tokens_per_expert): - return apply_mlp(dispatched_input, self.mlp.experts.w13_weight, - self.mlp.experts.w2_weight, tokens_per_expert) + return apply_mlp( + dispatched_input, + self.mlp.experts.w13_weight, + self.mlp.experts.w2_weight, + tokens_per_expert, + ) - def _forward_combine_comm(self, hidden_states, microbatch_id, num_tokens, - chunked_hidden_states_sizes): + def _forward_combine_comm( + self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes + ): token_dispatcher = self.mlp.experts.token_dispatchers[microbatch_id] - final_hidden_states, _ = token_dispatcher.token_unpermutation( - hidden_states) - if hasattr(self.mlp, 'routed_scaling_factor'): + final_hidden_states, _ = token_dispatcher.token_unpermutation(hidden_states) + if hasattr(self.mlp, "routed_scaling_factor"): final_hidden_states = final_hidden_states * self.mlp.routed_scaling_factor if self.tp_size > 1: final_hidden_states = gather_from_sequence_parallel_region( - final_hidden_states, self.tp_group, - chunked_hidden_states_sizes) + final_hidden_states, self.tp_group, chunked_hidden_states_sizes + ) if num_tokens < self.tp_size: final_hidden_states = final_hidden_states[:num_tokens] if hasattr(self.mlp, "shared_experts"): - final_hidden_states = final_hidden_states + token_dispatcher.cached_shared_expert_output - token_dispatcher.cached_shared_expert_output.untyped_storage( - ).resize_(0) + final_hidden_states = ( + final_hidden_states + token_dispatcher.cached_shared_expert_output + ) + token_dispatcher.cached_shared_expert_output.untyped_storage().resize_(0) token_dispatcher.cached_shared_expert_output = None final_hidden_states = final_hidden_states.view(num_tokens, -1) @@ -238,8 +262,7 @@ def _forward_ms_layer_alltoallv_finegrained( attn_metadata: List[AttentionMetadata], kv_cache: Optional[torch.Tensor] = None, ): - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( - ) + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() assert layer_index >= 0 and ms_metadata is not None num_micro_batchs = ms_metadata.ms_config.num_micro_batches assert len(positions) == num_micro_batchs @@ -248,9 +271,7 @@ def _forward_ms_layer_alltoallv_finegrained( assert attn_metadata is not None num_tokens = [None] * num_micro_batchs hidden_dims = [None] * num_micro_batchs - topk_weights, topk_ids = [None] * num_micro_batchs, [ - None - ] * num_micro_batchs + topk_weights, topk_ids = [None] * num_micro_batchs, [None] * num_micro_batchs tokens_per_expert = [None] * num_micro_batchs dispatched_input = [None] * num_micro_batchs router_expert_output = [None] * num_micro_batchs @@ -271,85 +292,95 @@ def discard_tensor(tensor): # can be overlapped with the attn communication of microbatch 1 for i in range(num_micro_batchs): forward_context = get_forward_context() - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( - ) - ms_metadata.try_wait_event(layer_index - 1, i, - MSEventKey.FFN_AR_FINISH) + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) forward_context.attn_metadata = attn_metadata[i] # input layernorm - hidden_states[i], residual[ - i] = self._forward_ms_op_input_layernorm( - hidden_states[i], residual[i]) + hidden_states[i], residual[i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i] + ) # attention and tp allreduce hidden_states[i], residual[i] = self._forward_ms_op_attn( - positions[i], hidden_states[i], residual[i], kv_cache, - attn_metadata[i]) + positions[i], hidden_states[i], residual[i], kv_cache, attn_metadata[i] + ) # post attention layer norm - hidden_states[i], residual[ - i] = self._forward_ms_op_post_attn_layernorm( - hidden_states[i], residual[i]) + hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i] + ) num_tokens[i], hidden_dims[i] = hidden_states[i].shape # If TP is enabled, hidden_states will be chunked. - topk_weights[i], topk_ids[i], dispatched_input[ - i], chunked_hidden_states_sizes[i] = self._forward_op_gating( - hidden_states[i], attn_metadata[i]) + ( + topk_weights[i], + topk_ids[i], + dispatched_input[i], + chunked_hidden_states_sizes[i], + ) = self._forward_op_gating(hidden_states[i], attn_metadata[i]) token_dispatchers[i].preprocess_and_permtute1( dispatched_input[i], topk_weights[i], topk_ids[i], shared_experts=None, - shared_experts_input=None) + shared_experts_input=None, + ) # Launch DisPatch Comm in a New Stream. dispatch_context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_BEFORE_COMM], + MSEventKey.MOE_BEFORE_COMM + ], after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM], + MSEventKey.MOE_AFTER_COMM + ], ) dispatch_context.before_comm_event.record() # print_with_sync(f'begin token dispatch{i}...', torch.distributed.get_rank()) with torch.npu.stream(dispatch_context.comm_stream): dispatch_context.comm_stream.wait_event( - dispatch_context.before_comm_event) + dispatch_context.before_comm_event + ) token_dispatchers[i].dispatch_alltoall() dispatched_input[i], tokens_per_expert[i] = token_dispatchers[ - i].permute2() + i + ].permute2() dispatch_context.after_comm_event.record() # print_with_sync('begin experts...', torch.distributed.get_rank()) # block 4 : Router Experts Computation # block 5 : Token Combine Communication for i in range(num_micro_batchs): - ms_metadata.try_wait_event(layer_index, i, - MSEventKey.MOE_AFTER_COMM) + ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM) discard_tensor(hidden_states[i]) router_expert_output[i] = self._forward_op_grouped_mlp( - dispatched_input[i], tokens_per_expert[i]) + dispatched_input[i], tokens_per_expert[i] + ) discard_tensor(dispatched_input[i]) # Launch Combine Comm in a New Stream. combine_context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_COM_FINISH], + MSEventKey.FFN_COM_FINISH + ], after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_AR_FINISH], + MSEventKey.FFN_AR_FINISH + ], ) combine_context.before_comm_event.record() - ms_metadata.try_wait_event(layer_index, i, - MSEventKey.MOE_SE_COMM_FINISH) + ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) with torch.npu.stream(combine_context.comm_stream): combine_context.comm_stream.wait_event( - combine_context.before_comm_event) + combine_context.before_comm_event + ) hidden_states[i] = self._forward_combine_comm( - router_expert_output[i], i, num_tokens[i], - chunked_hidden_states_sizes[i]) + router_expert_output[i], + i, + num_tokens[i], + chunked_hidden_states_sizes[i], + ) ms_metadata.ms_events[layer_index][i][ - MSEventKey. - FFN_AR_FINISH] = combine_context.comm_stream.record_event( - ) + MSEventKey.FFN_AR_FINISH + ] = combine_context.comm_stream.record_event() return hidden_states, residual @@ -368,21 +399,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size = config.vocab_size self.config = config self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - prefix=f"{prefix}.embed_tokens") + config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens" + ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Qwen3MoeDecoderLayerDBO(config=config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Qwen3MoeDecoderLayerDBO( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) # dbo related members if VLLM_ASCEND_ENABLE_DBO: @@ -394,10 +426,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): causal_lm=getattr(config, "causal_lm", True), multistream_config=self.multistream_config, ) - self.ms_pre_layer = MultiStreamPreTransformerLayer( - multistream_metadata) - self.ms_post_layer = MultiStreamPostTransformerLayer( - multistream_metadata) + self.ms_pre_layer = MultiStreamPreTransformerLayer(multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer(multistream_metadata) def forward( self, @@ -417,8 +447,11 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - num_normal_layers = (0 if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() - else self.end_layer - self.start_layer) + num_normal_layers = ( + 0 + if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() + else self.end_layer - self.start_layer + ) moe_start_layer = self.start_layer + num_normal_layers for i in range(self.start_layer, min(moe_start_layer, self.end_layer)): @@ -431,13 +464,13 @@ def forward( positions=positions, hidden_states=hidden_states, residual=residual, - moe_start_layer=moe_start_layer) + moe_start_layer=moe_start_layer, + ) if not get_pp_group().is_last_rank: - return IntermediateTensors({ - "hidden_states": hidden_states, - "residual": residual - }) + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -446,7 +479,11 @@ def can_run_ms(self): attn_metadata = get_forward_context().attn_metadata # enable prefill overlap with_prefill = get_forward_context().with_prefill - if attn_metadata is None or not with_prefill or not attn_metadata.enable_dbo_across_dp: + if ( + attn_metadata is None + or not with_prefill + or not attn_metadata.enable_dbo_across_dp + ): return False return True @@ -463,9 +500,9 @@ def _forward_ms_layers( if moe_start_layer == self.end_layer: return hidden_states, residual - attn_metadata, [positions, hidden_states, - residual] = self.ms_pre_layer( - [positions, hidden_states, residual], ) + attn_metadata, [positions, hidden_states, residual] = self.ms_pre_layer( + [positions, hidden_states, residual], + ) num_micro_batch = len(attn_metadata) # the rest layers for i in range(moe_start_layer, self.end_layer): @@ -480,14 +517,13 @@ def _forward_ms_layers( ) advance_step_multistream_layer_context() - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( - ) + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() for i in range(num_micro_batch): - ms_metadata.try_wait_event(layer_index - 1, i, - MSEventKey.FFN_AR_FINISH) + ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) - [hidden_states, - residual] = self.ms_post_layer([hidden_states, residual], ) + [hidden_states, residual] = self.ms_post_layer( + [hidden_states, residual], + ) return hidden_states, residual @@ -502,8 +538,7 @@ class CustomQwen3MoeForCausalLMDBO(Qwen3MoeForCausalLM): "gate_proj", "up_proj", ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock @@ -513,19 +548,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = CustomQwen3DBOMoEModel(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.model = CustomQwen3DBOMoEModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) + self.model.make_empty_intermediate_tensors + ) def forward(self, *args, **kwargs): if "graph_enable" in kwargs: - kwargs.pop('graph_enable') + kwargs.pop("graph_enable") return super().forward(*args, **kwargs) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index aa21455957..af09eb01cb 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -16,10 +16,12 @@ # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. -from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM import vllm.model_executor.models.qwen3_moe as qwen3 +from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM + from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock + class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { "qkv_proj": [ @@ -31,8 +33,6 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "gate_proj", "up_proj", ], - "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock - diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 3b082bb18f..6ea5d796b6 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -24,21 +24,25 @@ import torch_npu from torch import nn from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata from vllm.config import get_current_vllm_config -from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, - get_tp_group) +from vllm.distributed import ( + GroupCoordinator, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.distributed.parallel_state import get_dp_group, get_ep_group, get_tp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, - determine_expert_map) -from vllm.model_executor.layers.quantization.base_config import \ - QuantizationConfig + FusedMoE, + FusedMoEParallelConfig, + MoEConfig, + UnquantizedFusedMoEMethod, + determine_expert_map, +) from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.attention import AttentionMetadata - +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -46,48 +50,59 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) -from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, - get_ascend_soc_version, npu_stream_switch, - npu_wait_tensor) + MoEAlltoAllSeqOverLapDispatcher, + MoEDispatcherConfig, +) +from vllm_ascend.utils import ( + AscendSocVersion, + dispose_tensor, + get_ascend_soc_version, + npu_stream_switch, + npu_wait_tensor, +) VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER -def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, - max_row_per_ep_rank: int, num_tokens: int, - top_k: int) -> tuple[torch.Tensor, torch.Tensor]: +def process_topk_ids( + topk_ids: torch.Tensor, + expert_num: int, + ep_size: int, + max_row_per_ep_rank: int, + num_tokens: int, + top_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: original_total_elements = num_tokens * top_k device = topk_ids.device original_dtype = topk_ids.dtype if original_total_elements == 0: output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len, ), - expert_num, - dtype=original_dtype, - device=device) - unpad_indices = torch.full((original_total_elements, ), - -1, - dtype=torch.long, - device=device) + topk_ids_pad = torch.full( + (output_len,), expert_num, dtype=original_dtype, device=device + ) + unpad_indices = torch.full( + (original_total_elements,), -1, dtype=torch.long, device=device + ) return topk_ids_pad, unpad_indices experts_per_ep_rank_val = expert_num // ep_size if experts_per_ep_rank_val == 0: raise ValueError( "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " - "Ensure expert_num >= ep_size.") + "Ensure expert_num >= ep_size." + ) - assigned_ep_rank = (topk_ids.float() / - experts_per_ep_rank_val).to(original_dtype) + assigned_ep_rank = (topk_ids.float() / experts_per_ep_rank_val).to(original_dtype) indices_arange = torch.arange(topk_ids.shape[0], device=device) - is_new_segment = torch.cat((torch.tensor([True], device=device), - assigned_ep_rank[1:] != assigned_ep_rank[:-1])) - temp_start_markers = torch.full_like(indices_arange, - -1, - dtype=indices_arange.dtype) + is_new_segment = torch.cat( + ( + torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1], + ) + ) + temp_start_markers = torch.full_like(indices_arange, -1, dtype=indices_arange.dtype) temp_start_markers[is_new_segment] = indices_arange[is_new_segment] start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token @@ -95,24 +110,25 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) indices_in_rec_cond_list_for_all = cumsum_kept - 1 unpad_indices = torch.where( - is_kept_mask, indices_in_rec_cond_list_for_all, - torch.tensor(-1, device=device, dtype=torch.long)) + is_kept_mask, + indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long), + ) output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full((output_len, ), - expert_num, - dtype=original_dtype, - device=device) + topk_ids_pad = torch.full( + (output_len,), expert_num, dtype=original_dtype, device=device + ) if topk_ids.shape[0] > 0: - all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx - temp_pad_buffer = torch.full((output_len + 1, ), - expert_num, - dtype=original_dtype, - device=device) - output_len_tensor = torch.tensor(output_len, - dtype=torch.long, - device=device) - scatter_indices = torch.where(is_kept_mask, all_destination_indices, - output_len_tensor) + all_destination_indices = ( + assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + ) + temp_pad_buffer = torch.full( + (output_len + 1,), expert_num, dtype=original_dtype, device=device + ) + output_len_tensor = torch.tensor(output_len, dtype=torch.long, device=device) + scatter_indices = torch.where( + is_kept_mask, all_destination_indices, output_len_tensor + ) temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) topk_ids_pad = temp_pad_buffer[:output_len] return topk_ids_pad, unpad_indices @@ -139,12 +155,13 @@ def fused_experts_with_mc2( # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. - global_bs = math.ceil(get_forward_context().max_tokens_across_dp / - tp_world_size) * ep_world_size + global_bs = ( + math.ceil(get_forward_context().max_tokens_across_dp / tp_world_size) + * ep_world_size + ) # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 - or is_torchair) + need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 @@ -167,22 +184,25 @@ def fused_experts_with_mc2( "ep_rank_id": ep_rank_id, } if need_extra_args: - stage1_kwargs.update({ - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) + stage1_kwargs.update( + { + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + } + ) if a3_need_extra_args: - stage1_kwargs.update({ - "x_active_mask": mc2_mask, - }) + stage1_kwargs.update( + { + "x_active_mask": mc2_mask, + } + ) kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ - 0:5] + expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[0:5] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): @@ -239,16 +259,20 @@ def fused_experts_with_mc2( "ep_rank_id": ep_rank_id, } if need_extra_args: - stage3_kwargs.update({ - "tp_send_counts": tp_recv_counts, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) + stage3_kwargs.update( + { + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + } + ) if a3_need_extra_args: - stage3_kwargs.update({ - "x_active_mask": mc2_mask, - }) + stage3_kwargs.update( + { + "x_active_mask": mc2_mask, + } + ) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) @@ -262,11 +286,13 @@ def fused_experts_with_mc2( return hidden_states, shared_hidden_states -def apply_mlp(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - group_list_type: int = 1) -> torch.Tensor: +def apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, +) -> torch.Tensor: """ apply MLP: gate_up_proj -> swiglu -> down_proj @@ -339,58 +365,66 @@ def fused_experts_with_all2all( global_num_experts = len(expert_map) local_num_experts = global_num_experts // ep_group.world_size row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + row_idx = ( + torch.arange(0, row_idx_len, dtype=torch.int32, device=device) + .view(top_k, -1) + .permute(1, 0) + .contiguous() + ) + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens, + ) + ) - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) + global_expert_tokens = torch.bincount( + expanded_expert_idx, minlength=global_num_experts + ) + scatter_sizes = global_expert_tokens.view(ep_group.world_size, -1).sum(-1) gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, - group=ep_group.device_group) + dist.all_to_all_single(gather_sizes, scatter_sizes, group=ep_group.device_group) scatter_size_list = scatter_sizes.cpu().tolist() gather_size_list = gather_sizes.cpu().tolist() expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) + hidden_states = ep_group.all_to_all( + hidden_states, 0, 0, scatter_size_list, gather_size_list + ) + local_expert_idx = ep_group.all_to_all( + expanded_expert_idx, 0, 0, scatter_size_list, gather_size_list + ) sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) + sorted_local_expert_idx, local_num_experts + ).to(torch.int64) hidden_states = hidden_states[sorted_idx] else: row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + row_idx = ( + torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device) + .view(top_k, -1) + .permute(1, 0) + .contiguous() + ) + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens, + ) + ) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) + expanded_expert_idx, num_experts + ) expert_tokens = expert_tokens.to(torch.int64) w1 = w1.transpose(1, 2) @@ -422,9 +456,9 @@ def fused_experts_with_all2all( if expert_map is not None: resorted_idx = torch.argsort(sorted_idx) hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) + hidden_states = ep_group.all_to_all( + hidden_states, 0, 0, gather_size_list, scatter_size_list + ) final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, @@ -476,87 +510,116 @@ def fused_experts_with_all2all_buffer( global_num_experts = len(expert_map) local_num_experts = global_num_experts // ep_group.world_size row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, - device=device).view(top_k, - -1).permute(1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + row_idx = ( + torch.arange(0, row_idx_len, dtype=torch.int32, device=device) + .view(top_k, -1) + .permute(1, 0) + .contiguous() + ) + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens + ) + ) max_row_per_ep_rank = ( - -(-global_batch_size // ep_group.world_size) * max_model_len * - get_dp_group().world_size // ep_group.world_size + 1) * top_k * 2 + ( + -(-global_batch_size // ep_group.world_size) + * max_model_len + * get_dp_group().world_size + // ep_group.world_size + + 1 + ) + * top_k + * 2 + ) expert_idx_buffer_scatter, unpad_indices = process_topk_ids( - expanded_expert_idx, global_num_experts, ep_group.world_size, - max_row_per_ep_rank, num_tokens, top_k) + expanded_expert_idx, + global_num_experts, + ep_group.world_size, + max_row_per_ep_rank, + num_tokens, + top_k, + ) hidden_states_pad_idx = torch.zeros( expert_idx_buffer_scatter.shape, dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) + device=expert_idx_buffer_scatter.device, + ) non_pad_len = torch.sum( - (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) - hidden_states_pad_idx[ - expert_idx_buffer_scatter != global_num_experts] = torch.arange( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32) + ) + hidden_states_pad_idx[expert_idx_buffer_scatter != global_num_experts] = ( + torch.arange( non_pad_len, dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) + device=hidden_states.device, + ) + ) hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] expert_idx_buffer_gather = torch.empty_like( expert_idx_buffer_scatter, dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) + device=expert_idx_buffer_scatter.device, + ) hidden_states_buffer_gather = torch.empty_like( hidden_states_buffer_scatter, dtype=hidden_states_buffer_scatter.dtype, - device=hidden_states_buffer_scatter.device) - dist.all_to_all_single(expert_idx_buffer_gather, - expert_idx_buffer_scatter, - group=ep_group.device_group) - dist.all_to_all_single(hidden_states_buffer_gather, - hidden_states_buffer_scatter, - group=ep_group.device_group) + device=hidden_states_buffer_scatter.device, + ) + dist.all_to_all_single( + expert_idx_buffer_gather, expert_idx_buffer_scatter, group=ep_group.device_group + ) + dist.all_to_all_single( + hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group, + ) mask = expert_idx_buffer_gather != global_num_experts local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( - global_num_experts // ep_group.world_size) + global_num_experts // ep_group.world_size + ) hidden_states = hidden_states_buffer_gather[mask] idx_type = local_expert_idx.dtype sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) + sorted_local_expert_idx, local_num_experts + ).to(torch.int64) hidden_states = hidden_states[sorted_idx] group_list_type = 0 - hidden_states = apply_mlp(hidden_states, - w1, - w2, - expert_tokens, - group_list_type=group_list_type) + hidden_states = apply_mlp( + hidden_states, w1, w2, expert_tokens, group_list_type=group_list_type + ) resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) hidden_states = hidden_states[resorted_idx] hidden_states_scatter = torch.zeros( (mask.shape[0], hidden_states.shape[1]), dtype=hidden_states.dtype, - device=hidden_states.device) + device=hidden_states.device, + ) hidden_states_scatter[mask] = hidden_states hidden_states_gatter = torch.empty_like( hidden_states_scatter, dtype=hidden_states_scatter.dtype, - device=hidden_states_scatter.device) - dist.all_to_all_single(hidden_states_gatter, - hidden_states_scatter, - group=ep_group.device_group) + device=hidden_states_scatter.device, + ) + dist.all_to_all_single( + hidden_states_gatter, hidden_states_scatter, group=ep_group.device_group + ) hidden_states_gatter = hidden_states_gatter[ - expert_idx_buffer_scatter != global_num_experts] + expert_idx_buffer_scatter != global_num_experts + ] if hidden_states_gatter.shape[0] != row_idx_len: - hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) + hidden_states = torch.zeros( + (row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) hidden_states[unpad_indices != -1] = hidden_states_gatter else: # TODO: Reorder device memory 2 times here, replace the current @@ -576,13 +639,18 @@ def fused_experts_with_all2all_buffer( return final_hidden_states -def fused_experts_with_all2allv(token_dispatcher, probs, routing_map, - hidden_states: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor): +def fused_experts_with_all2allv( + token_dispatcher, + probs, + routing_map, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, +): # Enable moe alltoallv, it's a balanced policy for precision and efficiency. - (share_experts_output, dispatched_input, - tokens_per_expert) = token_dispatcher.token_permutation( - hidden_states, probs, routing_map) + (share_experts_output, dispatched_input, tokens_per_expert) = ( + token_dispatcher.token_permutation(hidden_states, probs, routing_map) + ) expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert) output, mlp_bias = token_dispatcher.token_unpermutation(expert_output) @@ -637,8 +705,9 @@ def fused_experts( # ], "Only float32, float16, and bfloat16 are supported" if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert ( topk == 1 @@ -647,10 +716,12 @@ def fused_experts( if expert_map is not None: # Generate token indices and flatten - token_indices = (torch.arange(num_tokens, - device=device, - dtype=torch.int64).unsqueeze(1).expand( - -1, top_k).reshape(-1)) + token_indices = ( + torch.arange(num_tokens, device=device, dtype=torch.int64) + .unsqueeze(1) + .expand(-1, top_k) + .reshape(-1) + ) # Flatten token-to-expert mappings and map to local experts weights_flat = topk_weights.view(-1) @@ -660,11 +731,11 @@ def fused_experts( # Filter valid token-expert pairs mask = local_experts_flat != -1 filtered_weights = torch.where( - mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) + mask, weights_flat, torch.zeros_like(weights_flat) + ).to(dtype) filtered_experts = torch.where( - mask, local_experts_flat, - torch.full_like(local_experts_flat, - num_experts)).to(topk_ids.dtype) + mask, local_experts_flat, torch.full_like(local_experts_flat, num_experts) + ).to(topk_ids.dtype) # Sort by local expert IDs sort_indices = torch.argsort(filtered_experts.view(torch.float32)) @@ -674,9 +745,7 @@ def fused_experts( # Compute token counts with minlength of num_experts # This is equivalent to but faster than: # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, - device=device, - dtype=torch.int64) + token_counts = torch.zeros(num_experts + 1, device=device, dtype=torch.int64) ones = torch.ones_like(filtered_experts, dtype=torch.int64) token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) token_counts = token_counts[:num_experts] @@ -686,19 +755,24 @@ def fused_experts( sorted_hidden_states = hidden_states[sorted_token_indices] else: row_idx_len = num_tokens * top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(top_k, -1).permute( - 1, 0).contiguous()) - sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + row_idx = ( + torch.arange(0, row_idx_len, dtype=torch.int32, device=device) + .view(top_k, -1) + .permute(1, 0) + .contiguous() + ) + sorted_hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens, + ) + ) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) + expanded_expert_idx, num_experts + ) expert_tokens = expert_tokens.to(torch.int64) w1 = w1.transpose(1, 2) @@ -730,24 +804,28 @@ def fused_experts( if expert_map is not None: weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) - final_hidden_states = torch.zeros(*original_shape, - device=hidden_states.device, - dtype=dtype) + final_hidden_states = torch.zeros( + *original_shape, device=hidden_states.device, dtype=dtype + ) # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] # This created multiple NaN and index_add_ will mix them up which harms accuracy # remove this mask and filter after it being fixed num_valid_tokens = mask.sum() - valid_token_mask = torch.arange( - 0, sorted_token_indices.shape[0], - device=device).unsqueeze(1) < num_valid_tokens + valid_token_mask = ( + torch.arange(0, sorted_token_indices.shape[0], device=device).unsqueeze(1) + < num_valid_tokens + ) valid_output = torch.where( - valid_token_mask, weighted_down_out, - torch.zeros_like(weighted_down_out)).to(dtype) + valid_token_mask, weighted_down_out, torch.zeros_like(weighted_down_out) + ).to(dtype) final_hidden_states.index_add_(0, sorted_token_indices, valid_output) else: - scales = torch.ones_like( - topk_weights) if apply_router_weight_on_input else topk_weights + scales = ( + torch.ones_like(topk_weights) + if apply_router_weight_on_input + else topk_weights + ) # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( @@ -772,17 +850,19 @@ def native_grouped_topk( num_expert_group = 0 if num_expert_group is None else num_expert_group num_token = topk_weights.shape[0] - grouped_weights = topk_weights.view(num_token, num_expert_group, - -1).max(dim=-1).values - topk_group_indices = torch.topk(grouped_weights.to(torch.float32), - k=topk_group, - dim=-1, - sorted=False)[1] + grouped_weights = ( + topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values + ) + topk_group_indices = torch.topk( + grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False + )[1] topk_group_mask = torch.zeros_like(grouped_weights) topk_group_mask.scatter_(1, topk_group_indices, 1) - topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( - num_token, num_expert_group, - topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) + topk_weight_mask = ( + topk_group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) return topk_weights @@ -843,21 +923,18 @@ def select_experts( # TODO: Change to npu_group_topk when the latest CANN and NNAL is available # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) - topk_weights = native_grouped_topk(topk_weights, num_expert_group, - topk_group) + topk_weights = native_grouped_topk(topk_weights, num_expert_group, topk_group) # TODO bfloat16 is not supported in torch.topk with ge graph. if e_score_correction_bias is not None: - topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False)[1] + topk_ids = torch.topk( + topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False + )[1] # Use original unbiased scores for the routing weights topk_weights = original_weights.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), - k=top_k, - dim=-1, - sorted=False) + topk_weights, topk_ids = torch.topk( + topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False + ) elif custom_routing_function is None: topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) topk_weights = topk_weights.to(hidden_states.dtype) @@ -866,7 +943,8 @@ def select_experts( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, - renormalize=renormalize) + renormalize=renormalize, + ) # Required by npu_moe_init_routing topk_ids = topk_ids.to(torch.int32) return topk_weights, topk_ids @@ -898,20 +976,18 @@ def __init__(self, moe: MoEConfig = None): # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name( - local_rank) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) except AttributeError: self.moe_all_to_all_group_name = None def process_weights_after_loading(self, layer): - super(UnquantizedFusedMoEMethod, - self).process_weights_after_loading(layer) - layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w13_weight.data), - requires_grad=False) - layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( - layer.w2_weight.data), - requires_grad=False) + super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer) + layer.w13_weight = torch.nn.Parameter( + self._maybe_pad_weight(layer.w13_weight.data), requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter( + self._maybe_pad_weight(layer.w2_weight.data), requires_grad=False + ) def apply( self, @@ -948,7 +1024,8 @@ def apply( # out_flag=False, # todo new api; 第三个输出是否输出 # y2_flag=False, # old api; 第三个输出是否输出 routed_scaling_factor=1, - eps=float(1e-20)) + eps=float(1e-20), + ) else: topk_weights, topk_ids = select_experts( hidden_states=x, @@ -984,15 +1061,18 @@ def apply( moe_all_to_all_group_name=self.moe_all_to_all_group_name, shared_experts=shared_experts, is_torchair=self.torchair_graph_enabled, - mc2_mask=mc2_mask) + mc2_mask=mc2_mask, + ) elif fused_moe_state == FusedMoEState.AllGather: - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ) elif VLLM_ASCEND_MOE_ALL2ALL_BUFFER: return fused_experts_with_all2all_buffer( hidden_states=x, @@ -1004,25 +1084,29 @@ def apply( max_model_len=self.max_model_len, global_batch_size=self.global_batch_size, expert_map=expert_map, - ep_group=get_ep_group()) + ep_group=get_ep_group(), + ) elif fused_moe_state == FusedMoEState.All2AllSeq: - token_dispatcher = kwargs.get('token_dispatcher') + token_dispatcher = kwargs.get("token_dispatcher") return fused_experts_with_all2allv( token_dispatcher=token_dispatcher, probs=topk_weights, routing_map=topk_ids, hidden_states=x, w1=layer.w13_weight, - w2=layer.w2_weight) + w2=layer.w2_weight, + ) else: - return fused_experts_with_all2all(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ep_group=get_ep_group()) + return fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=get_ep_group(), + ) class AscendFusedMoE(FusedMoE): @@ -1066,13 +1150,15 @@ def __init__( vllm_config = get_current_vllm_config() - self.moe_parallel_config: FusedMoEParallelConfig = ( - FusedMoEParallelConfig.make( - tp_size_=(tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()), - dp_size_=(dp_size if dp_size is not None else - get_dp_group().world_size), - vllm_parallel_config=vllm_config.parallel_config)) + self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_=( + tp_size + if tp_size is not None + else get_tensor_model_parallel_world_size() + ), + dp_size_=(dp_size if dp_size is not None else get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config, + ) self.top_k = top_k self.num_experts = num_experts @@ -1098,29 +1184,36 @@ def __init__( expert_map_path = ascend_config.expert_map_path if expert_map_path and os.path.exists(expert_map_path): # moe expert load balance - expert_load_balancer = ExpertLoadBalancer(expert_map_path, - self.global_num_experts) - self.local_num_experts, self.expert_map = \ - expert_load_balancer.get_rank_placement_map( - self.moe_instance_id, - self.ep_rank) + expert_load_balancer = ExpertLoadBalancer( + expert_map_path, self.global_num_experts + ) + self.local_num_experts, self.expert_map = ( + expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, self.ep_rank + ) + ) self.log2phy = expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, self.ep_rank) - self.global_redundant_expert_num = \ - expert_load_balancer.get_global_redundant_expert_num() + self.moe_instance_id, self.ep_rank + ) + self.global_redundant_expert_num = ( + expert_load_balancer.get_global_redundant_expert_num() + ) else: # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts) + self.ep_size, self.ep_rank, self.global_num_experts + ) self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe and \ - self.torchair_graph_enabled + self.enable_multistream_moe = ( + ascend_config.torchair_graph_config.enable_multistream_moe + and self.torchair_graph_enabled + ) if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError("Only softmax scoring function is supported for " - "non-grouped topk.") + raise ValueError( + "Only softmax scoring function is supported for " "non-grouped topk." + ) moe = MoEConfig( num_experts=self.global_num_experts, @@ -1139,20 +1232,24 @@ def __init__( assert self.quant_method is not None - local_num_experts = torch.sum(self.expert_map != -1) \ - if self.expert_map is not None else num_experts + local_num_experts = ( + torch.sum(self.expert_map != -1) + if self.expert_map is not None + else num_experts + ) moe_quant_params = { "num_experts": local_num_experts, "hidden_size": hidden_size, - "intermediate_size_per_partition": - self.intermediate_size_per_partition, + "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): + if self.quant_method.__class__.__name__ in ( + "GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ): moe_quant_params["intermediate_size_full"] = intermediate_size self.ep_group = get_ep_group() @@ -1161,32 +1258,39 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) self.token_dispatcher = None if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance( - self.quant_method, AscendUnquantizedFusedMoEMethod): + self.quant_method, AscendUnquantizedFusedMoEMethod + ): self.reduce_results = False moe_dispatcher_config = ( - MoEDispatcherConfig().set_num_moe_experts( - self.global_num_experts).set_num_local_experts( - self.local_num_experts).set_moe_router_topk( - top_k).set_group_topk(topk_group). - set_num_groups(num_expert_group).set_expert_bias( - e_score_correction_bias).set_scaling_factor(1.0).build()) + MoEDispatcherConfig() + .set_num_moe_experts(self.global_num_experts) + .set_num_local_experts(self.local_num_experts) + .set_moe_router_topk(top_k) + .set_group_topk(topk_group) + .set_num_groups(num_expert_group) + .set_expert_bias(e_score_correction_bias) + .set_scaling_factor(1.0) + .build() + ) self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher( - moe_dispatcher_config) + moe_dispatcher_config + ) if envs_ascend.VLLM_ASCEND_ENABLE_DBO: token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher( - moe_dispatcher_config) - self.token_dispatchers = [ - self.token_dispatcher, token_dispatcher1 - ] - - def forward(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_prefill: bool, - enable_force_load_balance: bool = False, - top_k: Optional[int] = None, - shared_experts: Optional[Any] = None, - gate: Optional[Any] = None): + moe_dispatcher_config + ) + self.token_dispatchers = [self.token_dispatcher, token_dispatcher1] + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + enable_force_load_balance: bool = False, + top_k: Optional[int] = None, + shared_experts: Optional[Any] = None, + gate: Optional[Any] = None, + ): assert self.quant_method is not None if top_k: @@ -1200,43 +1304,46 @@ def forward(self, fused_moe_state = get_forward_context().fused_moe_state # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None - from vllm_ascend.quantization.w8a8_dynamic import \ - AscendW8A8DynamicFusedMoEMethod + from vllm_ascend.quantization.w8a8_dynamic import ( + AscendW8A8DynamicFusedMoEMethod, + ) + if self.enable_multistream_moe: assert gate is not None router_logits, _ = gate(hidden_states) - if isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod - ) and fused_moe_state == FusedMoEState.MC2: + if ( + isinstance( + self.quant_method.quant_method, AscendW8A8DynamicFusedMoEMethod + ) + and fused_moe_state == FusedMoEState.MC2 + ): with npu_stream_switch("moe_secondary", 0): - quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( - hidden_states) + quantized_x_for_share, dynamic_scale_for_share = ( + torch_npu.npu_dynamic_quant(hidden_states) + ) if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) - mc2_mask = forward_context.mc2_mask tp_size = get_tensor_model_parallel_world_size() if fused_moe_state != FusedMoEState.AllGather: if num_tokens < forward_context.padded_num_tokens: hidden_states = nn.functional.pad( hidden_states, - (0, 0, 0, forward_context.padded_num_tokens - num_tokens)) + (0, 0, 0, forward_context.padded_num_tokens - num_tokens), + ) router_logits = nn.functional.pad( router_logits, - (0, 0, 0, forward_context.padded_num_tokens - num_tokens)) + (0, 0, 0, forward_context.padded_num_tokens - num_tokens), + ) if tp_size > 1: - chunk_hidden_states = torch.tensor_split(hidden_states, - tp_size, - dim=0) - chunk_router_logits = torch.tensor_split(router_logits, - tp_size, - dim=0) - chunk_mc2_mask = torch.tensor_split(forward_context.mc2_mask, - tp_size, - dim=0) + chunk_hidden_states = torch.tensor_split(hidden_states, tp_size, dim=0) + chunk_router_logits = torch.tensor_split(router_logits, tp_size, dim=0) + chunk_mc2_mask = torch.tensor_split( + forward_context.mc2_mask, tp_size, dim=0 + ) tp_rank = get_tensor_model_parallel_rank() hidden_states = chunk_hidden_states[tp_rank] router_logits = chunk_router_logits[tp_rank] @@ -1245,15 +1352,14 @@ def forward(self, if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: - max_num_tokens_across_dp = get_forward_context( - ).max_tokens_across_dp + max_num_tokens_across_dp = get_forward_context().max_tokens_across_dp if num_tokens < max_num_tokens_across_dp: hidden_states = nn.functional.pad( - hidden_states, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + hidden_states, (0, 0, 0, max_num_tokens_across_dp - num_tokens) + ) router_logits = nn.functional.pad( - router_logits, - (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + router_logits, (0, 0, 0, max_num_tokens_across_dp - num_tokens) + ) hidden_states = get_dp_group().all_gather(hidden_states, 0) router_logits = get_dp_group().all_gather(router_logits, 0) @@ -1276,38 +1382,40 @@ def forward(self, enable_force_load_balance=enable_force_load_balance, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, - shared_experts=shared_experts if self.torchair_graph_enabled - and self.enable_multistream_moe and not is_prefill else None, + shared_experts=( + shared_experts + if self.torchair_graph_enabled + and self.enable_multistream_moe + and not is_prefill + else None + ), quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=mc2_mask, - token_dispatcher=self.token_dispatcher) + token_dispatcher=self.token_dispatcher, + ) if shared_experts: if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: - dist.all_gather(list(chunk_hidden_states), e_hidden_states, - self.tp_group) + dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0) if num_tokens < forward_context.padded_num_tokens: final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( - e_hidden_states, - "sum", - scatter_dim=0, - group=get_dp_group().device_group) + e_hidden_states, "sum", scatter_dim=0, group=get_dp_group().device_group + ) final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) else: final_hidden_states = e_hidden_states if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) if shared_experts: return final_hidden_states, shared_hidden_states @@ -1339,7 +1447,8 @@ def _forward_ms_fused_moe_comp( scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, is_prefill=is_prefill, - enable_force_load_balance=enable_force_load_balance) + enable_force_load_balance=enable_force_load_balance, + ) return hidden_states @@ -1357,18 +1466,22 @@ def __init__( if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") + f"the number of experts {config.num_experts}." + ) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = \ + self.enable_multistream_moe = ( ascend_config.torchair_graph_config.enable_multistream_moe + ) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) self.experts = AscendFusedMoE( num_experts=config.num_experts, @@ -1378,7 +1491,8 @@ def __init__( reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, - prefix=f"{prefix}.experts") + prefix=f"{prefix}.experts", + ) self.top_k = config.num_experts_per_tok @@ -1391,9 +1505,10 @@ def __init__( self.params_dtype = torch.get_default_dtype() def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: if attn_metadata is None: attn_metadata = get_forward_context().attn_metadata # when profile runs, force experts to load balanced tokens @@ -1413,4 +1528,4 @@ def forward( shared_experts=None, ) - return hidden_states \ No newline at end of file + return hidden_states From 80b1d0d257eb1f824bdffdb104e7697be8a30ba0 Mon Sep 17 00:00:00 2001 From: weijinqian Date: Fri, 11 Jul 2025 22:15:08 +0800 Subject: [PATCH 49/60] handle clean code Signed-off-by: weijinqian_v1 --- tests/ut/test_token_dispatcher.py | 10 +- vllm_ascend/models/qwen3_dbo.py | 241 ++++++------- vllm_ascend/ops/fused_moe.py | 563 ++++++++++++++---------------- 3 files changed, 372 insertions(+), 442 deletions(-) diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py index 15b1ac15fb..18768a7fe8 100644 --- a/tests/ut/test_token_dispatcher.py +++ b/tests/ut/test_token_dispatcher.py @@ -15,17 +15,16 @@ # limitations under the License. # This file is a part of the vllm-ascend project. -import pytest import unittest + +import pytest from pytest_mock import MockerFixture from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.utils import adapt_patch # noqa E402 -import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa - - +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa adapt_patch(True) @@ -61,7 +60,8 @@ def dispatcher(self, config, mocker: MockerFixture): return MoEAlltoAllSeqOverLapDispatcher(config) def test_initialization(self, dispatcher, config): - self.assertEqual(dispatcher.num_local_experts, config.num_local_experts) + self.assertEqual(dispatcher.num_local_experts, + config.num_local_experts) self.assertEqual(dispatcher.num_experts, config.num_moe_experts) self.assertEqual(dispatcher.local_expert_indices, [0, 1]) self.assertEqual(dispatcher.ep_rank, 0) diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py index 4e7dc12df7..fa87fe81f2 100644 --- a/vllm_ascend/models/qwen3_dbo.py +++ b/vllm_ascend/models/qwen3_dbo.py @@ -27,48 +27,35 @@ from vllm.attention import AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import ( - get_pp_group, - get_tensor_model_parallel_world_size, - get_tp_group, -) +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from vllm.model_executor.models.qwen3_moe import ( - Qwen3MoeDecoderLayer, - Qwen3MoeForCausalLM, - Qwen3MoeModel, -) + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeModel) from vllm.model_executor.models.utils import ( - make_empty_intermediate_tensors_factory, - make_layers, - maybe_prefix, -) + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend -from vllm_ascend.distributed.tensor_parallel import gather_from_sequence_parallel_region +from vllm_ascend.distributed.tensor_parallel import \ + gather_from_sequence_parallel_region from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( - advance_step_multistream_layer_context, - get_multistream_layer_context, -) -from vllm_ascend.multistream.layers import ( - MultiStreamPostTransformerLayer, - MultiStreamPreTransformerLayer, -) -from vllm_ascend.multistream.metadata import ( - MultiStreamConfig, - MultiStreamStepMetadata, - make_multistream_metadata_ds, -) -from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock, apply_mlp, select_experts + advance_step_multistream_layer_context, get_multistream_layer_context) +from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, + MultiStreamPreTransformerLayer) +from vllm_ascend.multistream.metadata import (MultiStreamConfig, + MultiStreamStepMetadata, + make_multistream_metadata_ds) +from vllm_ascend.ops.fused_moe import (AscendSparseMoeBlock, apply_mlp, + select_experts) VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO @@ -82,19 +69,14 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: - super(Qwen3MoeDecoderLayerDBO, self).__init__( - config, cache_config, quant_config, prefix - ) + super(Qwen3MoeDecoderLayerDBO, self).__init__(config, cache_config, + quant_config, prefix) self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group self.tp_group = get_tp_group().device_group self.dummy_vllm_config = SimpleNamespace( - parallel_config=SimpleNamespace( - data_parallel_size=1, - ), - compilation_config=SimpleNamespace( - static_forward_context=None, - ), + parallel_config=SimpleNamespace(data_parallel_size=1, ), + compilation_config=SimpleNamespace(static_forward_context=None, ), other_setting="value", ) self.config = config @@ -112,7 +94,8 @@ def _forward_ms_op_input_layernorm( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states, residual = self.input_layernorm( + hidden_states, residual) return hidden_states, residual def _forward_ms_op_attn( @@ -124,8 +107,7 @@ def _forward_ms_op_attn( attn_metadata: Optional[AttentionMetadata] = None, ) -> tuple[torch.Tensor, torch.Tensor]: self.dummy_vllm_config.compilation_config.static_forward_context = ( - get_forward_context().no_compile_layers - ) + get_forward_context().no_compile_layers) with set_forward_context(attn_metadata, self.dummy_vllm_config): hidden_states = self.self_attn( positions=positions, @@ -147,7 +129,8 @@ def _forward_ms_op_post_attn_layernorm( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ): - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) return hidden_states, residual def _forward_op_gating( @@ -168,10 +151,13 @@ def _forward_op_gating( num_tokens, hidden_size = hidden_states.shape if num_tokens < self.tp_size: hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, self.tp_size - num_tokens) - ) - chunk_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0) - chunked_hidden_states_sizes = [x.shape[0] for x in chunk_hidden_states] + hidden_states, (0, 0, 0, self.tp_size - num_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + chunked_hidden_states_sizes = [ + x.shape[0] for x in chunk_hidden_states + ] local_hidden_states = chunk_hidden_states[self.tp_rank] else: local_hidden_states = hidden_states @@ -206,9 +192,9 @@ def _forward_op_gating( num_expert_group=getattr(mlp_config, "n_group", None), custom_routing_function=None, scoring_func=getattr(mlp_config, "scoring_func", "softmax"), - e_score_correction_bias=getattr( - self.mlp.gate, "e_score_correction_bias", None - ), + e_score_correction_bias=getattr(self.mlp.gate, + "e_score_correction_bias", + None), ) topk_weights = topk_weights.to(hidden_states.dtype) @@ -228,26 +214,27 @@ def _forward_op_grouped_mlp(self, dispatched_input, tokens_per_expert): tokens_per_expert, ) - def _forward_combine_comm( - self, hidden_states, microbatch_id, num_tokens, chunked_hidden_states_sizes - ): + def _forward_combine_comm(self, hidden_states, microbatch_id, num_tokens, + chunked_hidden_states_sizes): token_dispatcher = self.mlp.experts.token_dispatchers[microbatch_id] - final_hidden_states, _ = token_dispatcher.token_unpermutation(hidden_states) + final_hidden_states, _ = token_dispatcher.token_unpermutation( + hidden_states) if hasattr(self.mlp, "routed_scaling_factor"): final_hidden_states = final_hidden_states * self.mlp.routed_scaling_factor if self.tp_size > 1: final_hidden_states = gather_from_sequence_parallel_region( - final_hidden_states, self.tp_group, chunked_hidden_states_sizes - ) + final_hidden_states, self.tp_group, + chunked_hidden_states_sizes) if num_tokens < self.tp_size: final_hidden_states = final_hidden_states[:num_tokens] if hasattr(self.mlp, "shared_experts"): final_hidden_states = ( - final_hidden_states + token_dispatcher.cached_shared_expert_output - ) - token_dispatcher.cached_shared_expert_output.untyped_storage().resize_(0) + final_hidden_states + + token_dispatcher.cached_shared_expert_output) + token_dispatcher.cached_shared_expert_output.untyped_storage( + ).resize_(0) token_dispatcher.cached_shared_expert_output = None final_hidden_states = final_hidden_states.view(num_tokens, -1) @@ -262,7 +249,8 @@ def _forward_ms_layer_alltoallv_finegrained( attn_metadata: List[AttentionMetadata], kv_cache: Optional[torch.Tensor] = None, ): - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) assert layer_index >= 0 and ms_metadata is not None num_micro_batchs = ms_metadata.ms_config.num_micro_batches assert len(positions) == num_micro_batchs @@ -271,7 +259,9 @@ def _forward_ms_layer_alltoallv_finegrained( assert attn_metadata is not None num_tokens = [None] * num_micro_batchs hidden_dims = [None] * num_micro_batchs - topk_weights, topk_ids = [None] * num_micro_batchs, [None] * num_micro_batchs + topk_weights, topk_ids = [None] * num_micro_batchs, [ + None + ] * num_micro_batchs tokens_per_expert = [None] * num_micro_batchs dispatched_input = [None] * num_micro_batchs router_expert_output = [None] * num_micro_batchs @@ -292,22 +282,24 @@ def discard_tensor(tensor): # can be overlapped with the attn communication of microbatch 1 for i in range(num_micro_batchs): forward_context = get_forward_context() - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() - ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) forward_context.attn_metadata = attn_metadata[i] # input layernorm - hidden_states[i], residual[i] = self._forward_ms_op_input_layernorm( - hidden_states[i], residual[i] - ) + hidden_states[i], residual[ + i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i]) # attention and tp allreduce hidden_states[i], residual[i] = self._forward_ms_op_attn( - positions[i], hidden_states[i], residual[i], kv_cache, attn_metadata[i] - ) + positions[i], hidden_states[i], residual[i], kv_cache, + attn_metadata[i]) # post attention layer norm - hidden_states[i], residual[i] = self._forward_ms_op_post_attn_layernorm( - hidden_states[i], residual[i] - ) + hidden_states[i], residual[ + i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i]) num_tokens[i], hidden_dims[i] = hidden_states[i].shape # If TP is enabled, hidden_states will be chunked. ( @@ -327,51 +319,45 @@ def discard_tensor(tensor): dispatch_context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_BEFORE_COMM - ], + MSEventKey.MOE_BEFORE_COMM], after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.MOE_AFTER_COMM - ], + MSEventKey.MOE_AFTER_COMM], ) dispatch_context.before_comm_event.record() # print_with_sync(f'begin token dispatch{i}...', torch.distributed.get_rank()) with torch.npu.stream(dispatch_context.comm_stream): dispatch_context.comm_stream.wait_event( - dispatch_context.before_comm_event - ) + dispatch_context.before_comm_event) token_dispatchers[i].dispatch_alltoall() dispatched_input[i], tokens_per_expert[i] = token_dispatchers[ - i - ].permute2() + i].permute2() dispatch_context.after_comm_event.record() # print_with_sync('begin experts...', torch.distributed.get_rank()) # block 4 : Router Experts Computation # block 5 : Token Combine Communication for i in range(num_micro_batchs): - ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_AFTER_COMM) + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_AFTER_COMM) discard_tensor(hidden_states[i]) router_expert_output[i] = self._forward_op_grouped_mlp( - dispatched_input[i], tokens_per_expert[i] - ) + dispatched_input[i], tokens_per_expert[i]) discard_tensor(dispatched_input[i]) # Launch Combine Comm in a New Stream. combine_context = MultiStreamStepMetadata( comm_stream=ms_metadata.communicate_stream, before_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_COM_FINISH - ], + MSEventKey.FFN_COM_FINISH], after_comm_event=ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_AR_FINISH - ], + MSEventKey.FFN_AR_FINISH], ) combine_context.before_comm_event.record() - ms_metadata.try_wait_event(layer_index, i, MSEventKey.MOE_SE_COMM_FINISH) + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_SE_COMM_FINISH) with torch.npu.stream(combine_context.comm_stream): combine_context.comm_stream.wait_event( - combine_context.before_comm_event - ) + combine_context.before_comm_event) hidden_states[i] = self._forward_combine_comm( router_expert_output[i], i, @@ -379,8 +365,9 @@ def discard_tensor(tensor): chunked_hidden_states_sizes[i], ) ms_metadata.ms_events[layer_index][i][ - MSEventKey.FFN_AR_FINISH - ] = combine_context.comm_stream.record_event() + MSEventKey. + FFN_AR_FINISH] = combine_context.comm_stream.record_event( + ) return hidden_states, residual @@ -399,8 +386,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size = config.vocab_size self.config = config self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, config.hidden_size, prefix=f"{prefix}.embed_tokens" - ) + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Qwen3MoeDecoderLayerDBO( @@ -413,8 +401,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size - ) + ["hidden_states", "residual"], config.hidden_size) # dbo related members if VLLM_ASCEND_ENABLE_DBO: @@ -426,8 +413,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): causal_lm=getattr(config, "causal_lm", True), multistream_config=self.multistream_config, ) - self.ms_pre_layer = MultiStreamPreTransformerLayer(multistream_metadata) - self.ms_post_layer = MultiStreamPostTransformerLayer(multistream_metadata) + self.ms_pre_layer = MultiStreamPreTransformerLayer( + multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer( + multistream_metadata) def forward( self, @@ -447,11 +436,8 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - num_normal_layers = ( - 0 - if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() - else self.end_layer - self.start_layer - ) + num_normal_layers = (0 if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() + else self.end_layer - self.start_layer) moe_start_layer = self.start_layer + num_normal_layers for i in range(self.start_layer, min(moe_start_layer, self.end_layer)): @@ -468,9 +454,10 @@ def forward( ) if not get_pp_group().is_last_rank: - return IntermediateTensors( - {"hidden_states": hidden_states, "residual": residual} - ) + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -479,11 +466,8 @@ def can_run_ms(self): attn_metadata = get_forward_context().attn_metadata # enable prefill overlap with_prefill = get_forward_context().with_prefill - if ( - attn_metadata is None - or not with_prefill - or not attn_metadata.enable_dbo_across_dp - ): + if (attn_metadata is None or not with_prefill + or not attn_metadata.enable_dbo_across_dp): return False return True @@ -500,9 +484,9 @@ def _forward_ms_layers( if moe_start_layer == self.end_layer: return hidden_states, residual - attn_metadata, [positions, hidden_states, residual] = self.ms_pre_layer( - [positions, hidden_states, residual], - ) + attn_metadata, [positions, hidden_states, + residual] = self.ms_pre_layer( + [positions, hidden_states, residual], ) num_micro_batch = len(attn_metadata) # the rest layers for i in range(moe_start_layer, self.end_layer): @@ -517,13 +501,14 @@ def _forward_ms_layers( ) advance_step_multistream_layer_context() - layer_index, ms_metadata, attn_metadata = get_multistream_layer_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) for i in range(num_micro_batch): - ms_metadata.try_wait_event(layer_index - 1, i, MSEventKey.FFN_AR_FINISH) + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) - [hidden_states, residual] = self.ms_post_layer( - [hidden_states, residual], - ) + [hidden_states, + residual] = self.ms_post_layer([hidden_states, residual], ) return hidden_states, residual @@ -538,7 +523,8 @@ class CustomQwen3MoeForCausalLMDBO(Qwen3MoeForCausalLM): "gate_proj", "up_proj", ], - "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock @@ -548,18 +534,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = CustomQwen3DBOMoEModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) + self.model = CustomQwen3DBOMoEModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) + self.model.make_empty_intermediate_tensors) def forward(self, *args, **kwargs): if "graph_enable" in kwargs: diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6ea5d796b6..d3983af245 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -26,23 +26,18 @@ from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import get_current_vllm_config -from vllm.distributed import ( - GroupCoordinator, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) -from vllm.distributed.parallel_state import get_dp_group, get_ep_group, get_tp_group +from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, - FusedMoEParallelConfig, - MoEConfig, - UnquantizedFusedMoEMethod, - determine_expert_map, -) + FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, + determine_expert_map) from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config @@ -50,16 +45,10 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( - MoEAlltoAllSeqOverLapDispatcher, - MoEDispatcherConfig, -) -from vllm_ascend.utils import ( - AscendSocVersion, - dispose_tensor, - get_ascend_soc_version, - npu_stream_switch, - npu_wait_tensor, -) + MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) +from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, + get_ascend_soc_version, npu_stream_switch, + npu_wait_tensor) VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER @@ -78,31 +67,33 @@ def process_topk_ids( if original_total_elements == 0: output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full( - (output_len,), expert_num, dtype=original_dtype, device=device - ) - unpad_indices = torch.full( - (original_total_elements,), -1, dtype=torch.long, device=device - ) + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + unpad_indices = torch.full((original_total_elements, ), + -1, + dtype=torch.long, + device=device) return topk_ids_pad, unpad_indices experts_per_ep_rank_val = expert_num // ep_size if experts_per_ep_rank_val == 0: raise ValueError( "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " - "Ensure expert_num >= ep_size." - ) + "Ensure expert_num >= ep_size.") - assigned_ep_rank = (topk_ids.float() / experts_per_ep_rank_val).to(original_dtype) + assigned_ep_rank = (topk_ids.float() / + experts_per_ep_rank_val).to(original_dtype) indices_arange = torch.arange(topk_ids.shape[0], device=device) - is_new_segment = torch.cat( - ( - torch.tensor([True], device=device), - assigned_ep_rank[1:] != assigned_ep_rank[:-1], - ) - ) - temp_start_markers = torch.full_like(indices_arange, -1, dtype=indices_arange.dtype) + is_new_segment = torch.cat(( + torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1], + )) + temp_start_markers = torch.full_like(indices_arange, + -1, + dtype=indices_arange.dtype) temp_start_markers[is_new_segment] = indices_arange[is_new_segment] start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token @@ -115,20 +106,22 @@ def process_topk_ids( torch.tensor(-1, device=device, dtype=torch.long), ) output_len = ep_size * max_row_per_ep_rank - topk_ids_pad = torch.full( - (output_len,), expert_num, dtype=original_dtype, device=device - ) + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) if topk_ids.shape[0] > 0: - all_destination_indices = ( - assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx - ) - temp_pad_buffer = torch.full( - (output_len + 1,), expert_num, dtype=original_dtype, device=device - ) - output_len_tensor = torch.tensor(output_len, dtype=torch.long, device=device) - scatter_indices = torch.where( - is_kept_mask, all_destination_indices, output_len_tensor - ) + all_destination_indices = (assigned_ep_rank * max_row_per_ep_rank + + token_intra_ep_rank_idx) + temp_pad_buffer = torch.full((output_len + 1, ), + expert_num, + dtype=original_dtype, + device=device) + output_len_tensor = torch.tensor(output_len, + dtype=torch.long, + device=device) + scatter_indices = torch.where(is_kept_mask, all_destination_indices, + output_len_tensor) temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) topk_ids_pad = temp_pad_buffer[:output_len] return topk_ids_pad, unpad_indices @@ -156,12 +149,12 @@ def fused_experts_with_mc2( # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. global_bs = ( - math.ceil(get_forward_context().max_tokens_across_dp / tp_world_size) - * ep_world_size - ) + math.ceil(get_forward_context().max_tokens_across_dp / tp_world_size) * + ep_world_size) # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 or is_torchair + need_extra_args = get_ascend_soc_version( + ) == AscendSocVersion.A3 or is_torchair # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 @@ -184,25 +177,22 @@ def fused_experts_with_mc2( "ep_rank_id": ep_rank_id, } if need_extra_args: - stage1_kwargs.update( - { - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - } - ) + stage1_kwargs.update({ + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) if a3_need_extra_args: - stage1_kwargs.update( - { - "x_active_mask": mc2_mask, - } - ) + stage1_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[0:5] + expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ + 0:5] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): @@ -259,20 +249,16 @@ def fused_experts_with_mc2( "ep_rank_id": ep_rank_id, } if need_extra_args: - stage3_kwargs.update( - { - "tp_send_counts": tp_recv_counts, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - } - ) + stage3_kwargs.update({ + "tp_send_counts": tp_recv_counts, + "group_tp": moe_all_to_all_group_name, + "tp_world_size": 1, + "tp_rank_id": 0, + }) if a3_need_extra_args: - stage3_kwargs.update( - { - "x_active_mask": mc2_mask, - } - ) + stage3_kwargs.update({ + "x_active_mask": mc2_mask, + }) kwargs_mc2.update(stage3_kwargs) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) @@ -365,66 +351,62 @@ def fused_experts_with_all2all( global_num_experts = len(expert_map) local_num_experts = global_num_experts // ep_group.world_size row_idx_len = num_tokens * top_k - row_idx = ( - torch.arange(0, row_idx_len, dtype=torch.int32, device=device) - .view(top_k, -1) - .permute(1, 0) - .contiguous() - ) + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(top_k, -1).permute( + 1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = ( torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens, - ) - ) + )) - global_expert_tokens = torch.bincount( - expanded_expert_idx, minlength=global_num_experts - ) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, -1).sum(-1) + global_expert_tokens = torch.bincount(expanded_expert_idx, + minlength=global_num_experts) + scatter_sizes = global_expert_tokens.view(ep_group.world_size, + -1).sum(-1) gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, scatter_sizes, group=ep_group.device_group) + dist.all_to_all_single(gather_sizes, + scatter_sizes, + group=ep_group.device_group) scatter_size_list = scatter_sizes.cpu().tolist() gather_size_list = gather_sizes.cpu().tolist() expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all( - hidden_states, 0, 0, scatter_size_list, gather_size_list - ) - local_expert_idx = ep_group.all_to_all( - expanded_expert_idx, 0, 0, scatter_size_list, gather_size_list - ) + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + scatter_size_list, + gather_size_list) + local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, + scatter_size_list, + gather_size_list) sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts - ).to(torch.int64) + sorted_local_expert_idx, local_num_experts).to(torch.int64) hidden_states = hidden_states[sorted_idx] else: row_idx_len = num_tokens * top_k - row_idx = ( - torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device) - .view(top_k, -1) - .permute(1, 0) - .contiguous() - ) + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = ( torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens, - ) - ) + )) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts - ) + expanded_expert_idx, num_experts) expert_tokens = expert_tokens.to(torch.int64) w1 = w1.transpose(1, 2) @@ -456,9 +438,9 @@ def fused_experts_with_all2all( if expert_map is not None: resorted_idx = torch.argsort(sorted_idx) hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all( - hidden_states, 0, 0, gather_size_list, scatter_size_list - ) + hidden_states = ep_group.all_to_all(hidden_states, 0, 0, + gather_size_list, + scatter_size_list) final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, @@ -510,29 +492,18 @@ def fused_experts_with_all2all_buffer( global_num_experts = len(expert_map) local_num_experts = global_num_experts // ep_group.world_size row_idx_len = num_tokens * top_k - row_idx = ( - torch.arange(0, row_idx_len, dtype=torch.int32, device=device) - .view(top_k, -1) - .permute(1, 0) - .contiguous() - ) + row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, + device=device).view(top_k, + -1).permute(1, 0).contiguous()) hidden_states, expanded_row_idx, expanded_expert_idx = ( - torch_npu.npu_moe_init_routing( - hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens - ) - ) + torch_npu.npu_moe_init_routing(hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens)) max_row_per_ep_rank = ( - ( - -(-global_batch_size // ep_group.world_size) - * max_model_len - * get_dp_group().world_size - // ep_group.world_size - + 1 - ) - * top_k - * 2 - ) + (-(-global_batch_size // ep_group.world_size) * max_model_len * + get_dp_group().world_size // ep_group.world_size + 1) * top_k * 2) expert_idx_buffer_scatter, unpad_indices = process_topk_ids( expanded_expert_idx, global_num_experts, @@ -546,16 +517,14 @@ def fused_experts_with_all2all_buffer( dtype=expert_idx_buffer_scatter.dtype, device=expert_idx_buffer_scatter.device, ) - non_pad_len = torch.sum( - (expert_idx_buffer_scatter != global_num_experts).to(torch.int32) - ) + non_pad_len = torch.sum((expert_idx_buffer_scatter + != global_num_experts).to(torch.int32)) hidden_states_pad_idx[expert_idx_buffer_scatter != global_num_experts] = ( torch.arange( non_pad_len, dtype=expert_idx_buffer_scatter.dtype, device=hidden_states.device, - ) - ) + )) hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] expert_idx_buffer_gather = torch.empty_like( @@ -568,9 +537,9 @@ def fused_experts_with_all2all_buffer( dtype=hidden_states_buffer_scatter.dtype, device=hidden_states_buffer_scatter.device, ) - dist.all_to_all_single( - expert_idx_buffer_gather, expert_idx_buffer_scatter, group=ep_group.device_group - ) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, + group=ep_group.device_group) dist.all_to_all_single( hidden_states_buffer_gather, hidden_states_buffer_scatter, @@ -578,22 +547,22 @@ def fused_experts_with_all2all_buffer( ) mask = expert_idx_buffer_gather != global_num_experts local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( - global_num_experts // ep_group.world_size - ) + global_num_experts // ep_group.world_size) hidden_states = hidden_states_buffer_gather[mask] idx_type = local_expert_idx.dtype sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts - ).to(torch.int64) + sorted_local_expert_idx, local_num_experts).to(torch.int64) hidden_states = hidden_states[sorted_idx] group_list_type = 0 - hidden_states = apply_mlp( - hidden_states, w1, w2, expert_tokens, group_list_type=group_list_type - ) + hidden_states = apply_mlp(hidden_states, + w1, + w2, + expert_tokens, + group_list_type=group_list_type) resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) hidden_states = hidden_states[resorted_idx] @@ -608,12 +577,11 @@ def fused_experts_with_all2all_buffer( dtype=hidden_states_scatter.dtype, device=hidden_states_scatter.device, ) - dist.all_to_all_single( - hidden_states_gatter, hidden_states_scatter, group=ep_group.device_group - ) - hidden_states_gatter = hidden_states_gatter[ - expert_idx_buffer_scatter != global_num_experts - ] + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter != + global_num_experts] if hidden_states_gatter.shape[0] != row_idx_len: hidden_states = torch.zeros( (row_idx_len, hidden_states.shape[1]), @@ -648,9 +616,9 @@ def fused_experts_with_all2allv( w2: torch.Tensor, ): # Enable moe alltoallv, it's a balanced policy for precision and efficiency. - (share_experts_output, dispatched_input, tokens_per_expert) = ( - token_dispatcher.token_permutation(hidden_states, probs, routing_map) - ) + (share_experts_output, dispatched_input, + tokens_per_expert) = (token_dispatcher.token_permutation( + hidden_states, probs, routing_map)) expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert) output, mlp_bias = token_dispatcher.token_unpermutation(expert_output) @@ -705,9 +673,8 @@ def fused_experts( # ], "Only float32, float16, and bfloat16 are supported" if apply_router_weight_on_input: - assert ( - topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" + assert (topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert ( topk == 1 @@ -716,12 +683,10 @@ def fused_experts( if expert_map is not None: # Generate token indices and flatten - token_indices = ( - torch.arange(num_tokens, device=device, dtype=torch.int64) - .unsqueeze(1) - .expand(-1, top_k) - .reshape(-1) - ) + token_indices = (torch.arange(num_tokens, + device=device, + dtype=torch.int64).unsqueeze(1).expand( + -1, top_k).reshape(-1)) # Flatten token-to-expert mappings and map to local experts weights_flat = topk_weights.view(-1) @@ -731,11 +696,11 @@ def fused_experts( # Filter valid token-expert pairs mask = local_experts_flat != -1 filtered_weights = torch.where( - mask, weights_flat, torch.zeros_like(weights_flat) - ).to(dtype) + mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) filtered_experts = torch.where( - mask, local_experts_flat, torch.full_like(local_experts_flat, num_experts) - ).to(topk_ids.dtype) + mask, local_experts_flat, + torch.full_like(local_experts_flat, + num_experts)).to(topk_ids.dtype) # Sort by local expert IDs sort_indices = torch.argsort(filtered_experts.view(torch.float32)) @@ -745,7 +710,9 @@ def fused_experts( # Compute token counts with minlength of num_experts # This is equivalent to but faster than: # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, device=device, dtype=torch.int64) + token_counts = torch.zeros(num_experts + 1, + device=device, + dtype=torch.int64) ones = torch.ones_like(filtered_experts, dtype=torch.int64) token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) token_counts = token_counts[:num_experts] @@ -755,24 +722,21 @@ def fused_experts( sorted_hidden_states = hidden_states[sorted_token_indices] else: row_idx_len = num_tokens * top_k - row_idx = ( - torch.arange(0, row_idx_len, dtype=torch.int32, device=device) - .view(top_k, -1) - .permute(1, 0) - .contiguous() - ) + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=device).view(top_k, -1).permute( + 1, 0).contiguous()) sorted_hidden_states, expanded_row_idx, expanded_expert_idx = ( torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens, - ) - ) + )) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts - ) + expanded_expert_idx, num_experts) expert_tokens = expert_tokens.to(torch.int64) w1 = w1.transpose(1, 2) @@ -804,28 +768,25 @@ def fused_experts( if expert_map is not None: weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) - final_hidden_states = torch.zeros( - *original_shape, device=hidden_states.device, dtype=dtype - ) + final_hidden_states = torch.zeros(*original_shape, + device=hidden_states.device, + dtype=dtype) # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] # This created multiple NaN and index_add_ will mix them up which harms accuracy # remove this mask and filter after it being fixed num_valid_tokens = mask.sum() - valid_token_mask = ( - torch.arange(0, sorted_token_indices.shape[0], device=device).unsqueeze(1) - < num_valid_tokens - ) + valid_token_mask = (torch.arange(0, + sorted_token_indices.shape[0], + device=device).unsqueeze(1) + < num_valid_tokens) valid_output = torch.where( - valid_token_mask, weighted_down_out, torch.zeros_like(weighted_down_out) - ).to(dtype) + valid_token_mask, weighted_down_out, + torch.zeros_like(weighted_down_out)).to(dtype) final_hidden_states.index_add_(0, sorted_token_indices, valid_output) else: - scales = ( - torch.ones_like(topk_weights) - if apply_router_weight_on_input - else topk_weights - ) + scales = (torch.ones_like(topk_weights) + if apply_router_weight_on_input else topk_weights) # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( @@ -850,19 +811,17 @@ def native_grouped_topk( num_expert_group = 0 if num_expert_group is None else num_expert_group num_token = topk_weights.shape[0] - grouped_weights = ( - topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values - ) - topk_group_indices = torch.topk( - grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False - )[1] + grouped_weights = (topk_weights.view(num_token, num_expert_group, + -1).max(dim=-1).values) + topk_group_indices = torch.topk(grouped_weights.to(torch.float32), + k=topk_group, + dim=-1, + sorted=False)[1] topk_group_mask = torch.zeros_like(grouped_weights) topk_group_mask.scatter_(1, topk_group_indices, 1) - topk_weight_mask = ( - topk_group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) + topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) return topk_weights @@ -923,18 +882,21 @@ def select_experts( # TODO: Change to npu_group_topk when the latest CANN and NNAL is available # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) - topk_weights = native_grouped_topk(topk_weights, num_expert_group, topk_group) + topk_weights = native_grouped_topk(topk_weights, num_expert_group, + topk_group) # TODO bfloat16 is not supported in torch.topk with ge graph. if e_score_correction_bias is not None: - topk_ids = torch.topk( - topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False - )[1] + topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False)[1] # Use original unbiased scores for the routing weights topk_weights = original_weights.gather(1, topk_ids) else: - topk_weights, topk_ids = torch.topk( - topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False - ) + topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), + k=top_k, + dim=-1, + sorted=False) elif custom_routing_function is None: topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) topk_weights = topk_weights.to(hidden_states.dtype) @@ -976,18 +938,20 @@ def __init__(self, moe: MoEConfig = None): # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) + self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + local_rank) except AttributeError: self.moe_all_to_all_group_name = None def process_weights_after_loading(self, layer): - super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer) - layer.w13_weight = torch.nn.Parameter( - self._maybe_pad_weight(layer.w13_weight.data), requires_grad=False - ) - layer.w2_weight = torch.nn.Parameter( - self._maybe_pad_weight(layer.w2_weight.data), requires_grad=False - ) + super(UnquantizedFusedMoEMethod, + self).process_weights_after_loading(layer) + layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w13_weight.data), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( + layer.w2_weight.data), + requires_grad=False) def apply( self, @@ -1151,12 +1115,10 @@ def __init__( vllm_config = get_current_vllm_config() self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( - tp_size_=( - tp_size - if tp_size is not None - else get_tensor_model_parallel_world_size() - ), - dp_size_=(dp_size if dp_size is not None else get_dp_group().world_size), + tp_size_=(tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()), + dp_size_=(dp_size + if dp_size is not None else get_dp_group().world_size), vllm_parallel_config=vllm_config.parallel_config, ) @@ -1184,36 +1146,28 @@ def __init__( expert_map_path = ascend_config.expert_map_path if expert_map_path and os.path.exists(expert_map_path): # moe expert load balance - expert_load_balancer = ExpertLoadBalancer( - expert_map_path, self.global_num_experts - ) + expert_load_balancer = ExpertLoadBalancer(expert_map_path, + self.global_num_experts) self.local_num_experts, self.expert_map = ( expert_load_balancer.get_rank_placement_map( - self.moe_instance_id, self.ep_rank - ) - ) + self.moe_instance_id, self.ep_rank)) self.log2phy = expert_load_balancer.get_rank_log2phy_map( - self.moe_instance_id, self.ep_rank - ) + self.moe_instance_id, self.ep_rank) self.global_redundant_expert_num = ( - expert_load_balancer.get_global_redundant_expert_num() - ) + expert_load_balancer.get_global_redundant_expert_num()) else: # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, self.ep_rank, self.global_num_experts - ) + self.ep_size, self.ep_rank, self.global_num_experts) self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = ( ascend_config.torchair_graph_config.enable_multistream_moe - and self.torchair_graph_enabled - ) + and self.torchair_graph_enabled) if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError( - "Only softmax scoring function is supported for " "non-grouped topk." - ) + raise ValueError("Only softmax scoring function is supported for " + "non-grouped topk.") moe = MoEConfig( num_experts=self.global_num_experts, @@ -1232,23 +1186,21 @@ def __init__( assert self.quant_method is not None - local_num_experts = ( - torch.sum(self.expert_map != -1) - if self.expert_map is not None - else num_experts - ) + local_num_experts = (torch.sum(self.expert_map != -1) + if self.expert_map is not None else num_experts) moe_quant_params = { "num_experts": local_num_experts, "hidden_size": hidden_size, - "intermediate_size_per_partition": self.intermediate_size_per_partition, + "intermediate_size_per_partition": + self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order if self.quant_method.__class__.__name__ in ( - "GPTQMarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod", + "GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", ): moe_quant_params["intermediate_size_full"] = intermediate_size @@ -1258,28 +1210,23 @@ def __init__( self.quant_method.create_weights(layer=self, **moe_quant_params) self.token_dispatcher = None if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance( - self.quant_method, AscendUnquantizedFusedMoEMethod - ): + self.quant_method, AscendUnquantizedFusedMoEMethod): self.reduce_results = False moe_dispatcher_config = ( - MoEDispatcherConfig() - .set_num_moe_experts(self.global_num_experts) - .set_num_local_experts(self.local_num_experts) - .set_moe_router_topk(top_k) - .set_group_topk(topk_group) - .set_num_groups(num_expert_group) - .set_expert_bias(e_score_correction_bias) - .set_scaling_factor(1.0) - .build() - ) + MoEDispatcherConfig().set_num_moe_experts( + self.global_num_experts).set_num_local_experts( + self.local_num_experts).set_moe_router_topk( + top_k).set_group_topk(topk_group). + set_num_groups(num_expert_group).set_expert_bias( + e_score_correction_bias).set_scaling_factor(1.0).build()) self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher( - moe_dispatcher_config - ) + moe_dispatcher_config) if envs_ascend.VLLM_ASCEND_ENABLE_DBO: token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher( - moe_dispatcher_config - ) - self.token_dispatchers = [self.token_dispatcher, token_dispatcher1] + moe_dispatcher_config) + self.token_dispatchers = [ + self.token_dispatcher, token_dispatcher1 + ] def forward( self, @@ -1304,23 +1251,18 @@ def forward( fused_moe_state = get_forward_context().fused_moe_state # For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel. quantized_x_for_share, dynamic_scale_for_share = None, None - from vllm_ascend.quantization.w8a8_dynamic import ( - AscendW8A8DynamicFusedMoEMethod, - ) + from vllm_ascend.quantization.w8a8_dynamic import \ + AscendW8A8DynamicFusedMoEMethod if self.enable_multistream_moe: assert gate is not None router_logits, _ = gate(hidden_states) - if ( - isinstance( - self.quant_method.quant_method, AscendW8A8DynamicFusedMoEMethod - ) - and fused_moe_state == FusedMoEState.MC2 - ): + if (isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod) + and fused_moe_state == FusedMoEState.MC2): with npu_stream_switch("moe_secondary", 0): quantized_x_for_share, dynamic_scale_for_share = ( - torch_npu.npu_dynamic_quant(hidden_states) - ) + torch_npu.npu_dynamic_quant(hidden_states)) if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: @@ -1339,11 +1281,15 @@ def forward( (0, 0, 0, forward_context.padded_num_tokens - num_tokens), ) if tp_size > 1: - chunk_hidden_states = torch.tensor_split(hidden_states, tp_size, dim=0) - chunk_router_logits = torch.tensor_split(router_logits, tp_size, dim=0) - chunk_mc2_mask = torch.tensor_split( - forward_context.mc2_mask, tp_size, dim=0 - ) + chunk_hidden_states = torch.tensor_split(hidden_states, + tp_size, + dim=0) + chunk_router_logits = torch.tensor_split(router_logits, + tp_size, + dim=0) + chunk_mc2_mask = torch.tensor_split(forward_context.mc2_mask, + tp_size, + dim=0) tp_rank = get_tensor_model_parallel_rank() hidden_states = chunk_hidden_states[tp_rank] router_logits = chunk_router_logits[tp_rank] @@ -1352,14 +1298,15 @@ def forward( if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: # NOTE: When in torchair graph, it has been padded in model_runner_v1 if not self.torchair_graph_enabled or is_prefill: - max_num_tokens_across_dp = get_forward_context().max_tokens_across_dp + max_num_tokens_across_dp = get_forward_context( + ).max_tokens_across_dp if num_tokens < max_num_tokens_across_dp: hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, max_num_tokens_across_dp - num_tokens) - ) + hidden_states, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) router_logits = nn.functional.pad( - router_logits, (0, 0, 0, max_num_tokens_across_dp - num_tokens) - ) + router_logits, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) hidden_states = get_dp_group().all_gather(hidden_states, 0) router_logits = get_dp_group().all_gather(router_logits, 0) @@ -1382,13 +1329,9 @@ def forward( enable_force_load_balance=enable_force_load_balance, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, - shared_experts=( - shared_experts - if self.torchair_graph_enabled - and self.enable_multistream_moe - and not is_prefill - else None - ), + shared_experts=(shared_experts if self.torchair_graph_enabled + and self.enable_multistream_moe and not is_prefill + else None), quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=mc2_mask, @@ -1400,22 +1343,26 @@ def forward( e_hidden_states, shared_hidden_states = e_hidden_states if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: - dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) + dist.all_gather(list(chunk_hidden_states), e_hidden_states, + self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0) if num_tokens < forward_context.padded_num_tokens: final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( - e_hidden_states, "sum", scatter_dim=0, group=get_dp_group().device_group - ) + e_hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) final_hidden_states = final_hidden_states[:num_tokens] dispose_tensor(e_hidden_states) else: final_hidden_states = e_hidden_states if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) if shared_experts: return final_hidden_states, shared_hidden_states @@ -1466,14 +1413,12 @@ def __init__( if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}." - ) + f"the number of experts {config.num_experts}.") ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_multistream_moe = ( - ascend_config.torchair_graph_config.enable_multistream_moe - ) + ascend_config.torchair_graph_config.enable_multistream_moe) self.gate = ReplicatedLinear( config.hidden_size, From e87df1151a891b6eaf72da0794a09bd9027d4e42 Mon Sep 17 00:00:00 2001 From: weijinqian Date: Fri, 11 Jul 2025 22:59:01 +0800 Subject: [PATCH 50/60] handle clean code Signed-off-by: weijinqian_v1 --- vllm_ascend/ascend_forward_context.py | 57 +++++++------------ vllm_ascend/models/__init__.py | 2 +- vllm_ascend/models/deepseek_dbo.py | 12 ++-- vllm_ascend/models/qwen3_moe.py | 3 +- vllm_ascend/multistream/ms_split.py | 4 +- .../ops/moe_dispatcher/token_dispatcher.py | 6 +- 6 files changed, 37 insertions(+), 47 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index c2d81037ff..e4a9b5adce 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -28,11 +28,8 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool): return FusedMoEState.AllGather elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: # MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage. - return ( - FusedMoEState.All2AllSeq - if (ep_size < 16 or with_prefill) - else FusedMoEState.MC2 - ) + return (FusedMoEState.All2AllSeq if + (ep_size < 16 or with_prefill) else FusedMoEState.MC2) elif ep_size >= 16 and with_prefill and enable_chunk_mc2: return FusedMoEState.MC2_PREFILL # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. @@ -58,19 +55,16 @@ def set_ascend_forward_context( We add some additional param into forward_context. """ with set_forward_context( - attn_metadata, - vllm_config, - virtual_engine=virtual_engine, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, + attn_metadata, + vllm_config, + virtual_engine=virtual_engine, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, ): forward_context = get_forward_context() forward_context.with_prefill = with_prefill - ep_size = ( - torch.distributed.get_world_size() - if vllm_config.parallel_config.enable_expert_parallel - else 1 - ) + ep_size = (torch.distributed.get_world_size() if + vllm_config.parallel_config.enable_expert_parallel else 1) fused_moe_state = get_fused_moe_state(ep_size, with_prefill) @@ -88,9 +82,8 @@ def set_ascend_forward_context( num_tokens = attn_metadata.num_actual_tokens else: # for v0 engine - num_tokens = ( - attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens - ) + num_tokens = (attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens) if num_actual_tokens is None: num_actual_tokens = num_tokens @@ -98,8 +91,7 @@ def set_ascend_forward_context( dp_world_size = get_dp_group().world_size if dp_world_size > 1 and forward_context.dp_metadata is not None: max_tokens_across_dp = ( - forward_context.dp_metadata.max_tokens_across_dp_cpu.item() - ) + forward_context.dp_metadata.max_tokens_across_dp_cpu.item()) else: max_tokens_across_dp = num_tokens @@ -110,31 +102,26 @@ def set_ascend_forward_context( world_size = torch.distributed.get_world_size() # NOTE: token num which need to pad to when mc2 forward_context.padded_num_tokens = ( - math.ceil(max_tokens_across_dp / tp_world_size) * tp_world_size - ) + math.ceil(max_tokens_across_dp / tp_world_size) * + tp_world_size) # NOTE: mc2 op's param `global_bs`, add `world_size` to make `global_bs` absolutely larger than actual global_bs. forward_context.global_bs = ( - math.ceil(max_tokens_across_dp / tp_world_size) * world_size - ) + math.ceil(max_tokens_across_dp / tp_world_size) * world_size) if fused_moe_state == FusedMoEState.MC2_PREFILL: chunk_size = envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE forward_context.max_num_chunks = math.ceil( - math.ceil(max_tokens_across_dp / tp_world_size) / chunk_size - ) + math.ceil(max_tokens_across_dp / tp_world_size) / + chunk_size) - forward_context.global_bs = ( - math.ceil( - math.ceil(max_tokens_across_dp / tp_world_size) - / forward_context.max_num_chunks - ) - * world_size - ) + forward_context.global_bs = (math.ceil( + math.ceil(max_tokens_across_dp / tp_world_size) / + forward_context.max_num_chunks) * world_size) min_num_tokens = forward_context.max_num_chunks * tp_world_size forward_context.padded_num_tokens = ( - math.ceil(max_tokens_across_dp / min_num_tokens) * min_num_tokens - ) + math.ceil(max_tokens_across_dp / min_num_tokens) * + min_num_tokens) mc2_mask = torch.zeros( forward_context.padded_num_tokens, diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index c525849cb8..b5c9c9597b 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -58,6 +58,6 @@ def register_model(): ModelRegistry.register_model( "Qwen3MoeForCausalLM", "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") - + ModelRegistry.register_model( "Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM") diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 6562bb46bd..20dafdf7ac 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -147,7 +147,8 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=not envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is separated in alltoallv for better overlap. + reduce_results=not envs_ascend. + VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is separated in alltoallv for better overlap. prefix=f"{prefix}.shared_experts", ) CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok @@ -232,7 +233,9 @@ def _forward_op_gating( chunk_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0) - chunked_hidden_states_sizes = [x.shape[0] for x in chunk_hidden_states] + chunked_hidden_states_sizes = [ + x.shape[0] for x in chunk_hidden_states + ] local_hidden_states = chunk_hidden_states[self.tp_rank] else: local_hidden_states = hidden_states @@ -245,7 +248,7 @@ def _forward_op_gating( if self.config.n_routed_experts == 256: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( router_logits, - k=self.config.num_experts_per_tok, + k=self.config.num_experts_per_tok, bias=self.gate.e_score_correction_bias, k_group=self.config.topk_group, # fix: 4 group_count=self.config.n_group, # fix 8 @@ -273,7 +276,8 @@ def _forward_op_gating( # to avoid accumulating too much tokens on a single rank. # currently it is only activated when doing profile runs. if enable_force_load_balance: - topk_ids = torch.randint_like(topk_ids, 0, self.config.n_routed_experts) + topk_ids = torch.randint_like(topk_ids, 0, + self.config.n_routed_experts) return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index af09eb01cb..485e5ca92f 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -33,6 +33,7 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "gate_proj", "up_proj", ], - "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 605e6065c2..0ddf11e50f 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -294,8 +294,8 @@ def model_input_split_v1_attn( token_index) is_only_prefill_pre = is_only_prefill_post = attn_metadata.is_only_prefill - has_prefill_pre, _ = torch.any( - query_lens_pre > 1).item(), torch.any(query_lens_post > 1).item() + has_prefill_pre, _ = torch.any(query_lens_pre > 1).item(), torch.any( + query_lens_post > 1).item() if not attn_metadata.is_only_prefill: is_only_prefill_post = torch.all(query_lens_post > 1).item() diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 1e18900870..85234cd390 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -232,7 +232,6 @@ def preprocess(self, ep_size = self.ep_size - # Dropless self.num_out_tokens = indices.numel() if self.ep_size > 1 or self.num_local_experts > 1: @@ -408,7 +407,6 @@ def preprocess_and_permtute1(self, shared_output = shared_experts(shared_experts_input) self.cached_shared_expert_output = shared_output - hidden_states, self.reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( tokens=hidden_states, indices=self.top_indices, @@ -542,8 +540,8 @@ def alltoall_token_unpermutation2(permutated_local_input_tokens): output = torch_npu.npu_moe_token_unpermute( permuted_tokens=permutated_local_input_tokens, - sorted_indices=self. - reversed_local_input_permutation_mapping.to(torch.int32), + sorted_indices=self.reversed_local_input_permutation_mapping. + to(torch.int32), probs=self.probs, restore_shape=self.hidden_shape_before_permute) From 267db60143d30d44435023ff8dabc0036713115a Mon Sep 17 00:00:00 2001 From: weijinqian Date: Fri, 11 Jul 2025 23:04:45 +0800 Subject: [PATCH 51/60] handle clean code Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/fused_moe.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index d3983af245..7f4b7a14d0 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -517,8 +517,8 @@ def fused_experts_with_all2all_buffer( dtype=expert_idx_buffer_scatter.dtype, device=expert_idx_buffer_scatter.device, ) - non_pad_len = torch.sum((expert_idx_buffer_scatter - != global_num_experts).to(torch.int32)) + non_pad_len = torch.sum( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) hidden_states_pad_idx[expert_idx_buffer_scatter != global_num_experts] = ( torch.arange( non_pad_len, @@ -580,8 +580,8 @@ def fused_experts_with_all2all_buffer( dist.all_to_all_single(hidden_states_gatter, hidden_states_scatter, group=ep_group.device_group) - hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter != - global_num_experts] + hidden_states_gatter = hidden_states_gatter[ + expert_idx_buffer_scatter != global_num_experts] if hidden_states_gatter.shape[0] != row_idx_len: hidden_states = torch.zeros( (row_idx_len, hidden_states.shape[1]), @@ -776,10 +776,9 @@ def fused_experts( # This created multiple NaN and index_add_ will mix them up which harms accuracy # remove this mask and filter after it being fixed num_valid_tokens = mask.sum() - valid_token_mask = (torch.arange(0, - sorted_token_indices.shape[0], - device=device).unsqueeze(1) - < num_valid_tokens) + valid_token_mask = (torch.arange( + 0, sorted_token_indices.shape[0], device=device).unsqueeze(1) < + num_valid_tokens) valid_output = torch.where( valid_token_mask, weighted_down_out, torch.zeros_like(weighted_down_out)).to(dtype) From b0572c89335786002e868908fd7b40b02b75623f Mon Sep 17 00:00:00 2001 From: weijinqian Date: Fri, 11 Jul 2025 23:05:21 +0800 Subject: [PATCH 52/60] handle clean code Signed-off-by: weijinqian_v1 --- tests/ut/test_distributed_tensor_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ut/test_distributed_tensor_parallel.py b/tests/ut/test_distributed_tensor_parallel.py index 5792fb6df5..5a438e0cdf 100644 --- a/tests/ut/test_distributed_tensor_parallel.py +++ b/tests/ut/test_distributed_tensor_parallel.py @@ -72,7 +72,7 @@ def test_gather_along_first_dim_unequal_split(self, test_tensor, output_split_sizes = [5, 10, 15, 2] result = _gather_along_first_dim(test_tensor, mock_group, output_split_sizes) - self.assertEqual(result.shape, (32, 16)) # 5+10+15+2=32 + self.assertEqual(result.shape, (32, 16)) # 5+10+15+2=32 @pytest.mark.parametrize("world_size", [1, 4]) def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, From e4f1050a3afb26f7a00f5fdafc81773982bbcf27 Mon Sep 17 00:00:00 2001 From: weijinqian Date: Fri, 11 Jul 2025 23:43:43 +0800 Subject: [PATCH 53/60] handle clean code Signed-off-by: weijinqian_v1 --- vllm_ascend/models/__init__.py | 1 - vllm_ascend/multistream/ms_split.py | 2 ++ .../ops/moe_dispatcher/token_dispatcher.py | 16 ++++++++++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index b5c9c9597b..b2da242106 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -8,7 +8,6 @@ def register_model(): from .deepseek_mtp import CustomDeepSeekMTP # noqa: F401 from .deepseek_v2 import CustomDeepseekV2ForCausalLM # noqa: F401 from .deepseek_v2 import CustomDeepseekV3ForCausalLM # noqa: F401 - from .moe_block import AscendSparseMoeBlock # noqa: F401 from .qwen2_5_vl import \ AscendQwen2_5_VLForConditionalGeneration # noqa: F401 from .qwen2_vl import AscendQwen2VLForConditionalGeneration # noqa: F401 diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 0ddf11e50f..eda00f401a 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -304,12 +304,14 @@ def model_input_split_v1_attn( # the attn_mla kernel in torch npu only accept 128*128 attn mask attn_mask_pre = attn_mask_post = attn_metadata.attn_mask attn_state_pre = attn_state_post = attn_metadata.attn_state + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # should be none in decode only state attn_mask_pre = attn_mask_post = attn_metadata.attn_mask attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly else: # chunked prefill + assert attn_metadata.attn_mask is not None if has_prefill_pre: attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index 85234cd390..e631fcbf22 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -24,6 +24,7 @@ import torch import torch_npu +from torch import Tensor from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.tensor_parallel import ( @@ -279,7 +280,7 @@ def preprocess(self, "num_global_tokens_per_local_expert must be set before operations." ) self.device_sync_point = "no_sync" - self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + self.global_input_tokens_local_experts_indices: Tensor = torch.repeat_interleave( self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()) @@ -314,6 +315,7 @@ def token_permutation( # Permutation 1: input to AlltoAll input def alltoall_token_permutation1(hidden_states, routing_map): + assert self.hidden_shape is not None hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) tokens_per_expert = self.preprocess(routing_map) if self.tp_ep_size > 1: @@ -390,6 +392,7 @@ def preprocess_and_permtute1(self, self.top_indices = routing_map assert probs.dim() == 2, "Expected 2D tensor for probs" assert routing_map.dim() == 2, "Expected 2D tensor for routing map" + assert self.hidden_shape is not None hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) tokens_per_expert = self.preprocess(routing_map, with_sync=False) @@ -401,6 +404,7 @@ def preprocess_and_permtute1(self, event = torch.npu.current_stream().record_event() self.perm1_finish_event = torch.npu.Event() with torch.npu.stream(self.overlap_stream): + assert self.overlap_stream is not None self.overlap_stream.wait_event(event) if shared_experts is not None: @@ -418,7 +422,11 @@ def preprocess_and_permtute1(self, # repeat interleve will launch a sync on current_stream. if self.num_local_experts > 1: self.device_sync_point = "no_sync" - self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before operations." + ) + self.global_input_tokens_local_experts_indices: Tensor = torch.repeat_interleave( self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()) @@ -441,6 +449,10 @@ def dispatch_alltoall(self): ep_group, ) permute1_ep_all_to_all_handle.wait() + if self.cached_permutated_local_input_tokens is None: + raise ValueError( + "cached_permutated_local_input_tokens must be set before operations." + ) self.cached_permutated_local_input_tokens.untyped_storage().resize_(0) self.cached_permutated_local_input_tokens = None From eaed83df5d5d93e57ba96668509fb3cb871a1f7e Mon Sep 17 00:00:00 2001 From: weijinqian Date: Sat, 12 Jul 2025 00:08:52 +0800 Subject: [PATCH 54/60] handle clean code Signed-off-by: weijinqian_v1 --- vllm_ascend/multistream/ms_split.py | 10 +++++----- vllm_ascend/ops/moe_dispatcher/token_dispatcher.py | 7 +++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index eda00f401a..61e20ed14b 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -308,21 +308,21 @@ def model_input_split_v1_attn( elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # should be none in decode only state attn_mask_pre = attn_mask_post = attn_metadata.attn_mask - attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly + attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly # noqa else: # chunked prefill assert attn_metadata.attn_mask is not None if has_prefill_pre: - attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill # noqa attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( seq_lens_pre)].contiguous() - attn_state_post = AscendAttentionState.ChunkedPrefill + attn_state_post = AscendAttentionState.ChunkedPrefill # noqa attn_mask_post = attn_metadata.attn_mask[ token_index:, :max(seq_lens_post)].contiguous() else: - attn_state_pre = AscendAttentionState.DecodeOnly + attn_state_pre = AscendAttentionState.DecodeOnly # noqa attn_mask_pre = None - attn_state_post = AscendAttentionState.ChunkedPrefill + attn_state_post = AscendAttentionState.ChunkedPrefill # noqa attn_mask_post = attn_metadata.attn_mask[ token_index:, :max(seq_lens_post)].contiguous() diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index e631fcbf22..dca878078f 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -201,6 +201,8 @@ def __init__(self, config: MoEDispatcherConfig): self.cached_global_input_tokens = None self.cached_shared_expert_output = None self.tokens_per_expert = None + self.perm1_finish_event = None + self.global_input_tokens_local_experts_indices = None if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None: MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream() @@ -280,7 +282,7 @@ def preprocess(self, "num_global_tokens_per_local_expert must be set before operations." ) self.device_sync_point = "no_sync" - self.global_input_tokens_local_experts_indices: Tensor = torch.repeat_interleave( + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()) @@ -426,7 +428,7 @@ def preprocess_and_permtute1(self, raise ValueError( "num_global_tokens_per_local_expert must be set before operations." ) - self.global_input_tokens_local_experts_indices: Tensor = torch.repeat_interleave( + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( self.expert_ids_per_ep_rank, self.num_global_tokens_per_local_expert.ravel()) @@ -462,6 +464,7 @@ def permute2(self): global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( self.cached_global_input_tokens, self.global_input_tokens_local_experts_indices) + assert self.cached_global_input_tokens is not None self.cached_global_input_tokens.untyped_storage().resize_(0) self.cached_global_input_tokens = None From b97baf4d2f18955176e582ed4d35eef7d81a2a8e Mon Sep 17 00:00:00 2001 From: weijinqian Date: Sat, 12 Jul 2025 00:10:17 +0800 Subject: [PATCH 55/60] handle clean code Signed-off-by: weijinqian_v1 --- vllm_ascend/ops/moe_dispatcher/token_dispatcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index dca878078f..91118e296d 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -24,7 +24,6 @@ import torch import torch_npu -from torch import Tensor from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.tensor_parallel import ( From d232d4951db926861ef52314d5e153660df3b066 Mon Sep 17 00:00:00 2001 From: weijinqian Date: Sat, 12 Jul 2025 00:12:07 +0800 Subject: [PATCH 56/60] handle clean code Signed-off-by: weijinqian_v1 --- vllm_ascend/multistream/ms_split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 61e20ed14b..16e50468ac 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -313,7 +313,7 @@ def model_input_split_v1_attn( # chunked prefill assert attn_metadata.attn_mask is not None if has_prefill_pre: - attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill # noqa + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill # noqa attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( seq_lens_pre)].contiguous() attn_state_post = AscendAttentionState.ChunkedPrefill # noqa From 1e435e610b0b4bb1242c7e3f286f3cad21cfc327 Mon Sep 17 00:00:00 2001 From: weijinqian Date: Sat, 12 Jul 2025 00:33:33 +0800 Subject: [PATCH 57/60] handle clean code Signed-off-by: weijinqian_v1 --- vllm_ascend/multistream/ms_split.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index 16e50468ac..9ceae10537 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -304,25 +304,24 @@ def model_input_split_v1_attn( # the attn_mla kernel in torch npu only accept 128*128 attn mask attn_mask_pre = attn_mask_post = attn_metadata.attn_mask attn_state_pre = attn_state_post = attn_metadata.attn_state - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # should be none in decode only state attn_mask_pre = attn_mask_post = attn_metadata.attn_mask - attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly # noqa + attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly # type: ignore else: # chunked prefill assert attn_metadata.attn_mask is not None if has_prefill_pre: - attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill # noqa + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill # type: ignore attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( seq_lens_pre)].contiguous() - attn_state_post = AscendAttentionState.ChunkedPrefill # noqa + attn_state_post = AscendAttentionState.ChunkedPrefill # type: ignore attn_mask_post = attn_metadata.attn_mask[ token_index:, :max(seq_lens_post)].contiguous() else: - attn_state_pre = AscendAttentionState.DecodeOnly # noqa + attn_state_pre = AscendAttentionState.DecodeOnly # type: ignore attn_mask_pre = None - attn_state_post = AscendAttentionState.ChunkedPrefill # noqa + attn_state_post = AscendAttentionState.ChunkedPrefill # type: ignore attn_mask_post = attn_metadata.attn_mask[ token_index:, :max(seq_lens_post)].contiguous() From 8effdd0300d98066107ee1fdeb79c538f8731973 Mon Sep 17 00:00:00 2001 From: weijinqian Date: Sat, 12 Jul 2025 00:44:35 +0800 Subject: [PATCH 58/60] handle clean code Signed-off-by: weijinqian_v1 --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 6d84ec658c..be00f01991 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ torch-npu==2.5.1.post1.dev20250619 # Remove after https://github.com/vllm-project/vllm-ascend/issues/1470 transformers<4.53.0 +pytest_mock From a8136b75b277c698e95f5fc6a79a96ff9ceaae9d Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Sat, 12 Jul 2025 23:44:50 +0800 Subject: [PATCH 59/60] handle code clean Signed-off-by: weijinqian_v1 --- .../test_offline_inference_distributed.py | 21 ++++++ tests/multicard/test_qwen3_moe.py | 74 ------------------- 2 files changed, 21 insertions(+), 74 deletions(-) delete mode 100644 tests/multicard/test_qwen3_moe.py diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index d4af282efe..3f4364ef24 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -154,6 +154,27 @@ def test_models_distributed_DeepSeekV3_dbo(): vllm_model.generate(example_prompts, sampling_params) +@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in") +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1", "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"}) +def test_models_distributed_DeepSeekV3_alltoallv_dbo(): + example_prompts = ["The president of the United States is"] * 10 + dtype = "half" + sampling_params = SamplingParams(max_tokens=30, temperature=0.0) + with VllmRunner( + "vllm-ascend/DeepSeek-V3-Pruning", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + model_arch = 'DeepseekV3ForCausalLM' + registed_models = ModelRegistry.models + assert registed_models[ + model_arch].module_name == "vllm_ascend.models.deepseek_dbo" + assert registed_models[ + model_arch].class_name == "CustomDeepseekDBOForCausalLM" + vllm_model.generate(example_prompts, sampling_params) + + def test_models_distributed_DeepSeek_W8A8(): example_prompts = [ "Hello, my name is", diff --git a/tests/multicard/test_qwen3_moe.py b/tests/multicard/test_qwen3_moe.py deleted file mode 100644 index e24770b792..0000000000 --- a/tests/multicard/test_qwen3_moe.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -""" -Compare the outputs of vLLM with and without aclgraph. -Run `pytest tests/multicard/test_data_parallel.py`. -""" - -import os -import subprocess -import sys -from unittest.mock import patch - -import pytest - -MODELS = ["vllm-ascend/Qwen3-30B-A3B-Puring"] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_tokens", [32]) -@patch.dict( - os.environ, { - "ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3", - "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1", - "VLLM_ASCEND_ENABLE_DBO": "1" - }) -def test_qwen3_moe_inference(model, max_tokens): - script = "examples/dp_offline/data_parallel.py" - - env = os.environ.copy() - - cmd = [ - sys.executable, - script, - "--model", - model, - "--dp-size", - "2", - "--tp-size", - "2", - "--node-size", - "1", - "--node-rank", - "0", - "--trust-remote-code", - ] - - print(f"Running subprocess: {' '.join(cmd)}") - proc = subprocess.run(cmd, - env=env, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - timeout=600) - output = proc.stdout.decode() - - print(output) - - assert "DP rank 0 needs to process" in output - assert "DP rank 1 needs to process" in output - assert "Generated text:" in output - assert proc.returncode == 0 From 94b7b5ba02b83e3f196f6a237cdd4bd698837912 Mon Sep 17 00:00:00 2001 From: weijinqian_v1 Date: Sat, 12 Jul 2025 23:52:01 +0800 Subject: [PATCH 60/60] handle code clean Signed-off-by: weijinqian_v1 --- tests/multicard/test_offline_inference_distributed.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index 3f4364ef24..ec01f6db3a 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -155,7 +155,10 @@ def test_models_distributed_DeepSeekV3_dbo(): @pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in") -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1", "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1"}) +@patch.dict(os.environ, { + "VLLM_ASCEND_ENABLE_DBO": "1", + "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1" +}) def test_models_distributed_DeepSeekV3_alltoallv_dbo(): example_prompts = ["The president of the United States is"] * 10 dtype = "half"