diff --git a/requirements.txt b/requirements.txt index 6d84ec658..be00f0199 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,3 +28,4 @@ torch-npu==2.5.1.post1.dev20250619 # Remove after https://github.com/vllm-project/vllm-ascend/issues/1470 transformers<4.53.0 +pytest_mock diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index d4af282ef..ec01f6db3 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -154,6 +154,30 @@ def test_models_distributed_DeepSeekV3_dbo(): vllm_model.generate(example_prompts, sampling_params) +@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in") +@patch.dict(os.environ, { + "VLLM_ASCEND_ENABLE_DBO": "1", + "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": "1" +}) +def test_models_distributed_DeepSeekV3_alltoallv_dbo(): + example_prompts = ["The president of the United States is"] * 10 + dtype = "half" + sampling_params = SamplingParams(max_tokens=30, temperature=0.0) + with VllmRunner( + "vllm-ascend/DeepSeek-V3-Pruning", + dtype=dtype, + tensor_parallel_size=4, + distributed_executor_backend="mp", + ) as vllm_model: + model_arch = 'DeepseekV3ForCausalLM' + registed_models = ModelRegistry.models + assert registed_models[ + model_arch].module_name == "vllm_ascend.models.deepseek_dbo" + assert registed_models[ + model_arch].class_name == "CustomDeepseekDBOForCausalLM" + vllm_model.generate(example_prompts, sampling_params) + + def test_models_distributed_DeepSeek_W8A8(): example_prompts = [ "Hello, my name is", diff --git a/tests/ut/test_distributed_tensor_parallel.py b/tests/ut/test_distributed_tensor_parallel.py new file mode 100644 index 000000000..5a438e0cd --- /dev/null +++ b/tests/ut/test_distributed_tensor_parallel.py @@ -0,0 +1,139 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import importlib +import unittest +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from vllm_ascend.distributed.tensor_parallel import ( + _gather_along_first_dim, _gather_along_last_dim, + _reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim, + all_to_all_hp2sp, all_to_all_sp2hp) + + +@pytest.fixture +def test_tensor(): + return torch.randn(8, 16) + + +@pytest.fixture +def test_tensor_last_dim(): + return torch.randn(8, 16, 32) + + +@pytest.fixture +def mock_group(): + return MagicMock() + + +@pytest.fixture(autouse=True) +def mock_dist(): + with patch("torch.distributed") as mock: + mock.get_world_size.return_value = 4 + mock.get_rank.return_value = 0 + yield mock + + +class TestDistributedCommunication(unittest.TestCase): + + @pytest.mark.parametrize("world_size", [1, 4]) + def test_gather_along_first_dim(self, test_tensor, mock_group, mock_dist, + world_size): + """test _gather_along_first_dim""" + mock_dist.get_world_size.return_value = world_size + + result = _gather_along_first_dim(test_tensor, mock_group) + + if world_size == 1: + self.assertEqual(result.shape, (8, 16)) + else: + self.assertEqual(result.shape, (32, 16)) # 8*4=32 + + def test_gather_along_first_dim_unequal_split(self, test_tensor, + mock_group): + """test unequal split""" + output_split_sizes = [5, 10, 15, 2] + result = _gather_along_first_dim(test_tensor, mock_group, + output_split_sizes) + self.assertEqual(result.shape, (32, 16)) # 5+10+15+2=32 + + @pytest.mark.parametrize("world_size", [1, 4]) + def test_gather_along_last_dim(self, test_tensor_last_dim, mock_group, + mock_dist, world_size): + """test _gather_along_last_dim""" + mock_dist.get_world_size.return_value = world_size + + result = _gather_along_last_dim(test_tensor_last_dim, mock_group) + + self.assertEqual(result.shape, (8, 16, 32 * world_size)) + + @pytest.mark.parametrize("input_shape,expected_shape", [ + ((32, 16), (8, 16)), + ((40, 10), (10, 10)), + ]) + def test_reduce_scatter_along_first_dim(self, mock_group, input_shape, + expected_shape): + input_tensor = torch.randn(*input_shape) + result = _reduce_scatter_along_first_dim(input_tensor, mock_group) + self.assertEqual(result.shape, expected_shape) + + def test_reduce_scatter_along_last_dim(self, mock_group): + input_tensor = torch.randn(8, 16, 32) + result = _reduce_scatter_along_last_dim(input_tensor, mock_group) + self.assertEqual(result.shape, (8, 16, 8)) + + @pytest.mark.parametrize("func,input_shape,expected_shape", [ + ("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32), + (8, 16, 128)), + ("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)), + ("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32), + (8, 16, 8)), + ("gather_from_sequence_parallel_region", (8, 16), (32, 16)), + ]) + def test_wrapper_functions(self, mock_group, func, input_shape, + expected_shape): + """test wrapper funcs""" + mod = importlib.import_module( + 'vllm_ascend.distributed.tensor_parallel') + globals = mod.__dict__ + test_func = globals[func] + input_tensor = torch.randn(*input_shape) + result = test_func(input_tensor, mock_group) + self.assertEqual(result.shape, expected_shape) + + @pytest.mark.parametrize( + "input_shape,output_shape", + [ + ((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP] + ]) + def test_all_to_all_sp2hp(self, mock_group, input_shape, output_shape): + input_tensor = torch.randn(*input_shape) + result = all_to_all_sp2hp(input_tensor, mock_group) + self.assertEqual(result.shape, output_shape) + + @pytest.mark.parametrize( + "input_shape,output_shape", + [ + ((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H] + ]) + def test_all_to_all_hp2sp(self, mock_group, input_shape, output_shape): + input_tensor = torch.randn(*input_shape) + result = all_to_all_hp2sp(input_tensor, mock_group) + self.assertEqual(result.shape, output_shape) diff --git a/tests/ut/test_token_dispatcher.py b/tests/ut/test_token_dispatcher.py new file mode 100644 index 000000000..18768a7fe --- /dev/null +++ b/tests/ut/test_token_dispatcher.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +import unittest + +import pytest +from pytest_mock import MockerFixture + +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) +from vllm_ascend.utils import adapt_patch # noqa E402 + +import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa + +adapt_patch(True) + + +class TestMoEAlltoAllSeqOverLapDispatcher(unittest.TestCase): + + @pytest.fixture + def config(self): + config = MoEDispatcherConfig() + config.set_num_local_experts(2) + config.set_num_moe_experts(4) + config.set_moe_pad_expert_input_to_capacity(False) + config.set_moe_expert_capacity_factor(None) + config.set_moe_router_topk(2) + config.set_moe_grouped_gemm(False) + config.set_group_topk(0) + config.set_num_groups(1) + config.set_is_fused(False) + return config.build() + + def mock_ep_group(self, mocker): + mock_group = mocker.MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 2 + mock_group.device_group = "mock_group" + return mock_group + + @pytest.fixture + def dispatcher(self, config, mocker: MockerFixture): + mocker.patch( + "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ep_group", + return_value=self.mock_ep_group(mocker)) + return MoEAlltoAllSeqOverLapDispatcher(config) + + def test_initialization(self, dispatcher, config): + self.assertEqual(dispatcher.num_local_experts, + config.num_local_experts) + self.assertEqual(dispatcher.num_experts, config.num_moe_experts) + self.assertEqual(dispatcher.local_expert_indices, [0, 1]) + self.assertEqual(dispatcher.ep_rank, 0) + self.assertEqual(dispatcher.ep_size, 2) + self.assertIsNotNone(dispatcher.overlap_stream) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 009300f3d..e4a9b5adc 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -10,6 +10,7 @@ from vllm.platforms import current_platform import vllm_ascend.envs as envs +import vllm_ascend.envs as envs_ascend class FusedMoEState(Enum): @@ -17,6 +18,7 @@ class FusedMoEState(Enum): All2All = 1 MC2 = 2 MC2_PREFILL = 3 + All2AllSeq = 4 # TODO(zzzzwwjj): add soc_version to choose branch @@ -24,6 +26,10 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool): enable_chunk_mc2 = envs.VLLM_ASCEND_ENABLE_CHUNK_MC2 if ep_size == 1: return FusedMoEState.AllGather + elif envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: + # MC2 Dispatch/Combine performs better than alltoall_seq in decoding stage. + return (FusedMoEState.All2AllSeq if + (ep_size < 16 or with_prefill) else FusedMoEState.MC2) elif ep_size >= 16 and with_prefill and enable_chunk_mc2: return FusedMoEState.MC2_PREFILL # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. @@ -35,27 +41,30 @@ def get_fused_moe_state(ep_size: int, with_prefill: bool): @contextmanager def set_ascend_forward_context( - attn_metadata: Any, - vllm_config: VllmConfig, - virtual_engine: int = 0, - num_tokens: Optional[int] = None, - num_tokens_across_dp: Optional[torch.Tensor] = None, - with_prefill: bool = True, - in_profile_run: bool = False, - num_actual_tokens: Optional[int] = None): + attn_metadata: Any, + vllm_config: VllmConfig, + virtual_engine: int = 0, + num_tokens: Optional[int] = None, + num_tokens_across_dp: Optional[torch.Tensor] = None, + with_prefill: bool = True, + in_profile_run: bool = False, + num_actual_tokens: Optional[int] = None, +): """A context manager that stores the current forward context, can be attention metadata, etc. We add some additional param into forward_context. """ - with set_forward_context(attn_metadata, - vllm_config, - virtual_engine=virtual_engine, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): + with set_forward_context( + attn_metadata, + vllm_config, + virtual_engine=virtual_engine, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + ): forward_context = get_forward_context() forward_context.with_prefill = with_prefill - ep_size = torch.distributed.get_world_size( - ) if vllm_config.parallel_config.enable_expert_parallel else 1 + ep_size = (torch.distributed.get_world_size() if + vllm_config.parallel_config.enable_expert_parallel else 1) fused_moe_state = get_fused_moe_state(ep_size, with_prefill) @@ -68,20 +77,21 @@ def set_ascend_forward_context( forward_context.capturing = False if num_tokens is None and attn_metadata is not None: - if hasattr(attn_metadata, 'num_actual_tokens'): + if hasattr(attn_metadata, "num_actual_tokens"): # for v1 engine num_tokens = attn_metadata.num_actual_tokens else: # for v0 engine - num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + num_tokens = (attn_metadata.num_prefill_tokens + + attn_metadata.num_decode_tokens) if num_actual_tokens is None: num_actual_tokens = num_tokens dp_world_size = get_dp_group().world_size if dp_world_size > 1 and forward_context.dp_metadata is not None: - max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item( - ) + max_tokens_across_dp = ( + forward_context.dp_metadata.max_tokens_across_dp_cpu.item()) else: max_tokens_across_dp = num_tokens @@ -91,11 +101,12 @@ def set_ascend_forward_context( tp_world_size = get_tp_group().world_size world_size = torch.distributed.get_world_size() # NOTE: token num which need to pad to when mc2 - forward_context.padded_num_tokens = math.ceil( - max_tokens_across_dp / tp_world_size) * tp_world_size + forward_context.padded_num_tokens = ( + math.ceil(max_tokens_across_dp / tp_world_size) * + tp_world_size) # NOTE: mc2 op's param `global_bs`, add `world_size` to make `global_bs` absolutely larger than actual global_bs. - forward_context.global_bs = math.ceil( - max_tokens_across_dp / tp_world_size) * world_size + forward_context.global_bs = ( + math.ceil(max_tokens_across_dp / tp_world_size) * world_size) if fused_moe_state == FusedMoEState.MC2_PREFILL: chunk_size = envs.VLLM_ASCEND_FUSED_MOE_MC2_CHUNK_SIZE @@ -103,17 +114,20 @@ def set_ascend_forward_context( math.ceil(max_tokens_across_dp / tp_world_size) / chunk_size) - forward_context.global_bs = math.ceil( + forward_context.global_bs = (math.ceil( math.ceil(max_tokens_across_dp / tp_world_size) / - forward_context.max_num_chunks) * world_size + forward_context.max_num_chunks) * world_size) min_num_tokens = forward_context.max_num_chunks * tp_world_size - forward_context.padded_num_tokens = math.ceil( - max_tokens_across_dp / min_num_tokens) * min_num_tokens - - mc2_mask = torch.zeros(forward_context.padded_num_tokens, - dtype=torch.bool, - device=current_platform.device_type) + forward_context.padded_num_tokens = ( + math.ceil(max_tokens_across_dp / min_num_tokens) * + min_num_tokens) + + mc2_mask = torch.zeros( + forward_context.padded_num_tokens, + dtype=torch.bool, + device=current_platform.device_type, + ) mc2_mask[:num_actual_tokens] = True forward_context.mc2_mask = mc2_mask diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index f03d9f88e..629fe73d5 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -32,6 +32,7 @@ from vllm_ascend.attention.utils import \ AscendCommonAttentionMetadata as CommonAttentionMetadata +from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import get_graph_params @@ -140,6 +141,18 @@ class AscendMetadata: enable_dbo_across_dp: bool = False + def split_metadata_for_multistream( + self, + ms_split_config: MSAttentionMetadataSplitConfig, + ) -> list["AscendMetadata"]: + """Split metadata for multi-stream with AscendMetadata""" + from vllm_ascend.multistream.ms_split import model_input_split_v1_attn + return model_input_split_v1_attn( + ms_split_config=ms_split_config, + attn_metadata=self, + _metadata_cls=AscendMetadata, + ) + class AscendAttentionMetadataBuilder: diff --git a/vllm_ascend/distributed/tensor_parallel.py b/vllm_ascend/distributed/tensor_parallel.py new file mode 100644 index 000000000..a9e432455 --- /dev/null +++ b/vllm_ascend/distributed/tensor_parallel.py @@ -0,0 +1,247 @@ +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +import torch + + +def _gather_along_first_dim(input_, group, output_split_sizes=None): + """Gather tensors and concatenate along the first dimension. + + Args: + input_tensor (torch.Tensor): + A tensor to be gathered. + output_split_sizes (List[int], optional): + A list specifying the sizes of the output splits along the first dimension. + If None, equal splitting is assumed. Default: None. + + Returns: + torch.Tensor: Gathered tensor. + """ + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + if output_split_sizes is None: + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, + input_.contiguous(), + group=group) + else: + dim_size[0] = sum(output_split_sizes) + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + output_tensor_list = list( + torch.split(output, output_split_sizes, dim=0)) + torch.distributed.all_gather(output_tensor_list, input_, group=group) + + return output + + +def _gather_along_last_dim(input_, group): + """Gather tensors and concatenate along the last dimension.""" + + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] * world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.all_gather_into_tensor(output, + input_.contiguous(), + group=group) + tensor_list = output.chunk(world_size, dim=0) + output = torch.cat(tensor_list, dim=-1).contiguous() + + return output + + +def _reduce_scatter_along_first_dim(input_, + group, + input_split_sizes=None, + use_global_buffer=False): + """Reduce-scatter the input tensor across model parallel group. + + Args: + input_ (torch.Tensor): The input tensor to be reduce-scattered. + input_split_sizes (List[int], optional): A list specifying the sizes of + the input splits along the first dimension for each rank. If None, + equal splitting is assumed. Default: None. + """ + world_size = torch.distributed.get_world_size(group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + if input_split_sizes is None: + dim_size = list(input_.size()) + assert ( + dim_size[0] % world_size == 0 + ), "First dimension of the tensor should be divisible by tensor parallel size" + + dim_size[0] = dim_size[0] // world_size + + output = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + torch.distributed.reduce_scatter_tensor(output, + input_.contiguous(), + group=group) + else: + rank = torch.distributed.get_rank(group) + input_tensor_list = list(torch.split(input_, input_split_sizes, dim=0)) + + output = torch.empty_like(input_tensor_list[rank]) + torch.distributed.reduce_scatter(output, + input_tensor_list, + group=group) + return output + + +def _reduce_scatter_along_last_dim(input_, group): + """Reduce-scatter tensors on the last dimension.""" + world_size = torch.distributed.get_world_size(group) + target_shape = list(input_.size()) + target_shape[-1] = target_shape[-1] // world_size + input_ = input_.reshape(-1, input_.shape[-1]) + split_tensors = torch.split(input_, + split_size_or_sections=input_.shape[-1] // + world_size, + dim=1) + concat_tensor = torch.cat(split_tensors, dim=0) + output = _reduce_scatter_along_first_dim(concat_tensor, + group).reshape(target_shape) + return output + + +def all_gather_last_dim_from_tensor_parallel_region(input_, group): + """Wrapper for autograd function: forward: AG, backward RS """ + return _gather_along_last_dim(input_, group) + + +def reduce_scatter_to_sequence_parallel_region(input_, + group, + input_split_sizes=None): + """Wrapper for autograd function: forward: RS, backward AG """ + return _reduce_scatter_along_first_dim(input_, group, input_split_sizes) + + +def reduce_scatter_last_dim_to_tensor_parallel_region(input_, group): + """Wrapper for autograd function: forward: RS, backward AG: AG """ + return _reduce_scatter_along_last_dim(input_, group) + + +def gather_from_sequence_parallel_region( + input_, + group, + output_split_sizes=None, +): + """Wrapper for autograd function: forward: AG, backward: RS """ + return _gather_along_first_dim(input_, group, output_split_sizes) + + +def all_to_all(group, input, output_split_sizes=None, input_split_sizes=None): + world_size = torch.distributed.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input + + input = input.contiguous() + if output_split_sizes is None: + # Equal split (all2all) + output = torch.empty_like(input) + else: + # Unequal split (all2all-v) + output = input.new_empty( + size=[sum(output_split_sizes)] + list(input.size()[1:]), + dtype=input.dtype, + device=torch.npu.current_device(), + ) + torch.distributed.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + +def all_to_all_sp2hp(input_, group): + """ + Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape + [num_tokens/TP, H] to [num_tokens, H/TP]. + + Args: + input_ (torch.Tensor): + The input tensor which has been distributed along the sequence + dimension. + + Returns: + torch.Tensor: The output tensor with shape [num_tokens, H/TP]. + + """ + if group is None: + return input_ + world_size = torch.distributed.get_world_size(group=group) + tp_group = group + input_ = input_.reshape(-1, input_.shape[-1]) + split_tensors = torch.split(input_, + split_size_or_sections=input_.shape[-1] // + world_size, + dim=1) + concat_tensor = torch.cat(split_tensors, dim=0) + output = all_to_all(tp_group, concat_tensor) + return output + + +def all_to_all_hp2sp(input_, group): + """ + Perform AlltoAll communication on tensor parallel group, transform the input tensor from shape + [num_tokens, H/TP] to [num_tokens/TP, H]. + + Args: + input_ (torch.Tensor): + The input tensor which has been distributed along the hidden + dimension. + + Returns: + torch.Tensor: The output tensor with shape [num_tokens/TP, H]. + """ + if group is None: + return input_ + world_size = torch.distributed.get_world_size(group=group) + input_ = input_.reshape(-1, input_.shape[-1]) + tp_group = group + input_exchanged = all_to_all(tp_group, input_) + input_reshaped = input_exchanged.reshape(-1, input_exchanged.shape[-1]) + split_tensors = torch.split( + input_reshaped, + split_size_or_sections=input_reshaped.shape[0] // world_size, + dim=0) + output = torch.cat(split_tensors, dim=-1) + return output diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index a5ba5e0e3..2977d6763 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -106,11 +106,11 @@ "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) ), - # MOE_ALL2ALL_BUFFER: + # VLLM_ASCEND_MOE_ALL2ALL_BUFFER: # 0: default, normal init. # 1: enable moe_all2all_buffer. - "MOE_ALL2ALL_BUFFER": - lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))), + "VLLM_ASCEND_MOE_ALL2ALL_BUFFER": + lambda: bool(int(os.getenv("VLLM_ASCEND_MOE_ALL2ALL_BUFFER", '0'))), # Some models are optimized by vllm ascend. While in some case, e.g. rlhf # training, the optimized model may not be suitable. In this case, set this # value to False to disable the optimized model. @@ -137,6 +137,11 @@ # and the mla_pa will be the default path of deepseek decode path. "VLLM_ASCEND_MLA_PA": lambda: int(os.getenv("VLLM_ASCEND_MLA_PA", 0)), + # VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ: + # 0: default, normal init. + # 1: enable moe all2all seq. + "VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ": + lambda: bool(int(os.getenv('VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ', '0'))), # ENABLE chunk mc2 "VLLM_ASCEND_ENABLE_CHUNK_MC2": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CHUNK_MC2", "0"))), diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 324a31b21..b2da24210 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -41,6 +41,10 @@ def register_model(): "DeepseekV3ForCausalLM", "vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM") + ModelRegistry.register_model( + "Qwen3MoeForCausalLM", + "vllm_ascend.models.qwen3_dbo:CustomQwen3MoeForCausalLMDBO") + else: ModelRegistry.register_model( "DeepseekV2ForCausalLM", @@ -50,9 +54,9 @@ def register_model(): "DeepseekV3ForCausalLM", "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") - ModelRegistry.register_model( - "Qwen3MoeForCausalLM", - "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") + ModelRegistry.register_model( + "Qwen3MoeForCausalLM", + "vllm_ascend.models.qwen3_moe:CustomQwen3MoeForCausalLM") ModelRegistry.register_model( "Qwen3ForCausalLM", "vllm_ascend.models.qwen3:CustomQwen3ForCausalLM") diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 2107c3c81..20dafdf7a 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -34,11 +34,10 @@ from transformers import PretrainedConfig from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig -from vllm.distributed import (get_ep_group, get_pp_group, - get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import UnquantizedLinearMethod @@ -56,6 +55,8 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import FusedMoEState +from vllm_ascend.distributed.tensor_parallel import \ + gather_from_sequence_parallel_region from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2DecoderLayer, CustomDeepseekV2MLP, CustomDeepseekV2MoE) @@ -68,6 +69,7 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) +from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.quantization.w8a8_dynamic import ( AscendW8A8DynamicLinearMethod, apply_mlp) from vllm_ascend.utils import dispose_tensor @@ -90,7 +92,8 @@ def __init__( intermediate_size=intermediate_size, hidden_act=hidden_act, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + reduce_results=reduce_results) self.is_dynamic_quant = not isinstance( self.gate_up_proj.quant_method, UnquantizedLinearMethod) and isinstance( @@ -144,10 +147,53 @@ def __init__( intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, - reduce_results=True, + reduce_results=not envs_ascend. + VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ, # shared experts tp comm is separated in alltoallv for better overlap. prefix=f"{prefix}.shared_experts", ) CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() + if attn_metadata is None: + attn_metadata = forward_context.attn_metadata + + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = forward_context.in_profile_run + + is_prefill = forward_context.with_prefill + # If this node is kv_consumer, we force the moe always runs in decode path to make sure + # the behaviour aligned between dummy_run and normal model_execute. + if self.kv_consumer: + is_prefill = False + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + experts_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=CustomDeepseekDBOMoE.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=self.shared_experts) + + shared_experts_hidden = experts_hidden_states[1] + if not (self.shared_experts.down_proj.reduce_results + and self.shared_experts.down_proj.tp_size > 1): + shared_experts_hidden = tensor_model_parallel_all_reduce( + shared_experts_hidden) + + hidden_states = ( + experts_hidden_states[0] * self.routed_scaling_factor + + shared_experts_hidden) + + return hidden_states # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_op_shared_expert( @@ -165,6 +211,112 @@ def _forward_ms_op_gate( router_logits, _ = self.gate(hidden_states) return router_logits + def _forward_op_gating( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + # TODO: need a better flag to indicate whether in profile run or not. + enable_force_load_balance = get_forward_context().in_profile_run + + num_tokens, hidden_dim = hidden_states.shape + + if self.tp_size > 1: + # pass + num_tokens, hidden_size = hidden_states.shape + if num_tokens < self.tp_size: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, self.tp_size - num_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + chunked_hidden_states_sizes = [ + x.shape[0] for x in chunk_hidden_states + ] + local_hidden_states = chunk_hidden_states[self.tp_rank] + else: + local_hidden_states = hidden_states + chunked_hidden_states_sizes = None + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(local_hidden_states) + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if self.config.n_routed_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=self.config.num_experts_per_tok, + bias=self.gate.e_score_correction_bias, + k_group=self.config.topk_group, # fix: 4 + group_count=self.config.n_group, # fix 8 + group_select_mode=1, # 0: max in group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + routed_scaling_factor=1, + eps=float(1e-20)) + else: + topk_weights, topk_ids = select_experts( + hidden_states=local_hidden_states, + router_logits=router_logits, + top_k=self.config.num_experts_per_tok, + use_grouped_topk=True, + renormalize=self.config.norm_topk_prob, + topk_group=self.config.topk_group, + num_expert_group=self.config.n_group, + custom_routing_function=None, + scoring_func=self.config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias, + ) + + topk_weights = topk_weights.to(hidden_states.dtype) + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, + self.config.n_routed_experts) + + return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes + + def _forward_op_shared_experts(self, hidden_states): + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + return shared_output + + def _forward_op_grouped_mlp(self, dispatched_input, tokens_per_expert): + from vllm_ascend.ops.fused_moe import apply_mlp + return apply_mlp(dispatched_input, self.experts.w13_weight, + self.experts.w2_weight, tokens_per_expert) + + def _forward_combine_comm(self, hidden_states, microbatch_id, num_tokens, + chunked_hidden_states_sizes): + token_dispatcher = self.experts.token_dispatchers[microbatch_id] + final_hidden_states, _ = token_dispatcher.token_unpermutation( + hidden_states) + if hasattr(self, 'routed_scaling_factor'): + final_hidden_states = final_hidden_states * self.routed_scaling_factor + + if self.tp_size > 1: + final_hidden_states = gather_from_sequence_parallel_region( + final_hidden_states, self.tp_group, + chunked_hidden_states_sizes) + if num_tokens < self.tp_size: + final_hidden_states = final_hidden_states[:num_tokens] + + if self.shared_experts is not None: + final_hidden_states = final_hidden_states + token_dispatcher.cached_shared_expert_output + token_dispatcher.cached_shared_expert_output.untyped_storage( + ).resize_(0) + token_dispatcher.cached_shared_expert_output = None + + final_hidden_states = final_hidden_states.view(num_tokens, -1) + + return final_hidden_states + class CustomDeepseekDBODecoderLayer(CustomDeepseekV2DecoderLayer): @@ -587,6 +739,142 @@ def _forward_ms_layer( context.after_comm_event.record() return hidden_states, residual + # ----------------------------------------- TBO-related -------------------------------------------- + def _forward_ms_layer_alltoallv_finegrained( + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: List[torch.Tensor], + attn_metadata: List[AttentionMetadata], + kv_cache: Optional[torch.Tensor] = None, + is_prefill: bool = False, + ) -> tuple[List[torch.Tensor], List[torch.Tensor]]: + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + assert layer_index >= 0 and ms_metadata is not None + num_micro_batchs = ms_metadata.ms_config.num_micro_batches + assert isinstance(self.mlp, CustomDeepseekDBOMoE) + assert len(positions) == num_micro_batchs + assert len(hidden_states) == num_micro_batchs + assert residual is not None + assert attn_metadata is not None + num_tokens = [None] * num_micro_batchs + hidden_dims = [None] * num_micro_batchs + topk_weights, topk_ids = [None] * num_micro_batchs, [ + None + ] * num_micro_batchs + tokens_per_expert = [None] * num_micro_batchs + dispatched_input = [None] * num_micro_batchs + router_expert_output = [None] * num_micro_batchs + chunked_hidden_states_sizes = [None] * num_micro_batchs + token_dispatchers = self.mlp.experts.token_dispatchers + + def discard_tensor(tensor): + if isinstance(tensor, torch.Tensor): + tensor = [tensor] + for t in tensor: + t.untyped_storage().resize_(0) + + # block 1 : attention + # block 2 : Router Gating + # block 3 : Token DisPatch + # the attn computation of microbatch 1 can be overlapped with the moe + # communication in the previous layer, and the attn computation of microbatch 2 + # can be overlapped with the attn communication of microbatch 1 + for i in range(num_micro_batchs): + # wait last layer moe finishing communication + + forward_context = get_forward_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) + forward_context.attn_metadata = attn_metadata[i] + + # input layernorm + hidden_states[i], residual[ + i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i]) + # attention and tp allreduce + hidden_states[i], residual[i] = self._forward_ms_op_attn( + positions[i], hidden_states[i], residual[i], kv_cache, + attn_metadata[i]) + # post attention layer norm + hidden_states[i], residual[ + i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i]) + num_tokens[i], hidden_dims[i] = hidden_states[i].shape + # If TP is enabled, hidden_states will be chunked. + topk_weights[i], topk_ids[i], dispatched_input[ + i], chunked_hidden_states_sizes[ + i] = self.mlp._forward_op_gating(hidden_states[i], + attn_metadata[i]) + token_dispatchers[i].preprocess_and_permtute1( + dispatched_input[i], + topk_weights[i], + topk_ids[i], + self.mlp.shared_experts, + shared_experts_input=hidden_states[i] + if self.mlp.n_shared_experts else None) + # Launch DisPatch Comm in a New Stream. + dispatch_context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_BEFORE_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + ) + dispatch_context.before_comm_event.record() + # print_with_sync(f'begin token dispatch{i}...', torch.distributed.get_rank()) + with torch.npu.stream(dispatch_context.comm_stream): + dispatch_context.comm_stream.wait_event( + dispatch_context.before_comm_event) + token_dispatchers[i].dispatch_alltoall() + dispatched_input[i], tokens_per_expert[i] = token_dispatchers[ + i].permute2() + dispatch_context.after_comm_event.record() + + if self.mlp.n_shared_experts and self.tp_size > 1: + token_dispatchers[ + i].cached_shared_expert_output = tensor_model_parallel_all_reduce( + token_dispatchers[i].cached_shared_expert_output) + ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_SE_COMM_FINISH].record() + + # print_with_sync('begin experts...', torch.distributed.get_rank()) + # block 4 : Router Experts Computation + # block 5 : Token Combine Communication + for i in range(num_micro_batchs): + + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_AFTER_COMM) + discard_tensor(hidden_states[i]) + + router_expert_output[i] = self.mlp._forward_op_grouped_mlp( + dispatched_input[i], tokens_per_expert[i]) + discard_tensor(dispatched_input[i]) + + # Launch Combine Comm in a New Stream. + combine_context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_AR_FINISH], + ) + combine_context.before_comm_event.record() + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_SE_COMM_FINISH) + with torch.npu.stream(combine_context.comm_stream): + combine_context.comm_stream.wait_event( + combine_context.before_comm_event) + hidden_states[i] = self.mlp._forward_combine_comm( + router_expert_output[i], i, num_tokens[i], + chunked_hidden_states_sizes[i]) + combine_context.after_comm_event.record() + + return hidden_states, residual + # should split ops in Decoder Layer def _forward_ms_op_input_layernorm( self, @@ -721,7 +1009,6 @@ def forward( if VLLM_ASCEND_ENABLE_DBO and not graph_enable and self.can_run_ms() else self.end_layer - self.start_layer) - moe_start_layer = self.start_layer + num_normal_layers for i in range(self.start_layer, min(moe_start_layer, self.end_layer)): layer = self.layers[i] @@ -768,13 +1055,17 @@ def _forward_ms_layers(self, if moe_start_layer == self.end_layer: return hidden_states, residual + fused_moe_state = get_forward_context().fused_moe_state attn_metadata, [positions, hidden_states, residual] = self.ms_pre_layer( [positions, hidden_states, residual], ) # the rest layers for i in range(moe_start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer._forward_ms_layer( + ms_layer_forward_func = layer._forward_ms_layer + if fused_moe_state == FusedMoEState.All2AllSeq: + ms_layer_forward_func = layer._forward_ms_layer_alltoallv_finegrained + hidden_states, residual = ms_layer_forward_func( positions=positions, hidden_states=hidden_states, residual=residual, diff --git a/vllm_ascend/models/qwen3_dbo.py b/vllm_ascend/models/qwen3_dbo.py new file mode 100644 index 000000000..fa87fe81f --- /dev/null +++ b/vllm_ascend/models/qwen3_dbo.py @@ -0,0 +1,552 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. + +# """Inference-only Qwen3 model.""" +from types import SimpleNamespace +from typing import List, Optional, Union + +import torch +import torch_npu +import vllm.model_executor.models.qwen3_moe as qwen3 +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.forward_context import get_forward_context, set_forward_context +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.qwen3_moe import (Qwen3MoeDecoderLayer, + Qwen3MoeForCausalLM, + Qwen3MoeModel) +from vllm.model_executor.models.utils import ( + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.sequence import IntermediateTensors + +import vllm_ascend.envs as envs_ascend +from vllm_ascend.distributed.tensor_parallel import \ + gather_from_sequence_parallel_region +from vllm_ascend.multistream.base import MSEventKey +from vllm_ascend.multistream.context import ( + advance_step_multistream_layer_context, get_multistream_layer_context) +from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer, + MultiStreamPreTransformerLayer) +from vllm_ascend.multistream.metadata import (MultiStreamConfig, + MultiStreamStepMetadata, + make_multistream_metadata_ds) +from vllm_ascend.ops.fused_moe import (AscendSparseMoeBlock, apply_mlp, + select_experts) + +VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO + + +class Qwen3MoeDecoderLayerDBO(Qwen3MoeDecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super(Qwen3MoeDecoderLayerDBO, self).__init__(config, cache_config, + quant_config, prefix) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + self.tp_group = get_tp_group().device_group + self.dummy_vllm_config = SimpleNamespace( + parallel_config=SimpleNamespace(data_parallel_size=1, ), + compilation_config=SimpleNamespace(static_forward_context=None, ), + other_setting="value", + ) + self.config = config + + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + + # should split ops in Decoder Layer + def _forward_ms_op_input_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + return hidden_states, residual + + def _forward_ms_op_attn( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + self.dummy_vllm_config.compilation_config.static_forward_context = ( + get_forward_context().no_compile_layers) + with set_forward_context(attn_metadata, self.dummy_vllm_config): + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1.0 / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1.0 / self.routed_scaling_factor + return hidden_states, residual + + def _forward_ms_op_post_attn_layernorm( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ): + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + return hidden_states, residual + + def _forward_op_gating( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = get_forward_context().in_profile_run + + num_tokens, hidden_dim = hidden_states.shape + + if self.tp_size > 1: + # pass + num_tokens, hidden_size = hidden_states.shape + if num_tokens < self.tp_size: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, self.tp_size - num_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + self.tp_size, + dim=0) + chunked_hidden_states_sizes = [ + x.shape[0] for x in chunk_hidden_states + ] + local_hidden_states = chunk_hidden_states[self.tp_rank] + else: + local_hidden_states = hidden_states + chunked_hidden_states_sizes = None + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.mlp.gate(local_hidden_states) + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + mlp_config = self.config + if mlp_config.num_experts == 256: + topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( + router_logits, + k=mlp_config.num_experts_per_tok, # topk当前写8 + bias=self.mlp.gate.e_score_correction_bias, + k_group=mlp_config.topk_group, # fix: 4 + group_count=mlp_config.n_group, # fix 8 + group_select_mode=1, # 0: max in group; 1: topk2.sum(fix) + renorm=0, # 0: softmax->topk(fix); 1: topk->softmax + norm_type=1, # 0: softmax; 1: sigmoid(fix) + routed_scaling_factor=1, + eps=float(1e-20), + ) + else: + topk_weights, topk_ids = select_experts( + hidden_states=local_hidden_states, + router_logits=router_logits, + top_k=mlp_config.num_experts_per_tok, + use_grouped_topk=False, + renormalize=mlp_config.norm_topk_prob, + topk_group=getattr(mlp_config, "topk_group", None), + num_expert_group=getattr(mlp_config, "n_group", None), + custom_routing_function=None, + scoring_func=getattr(mlp_config, "scoring_func", "softmax"), + e_score_correction_bias=getattr(self.mlp.gate, + "e_score_correction_bias", + None), + ) + + topk_weights = topk_weights.to(hidden_states.dtype) + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance: + topk_ids = torch.randint_like(topk_ids, 0, self.config.num_experts) + + return topk_weights, topk_ids, local_hidden_states, chunked_hidden_states_sizes + + def _forward_op_grouped_mlp(self, dispatched_input, tokens_per_expert): + return apply_mlp( + dispatched_input, + self.mlp.experts.w13_weight, + self.mlp.experts.w2_weight, + tokens_per_expert, + ) + + def _forward_combine_comm(self, hidden_states, microbatch_id, num_tokens, + chunked_hidden_states_sizes): + token_dispatcher = self.mlp.experts.token_dispatchers[microbatch_id] + final_hidden_states, _ = token_dispatcher.token_unpermutation( + hidden_states) + if hasattr(self.mlp, "routed_scaling_factor"): + final_hidden_states = final_hidden_states * self.mlp.routed_scaling_factor + + if self.tp_size > 1: + final_hidden_states = gather_from_sequence_parallel_region( + final_hidden_states, self.tp_group, + chunked_hidden_states_sizes) + if num_tokens < self.tp_size: + final_hidden_states = final_hidden_states[:num_tokens] + + if hasattr(self.mlp, "shared_experts"): + final_hidden_states = ( + final_hidden_states + + token_dispatcher.cached_shared_expert_output) + token_dispatcher.cached_shared_expert_output.untyped_storage( + ).resize_(0) + token_dispatcher.cached_shared_expert_output = None + + final_hidden_states = final_hidden_states.view(num_tokens, -1) + + return final_hidden_states + + def _forward_ms_layer_alltoallv_finegrained( + self, + positions: List[torch.Tensor], + hidden_states: List[torch.Tensor], + residual: List[torch.Tensor], + attn_metadata: List[AttentionMetadata], + kv_cache: Optional[torch.Tensor] = None, + ): + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + assert layer_index >= 0 and ms_metadata is not None + num_micro_batchs = ms_metadata.ms_config.num_micro_batches + assert len(positions) == num_micro_batchs + assert len(hidden_states) == num_micro_batchs + assert residual is not None + assert attn_metadata is not None + num_tokens = [None] * num_micro_batchs + hidden_dims = [None] * num_micro_batchs + topk_weights, topk_ids = [None] * num_micro_batchs, [ + None + ] * num_micro_batchs + tokens_per_expert = [None] * num_micro_batchs + dispatched_input = [None] * num_micro_batchs + router_expert_output = [None] * num_micro_batchs + chunked_hidden_states_sizes = [None] * num_micro_batchs + token_dispatchers = self.mlp.experts.token_dispatchers + + def discard_tensor(tensor): + if isinstance(tensor, torch.Tensor): + tensor = [tensor] + for t in tensor: + t.untyped_storage().resize_(0) + + # block 1 : attention + # block 2 : Router Gating + # block 3 : Token DisPatch + # the attn computation of microbatch 1 can be overlapped with the moe + # communication in the previous layer, and the attn computation of microbatch 2 + # can be overlapped with the attn communication of microbatch 1 + for i in range(num_micro_batchs): + forward_context = get_forward_context() + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) + forward_context.attn_metadata = attn_metadata[i] + + # input layernorm + hidden_states[i], residual[ + i] = self._forward_ms_op_input_layernorm( + hidden_states[i], residual[i]) + # attention and tp allreduce + hidden_states[i], residual[i] = self._forward_ms_op_attn( + positions[i], hidden_states[i], residual[i], kv_cache, + attn_metadata[i]) + # post attention layer norm + hidden_states[i], residual[ + i] = self._forward_ms_op_post_attn_layernorm( + hidden_states[i], residual[i]) + num_tokens[i], hidden_dims[i] = hidden_states[i].shape + # If TP is enabled, hidden_states will be chunked. + ( + topk_weights[i], + topk_ids[i], + dispatched_input[i], + chunked_hidden_states_sizes[i], + ) = self._forward_op_gating(hidden_states[i], attn_metadata[i]) + token_dispatchers[i].preprocess_and_permtute1( + dispatched_input[i], + topk_weights[i], + topk_ids[i], + shared_experts=None, + shared_experts_input=None, + ) + # Launch DisPatch Comm in a New Stream. + dispatch_context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_BEFORE_COMM], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.MOE_AFTER_COMM], + ) + dispatch_context.before_comm_event.record() + # print_with_sync(f'begin token dispatch{i}...', torch.distributed.get_rank()) + with torch.npu.stream(dispatch_context.comm_stream): + dispatch_context.comm_stream.wait_event( + dispatch_context.before_comm_event) + token_dispatchers[i].dispatch_alltoall() + dispatched_input[i], tokens_per_expert[i] = token_dispatchers[ + i].permute2() + dispatch_context.after_comm_event.record() + + # print_with_sync('begin experts...', torch.distributed.get_rank()) + # block 4 : Router Experts Computation + # block 5 : Token Combine Communication + for i in range(num_micro_batchs): + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_AFTER_COMM) + discard_tensor(hidden_states[i]) + router_expert_output[i] = self._forward_op_grouped_mlp( + dispatched_input[i], tokens_per_expert[i]) + discard_tensor(dispatched_input[i]) + + # Launch Combine Comm in a New Stream. + combine_context = MultiStreamStepMetadata( + comm_stream=ms_metadata.communicate_stream, + before_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_COM_FINISH], + after_comm_event=ms_metadata.ms_events[layer_index][i][ + MSEventKey.FFN_AR_FINISH], + ) + combine_context.before_comm_event.record() + ms_metadata.try_wait_event(layer_index, i, + MSEventKey.MOE_SE_COMM_FINISH) + with torch.npu.stream(combine_context.comm_stream): + combine_context.comm_stream.wait_event( + combine_context.before_comm_event) + hidden_states[i] = self._forward_combine_comm( + router_expert_output[i], + i, + num_tokens[i], + chunked_hidden_states_sizes[i], + ) + ms_metadata.ms_events[layer_index][i][ + MSEventKey. + FFN_AR_FINISH] = combine_context.comm_stream.record_event( + ) + + return hidden_states, residual + + +@support_torch_compile +class CustomQwen3DBOMoEModel(Qwen3MoeModel): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + prefix=f"{prefix}.embed_tokens") + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen3MoeDecoderLayerDBO( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size) + + # dbo related members + if VLLM_ASCEND_ENABLE_DBO: + self.use_mla = False + self.multistream_config = MultiStreamConfig() + multistream_metadata = make_multistream_metadata_ds( + start_layer=self.start_layer, + end_layer=self.end_layer, + causal_lm=getattr(config, "causal_lm", True), + multistream_config=self.multistream_config, + ) + self.ms_pre_layer = MultiStreamPreTransformerLayer( + multistream_metadata) + self.ms_post_layer = MultiStreamPostTransformerLayer( + multistream_metadata) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + num_normal_layers = (0 if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms() + else self.end_layer - self.start_layer) + + moe_start_layer = self.start_layer + num_normal_layers + for i in range(self.start_layer, min(moe_start_layer, self.end_layer)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if moe_start_layer < self.end_layer: + # if we enable multistream/dbo, process sparse layers here + hidden_states, residual = self._forward_ms_layers( + positions=positions, + hidden_states=hidden_states, + residual=residual, + moe_start_layer=moe_start_layer, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def can_run_ms(self): + attn_metadata = get_forward_context().attn_metadata + # enable prefill overlap + with_prefill = get_forward_context().with_prefill + if (attn_metadata is None or not with_prefill + or not attn_metadata.enable_dbo_across_dp): + return False + + return True + + def _forward_ms_layers( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor, + moe_start_layer: int, + kv_caches: Optional[List[torch.Tensor]] = None, + ): + + if moe_start_layer == self.end_layer: + return hidden_states, residual + + attn_metadata, [positions, hidden_states, + residual] = self.ms_pre_layer( + [positions, hidden_states, residual], ) + num_micro_batch = len(attn_metadata) + # the rest layers + for i in range(moe_start_layer, self.end_layer): + layer = self.layers[i] + ms_layer_forward_func = layer._forward_ms_layer_alltoallv_finegrained + # print("get_called......") + hidden_states, residual = ms_layer_forward_func( + positions=positions, + hidden_states=hidden_states, + residual=residual, + attn_metadata=attn_metadata, + ) + advance_step_multistream_layer_context() + + layer_index, ms_metadata, attn_metadata = get_multistream_layer_context( + ) + for i in range(num_micro_batch): + ms_metadata.try_wait_event(layer_index - 1, i, + MSEventKey.FFN_AR_FINISH) + + [hidden_states, + residual] = self.ms_post_layer([hidden_states, residual], ) + return hidden_states, residual + + +class CustomQwen3MoeForCausalLMDBO(Qwen3MoeForCausalLM): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + } + qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = CustomQwen3DBOMoEModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward(self, *args, **kwargs): + if "graph_enable" in kwargs: + kwargs.pop("graph_enable") + return super().forward(*args, **kwargs) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 8ff1b52a7..485e5ca92 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -16,8 +16,11 @@ # Adapted from vllm/model_executor/models/qwen3_moe.py # This file is a part of the vllm-ascend project. +import vllm.model_executor.models.qwen3_moe as qwen3 from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM +from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock + class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): packed_modules_mapping = { @@ -33,3 +36,4 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM): "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], } + qwen3.Qwen3MoeSparseMoeBlock = AscendSparseMoeBlock diff --git a/vllm_ascend/multistream/ms_split.py b/vllm_ascend/multistream/ms_split.py index b80446e32..9ceae1053 100644 --- a/vllm_ascend/multistream/ms_split.py +++ b/vllm_ascend/multistream/ms_split.py @@ -4,7 +4,8 @@ import numpy as np import torch -from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.attention_v1 import (AscendAttentionState, + AscendMetadata) from .base import MSAttentionMetadataSplitConfig @@ -246,3 +247,115 @@ def model_input_split_v1_mla_attn( enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, ) return [attention_metadata_pre, attention_metadata_post] + + +def model_input_split_v1_attn( + attn_metadata: AscendMetadata, + _metadata_cls, + ms_split_config: MSAttentionMetadataSplitConfig, +) -> List[Any]: + assert 0 < ms_split_config.num_micro_batches < 3 + if attn_metadata is None: + return [attn_metadata] + [token_index, + seq_index] = compute_split_seq_index(attn_metadata.query_lens, + attn_metadata.attn_state, + attn_metadata.num_actual_tokens) + if token_index == 0 or seq_index == 0 or seq_index == len( + attn_metadata.query_lens): + return [attn_metadata] + + # split attn metadata + + [block_table_pre, + block_table_post] = split_attn_tensor_type(attn_metadata.block_tables, + seq_index) + + query_start_loc_pre = query_start_loc_post = None + if attn_metadata.query_start_loc is not None: + query_start_loc_pre = attn_metadata.query_start_loc[:seq_index + 1] + query_start_loc_post = deepcopy( + attn_metadata.query_start_loc[seq_index:] + ) - attn_metadata.query_start_loc[seq_index] + + [query_lens_pre, + query_lens_post] = split_attn_tensor_type(attn_metadata.query_lens, + seq_index) + [seq_lens_pre, + seq_lens_post] = split_attn_tensor_type(attn_metadata.seq_lens, seq_index) + + max_query_len_pre = max_query_len_post = None + if attn_metadata.max_query_len is not None: + max_query_len_pre, max_query_len_post = max(query_lens_pre), max( + query_lens_post) + + [slot_mapping_pre, + slot_mapping_post] = split_attn_tensor_type(attn_metadata.slot_mapping, + token_index) + + is_only_prefill_pre = is_only_prefill_post = attn_metadata.is_only_prefill + has_prefill_pre, _ = torch.any(query_lens_pre > 1).item(), torch.any( + query_lens_post > 1).item() + + if not attn_metadata.is_only_prefill: + is_only_prefill_post = torch.all(query_lens_post > 1).item() + + if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: + # the attn_mla kernel in torch npu only accept 128*128 attn mask + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = attn_metadata.attn_state + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + # should be none in decode only state + attn_mask_pre = attn_mask_post = attn_metadata.attn_mask + attn_state_pre = attn_state_post = AscendAttentionState.DecodeOnly # type: ignore + else: + # chunked prefill + assert attn_metadata.attn_mask is not None + if has_prefill_pre: + attn_state_pre = attn_state_post = AscendAttentionState.ChunkedPrefill # type: ignore + attn_mask_pre = attn_metadata.attn_mask[:token_index, :max( + seq_lens_pre)].contiguous() + attn_state_post = AscendAttentionState.ChunkedPrefill # type: ignore + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + else: + attn_state_pre = AscendAttentionState.DecodeOnly # type: ignore + attn_mask_pre = None + attn_state_post = AscendAttentionState.ChunkedPrefill # type: ignore + attn_mask_post = attn_metadata.attn_mask[ + token_index:, :max(seq_lens_post)].contiguous() + + # construct metadata + attention_metadata_pre = _metadata_cls( + num_actual_tokens=token_index, + block_tables=block_table_pre, + query_start_loc=query_start_loc_pre, + query_lens=query_lens_pre, + seq_lens=seq_lens_pre, + seq_lens_list=seq_lens_pre.tolist(), + max_query_len=max_query_len_pre, + slot_mapping=slot_mapping_pre, + is_only_prefill=is_only_prefill_pre, + attn_state=attn_state_pre, + attn_mask=attn_mask_pre, + num_input_tokens=token_index, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, + ) + + attention_metadata_post = _metadata_cls( + num_actual_tokens=attn_metadata.num_actual_tokens - token_index, + block_tables=block_table_post, + query_start_loc=query_start_loc_post, + query_lens=query_lens_post, + seq_lens=seq_lens_post, + seq_lens_list=seq_lens_post.tolist(), + max_query_len=max_query_len_post, + slot_mapping=slot_mapping_post, + is_only_prefill=is_only_prefill_post, + attn_state=attn_state_post, + attn_mask=attn_mask_post, + num_input_tokens=attn_metadata.num_input_tokens - token_index, + enable_dbo_across_dp=attn_metadata.enable_dbo_across_dp, + ) + + return [attention_metadata_pre, attention_metadata_post] diff --git a/vllm_ascend/ops/comm_utils.py b/vllm_ascend/ops/comm_utils.py new file mode 100644 index 000000000..6c4377330 --- /dev/null +++ b/vllm_ascend/ops/comm_utils.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +import torch +import torch.distributed +import torch.distributed as dist +import torch_npu + +COMM_STREAM = None + + +def async_all_gather(input_, + group, + event=None, + is_use_get_global_memory_buffer=False): + world_size = torch.distributed.get_world_size(group) + dim_size = list(input_.size()) + new_dim_size = dim_size[0] * world_size + dim_size[0] = new_dim_size + + ag_out = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + if event: + # multi stream wait event + global COMM_STREAM + if COMM_STREAM is None: + COMM_STREAM = torch_npu.npu.Stream( + device=torch.npu.current_device()) + with torch_npu.npu.stream(COMM_STREAM): + event.wait() + handle = torch.distributed._all_gather_base(ag_out, + input_.contiguous(), + group=group, + async_op=True) + else: + handle = torch.distributed._all_gather_base(ag_out, + input_.contiguous(), + group=group, + async_op=True) + return input_, ag_out, handle + + +def async_reduce_scatter(input_, + group, + event=None, + stream=None, + is_use_get_global_memory_buffer=False): + world_size = dist.get_world_size(group) + dim_size = list(input_.size()) + dim_size[0] = dim_size[0] // world_size + + rs_out = torch.empty(dim_size, + dtype=input_.dtype, + device=torch.npu.current_device()) + if event or stream: + # multi stream wait event + global COMM_STREAM + if COMM_STREAM is None: + COMM_STREAM = torch_npu.npu.Stream( + device=torch.npu.current_device()) + with torch_npu.npu.stream(COMM_STREAM): + if event: + event.wait() + if stream: + torch.npu.current_stream().wait_stream(stream) + handle = torch.distributed.reduce_scatter_tensor( + rs_out, input_.contiguous(), group=group, async_op=True) + else: + handle = torch.distributed.reduce_scatter_tensor(rs_out, + input_.contiguous(), + group=group, + async_op=True) + return input_, rs_out, handle + + +def async_all_to_all(input_, + output_split_sizes, + input_split_sizes, + group, + event=None): + if output_split_sizes is None: + # Equal split (all2all) + a2a_out = torch.empty_like(input_) + else: + # Unequal split (all2all-v) + a2a_out = input_.new_empty( + size=[sum(output_split_sizes)] + list(input_.size()[1:]), + dtype=input_.dtype, + device=torch.npu.current_device(), + ) + + if event: + # multi stream wait event + global COMM_STREAM + if COMM_STREAM is None: + COMM_STREAM = torch_npu.npu.Stream( + device=torch.npu.current_device()) + with torch_npu.npu.stream(COMM_STREAM): + event.wait() + handle = dist.all_to_all_single( + a2a_out, + input_.contiguous(), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True) + else: + handle = dist.all_to_all_single(a2a_out, + input_.contiguous(), + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=True) + return input_, a2a_out, handle diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 14ff9ef70..3fa9c8be7 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -17,12 +17,14 @@ import math import os -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.distributed as dist import torch_npu from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -33,6 +35,7 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, determine_expert_map) +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig @@ -41,16 +44,23 @@ from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( + MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_ascend_soc_version, npu_stream_switch, npu_wait_tensor) -MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER +VLLM_ASCEND_MOE_ALL2ALL_BUFFER: bool = envs_ascend.VLLM_ASCEND_MOE_ALL2ALL_BUFFER -def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, - max_row_per_ep_rank: int, num_tokens: int, - top_k: int) -> tuple[torch.Tensor, torch.Tensor]: +def process_topk_ids( + topk_ids: torch.Tensor, + expert_num: int, + ep_size: int, + max_row_per_ep_rank: int, + num_tokens: int, + top_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: original_total_elements = num_tokens * top_k device = topk_ids.device original_dtype = topk_ids.dtype @@ -77,8 +87,10 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, experts_per_ep_rank_val).to(original_dtype) indices_arange = torch.arange(topk_ids.shape[0], device=device) - is_new_segment = torch.cat((torch.tensor([True], device=device), - assigned_ep_rank[1:] != assigned_ep_rank[:-1])) + is_new_segment = torch.cat(( + torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1], + )) temp_start_markers = torch.full_like(indices_arange, -1, dtype=indices_arange.dtype) @@ -89,15 +101,18 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) indices_in_rec_cond_list_for_all = cumsum_kept - 1 unpad_indices = torch.where( - is_kept_mask, indices_in_rec_cond_list_for_all, - torch.tensor(-1, device=device, dtype=torch.long)) + is_kept_mask, + indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long), + ) output_len = ep_size * max_row_per_ep_rank topk_ids_pad = torch.full((output_len, ), expert_num, dtype=original_dtype, device=device) if topk_ids.shape[0] > 0: - all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + all_destination_indices = (assigned_ep_rank * max_row_per_ep_rank + + token_intra_ep_rank_idx) temp_pad_buffer = torch.full((output_len + 1, ), expert_num, dtype=original_dtype, @@ -133,12 +148,13 @@ def fused_experts_with_mc2( # NOTE: `global_bs` should be equal to `max_num_tokens_across_dp` * `ep_world_size`, # and `max_num_tokens_across_dp` has been split into `tp_world_size` parts before. - global_bs = math.ceil(get_forward_context().max_tokens_across_dp / - tp_world_size) * ep_world_size + global_bs = ( + math.ceil(get_forward_context().max_tokens_across_dp / tp_world_size) * + ep_world_size) # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 - or is_torchair) + need_extra_args = get_ascend_soc_version( + ) == AscendSocVersion.A3 or is_torchair # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 @@ -256,11 +272,13 @@ def fused_experts_with_mc2( return hidden_states, shared_hidden_states -def apply_mlp(hidden_states_wrapper: List[torch.Tensor], - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - group_list_type: int = 1) -> torch.Tensor: +def apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, +) -> torch.Tensor: """ apply MLP: gate_up_proj -> swiglu -> down_proj @@ -282,9 +300,6 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], hidden_states: output hidden states after MLP. """ - assert len(hidden_states_wrapper) == 1 - hidden_states = hidden_states_wrapper.pop() - w1 = w1.transpose(1, 2) hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -312,6 +327,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor], return hidden_states +# currently expert parallelism implemented with all2all +# is under-optimized. def fused_experts_with_all2all( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -339,11 +356,13 @@ def fused_experts_with_all2all( dtype=torch.int32, device=device).view(top_k, -1).permute( 1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens, + )) global_expert_tokens = torch.bincount(expanded_expert_idx, minlength=global_num_experts) @@ -373,16 +392,18 @@ def fused_experts_with_all2all( hidden_states = hidden_states[sorted_idx] else: row_idx_len = num_tokens * top_k - row_idx = torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=topk_weights.device).view( - top_k, -1).permute(1, 0).contiguous() - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + row_idx = (torch.arange(0, + row_idx_len, + dtype=torch.int32, + device=topk_weights.device).view( + top_k, -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens, + )) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expanded_expert_idx, num_experts) @@ -474,45 +495,56 @@ def fused_experts_with_all2all_buffer( row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, device=device).view(top_k, -1).permute(1, 0).contiguous()) - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * - max_model_len // ep_group.world_size + - 1) * top_k * 2 + hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing(hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens)) + + max_row_per_ep_rank = ( + (-(-global_batch_size // ep_group.world_size) * max_model_len * + get_dp_group().world_size // ep_group.world_size + 1) * top_k * 2) expert_idx_buffer_scatter, unpad_indices = process_topk_ids( - expanded_expert_idx, global_num_experts, ep_group.world_size, - max_row_per_ep_rank, num_tokens, top_k) + expanded_expert_idx, + global_num_experts, + ep_group.world_size, + max_row_per_ep_rank, + num_tokens, + top_k, + ) hidden_states_pad_idx = torch.zeros( expert_idx_buffer_scatter.shape, dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) + device=expert_idx_buffer_scatter.device, + ) non_pad_len = torch.sum( (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) - hidden_states_pad_idx[ - expert_idx_buffer_scatter != global_num_experts] = torch.arange( + hidden_states_pad_idx[expert_idx_buffer_scatter != global_num_experts] = ( + torch.arange( non_pad_len, dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) + device=hidden_states.device, + )) hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] expert_idx_buffer_gather = torch.empty_like( expert_idx_buffer_scatter, dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) + device=expert_idx_buffer_scatter.device, + ) hidden_states_buffer_gather = torch.empty_like( hidden_states_buffer_scatter, dtype=hidden_states_buffer_scatter.dtype, - device=hidden_states_buffer_scatter.device) + device=hidden_states_buffer_scatter.device, + ) dist.all_to_all_single(expert_idx_buffer_gather, expert_idx_buffer_scatter, group=ep_group.device_group) - dist.all_to_all_single(hidden_states_buffer_gather, - hidden_states_buffer_scatter, - group=ep_group.device_group) + dist.all_to_all_single( + hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group, + ) mask = expert_idx_buffer_gather != global_num_experts local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( global_num_experts // ep_group.world_size) @@ -526,10 +558,7 @@ def fused_experts_with_all2all_buffer( hidden_states = hidden_states[sorted_idx] group_list_type = 0 - hidden_states_wrapper = [hidden_states] - del hidden_states - - hidden_states = apply_mlp(hidden_states_wrapper, + hidden_states = apply_mlp(hidden_states, w1, w2, expert_tokens, @@ -540,21 +569,25 @@ def fused_experts_with_all2all_buffer( hidden_states_scatter = torch.zeros( (mask.shape[0], hidden_states.shape[1]), dtype=hidden_states.dtype, - device=hidden_states.device) + device=hidden_states.device, + ) hidden_states_scatter[mask] = hidden_states hidden_states_gatter = torch.empty_like( hidden_states_scatter, dtype=hidden_states_scatter.dtype, - device=hidden_states_scatter.device) + device=hidden_states_scatter.device, + ) dist.all_to_all_single(hidden_states_gatter, hidden_states_scatter, group=ep_group.device_group) hidden_states_gatter = hidden_states_gatter[ expert_idx_buffer_scatter != global_num_experts] if hidden_states_gatter.shape[0] != row_idx_len: - hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) + hidden_states = torch.zeros( + (row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) hidden_states[unpad_indices != -1] = hidden_states_gatter else: # TODO: Reorder device memory 2 times here, replace the current @@ -574,6 +607,24 @@ def fused_experts_with_all2all_buffer( return final_hidden_states +def fused_experts_with_all2allv( + token_dispatcher, + probs, + routing_map, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, +): + # Enable moe alltoallv, it's a balanced policy for precision and efficiency. + (share_experts_output, dispatched_input, + tokens_per_expert) = (token_dispatcher.token_permutation( + hidden_states, probs, routing_map)) + + expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert) + output, mlp_bias = token_dispatcher.token_unpermutation(expert_output) + return output + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -676,11 +727,13 @@ def fused_experts( dtype=torch.int32, device=device).view(top_k, -1).permute( 1, 0).contiguous()) - sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) + sorted_hidden_states, expanded_row_idx, expanded_expert_idx = ( + torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens, + )) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expanded_expert_idx, num_experts) @@ -723,16 +776,16 @@ def fused_experts( # This created multiple NaN and index_add_ will mix them up which harms accuracy # remove this mask and filter after it being fixed num_valid_tokens = mask.sum() - valid_token_mask = torch.arange( - 0, sorted_token_indices.shape[0], - device=device).unsqueeze(1) < num_valid_tokens + valid_token_mask = (torch.arange( + 0, sorted_token_indices.shape[0], device=device).unsqueeze(1) < + num_valid_tokens) valid_output = torch.where( valid_token_mask, weighted_down_out, torch.zeros_like(weighted_down_out)).to(dtype) final_hidden_states.index_add_(0, sorted_token_indices, valid_output) else: - scales = torch.ones_like( - topk_weights) if apply_router_weight_on_input else topk_weights + scales = (torch.ones_like(topk_weights) + if apply_router_weight_on_input else topk_weights) # TODO: Reorder device memory 2 times here, replace the current # implementation here when suitable operators become available. final_hidden_states = torch_npu.npu_moe_finalize_routing( @@ -757,8 +810,8 @@ def native_grouped_topk( num_expert_group = 0 if num_expert_group is None else num_expert_group num_token = topk_weights.shape[0] - grouped_weights = topk_weights.view(num_token, num_expert_group, - -1).max(dim=-1).values + grouped_weights = (topk_weights.view(num_token, num_expert_group, + -1).max(dim=-1).values) topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, @@ -851,7 +904,8 @@ def select_experts( hidden_states=hidden_states, gating_output=router_logits, topk=top_k, - renormalize=renormalize) + renormalize=renormalize, + ) # Required by npu_moe_init_routing topk_ids = topk_ids.to(torch.int32) return topk_weights, topk_ids @@ -933,7 +987,8 @@ def apply( # out_flag=False, # todo new api; 第三个输出是否输出 # y2_flag=False, # old api; 第三个输出是否输出 routed_scaling_factor=1, - eps=float(1e-20)) + eps=float(1e-20), + ) else: topk_weights, topk_ids = select_experts( hidden_states=x, @@ -969,16 +1024,19 @@ def apply( moe_all_to_all_group_name=self.moe_all_to_all_group_name, shared_experts=shared_experts, is_torchair=self.torchair_graph_enabled, - mc2_mask=mc2_mask) + mc2_mask=mc2_mask, + ) elif fused_moe_state == FusedMoEState.AllGather: - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) - elif MOE_ALL2ALL_BUFFER: + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ) + elif VLLM_ASCEND_MOE_ALL2ALL_BUFFER: return fused_experts_with_all2all_buffer( hidden_states=x, w1=layer.w13_weight, @@ -989,16 +1047,29 @@ def apply( max_model_len=self.max_model_len, global_batch_size=self.global_batch_size, expert_map=expert_map, - ep_group=get_ep_group()) + ep_group=get_ep_group(), + ) + elif fused_moe_state == FusedMoEState.All2AllSeq: + token_dispatcher = kwargs.get("token_dispatcher") + return fused_experts_with_all2allv( + token_dispatcher=token_dispatcher, + probs=topk_weights, + routing_map=topk_ids, + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + ) else: - return fused_experts_with_all2all(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - ep_group=get_ep_group()) + return fused_experts_with_all2all( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map, + ep_group=get_ep_group(), + ) class AscendFusedMoE(FusedMoE): @@ -1042,13 +1113,13 @@ def __init__( vllm_config = get_current_vllm_config() - self.moe_parallel_config: FusedMoEParallelConfig = ( - FusedMoEParallelConfig.make( - tp_size_=(tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()), - dp_size_=(dp_size if dp_size is not None else - get_dp_group().world_size), - vllm_parallel_config=vllm_config.parallel_config)) + self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_=(tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()), + dp_size_=(dp_size + if dp_size is not None else get_dp_group().world_size), + vllm_parallel_config=vllm_config.parallel_config, + ) self.top_k = top_k self.num_experts = num_experts @@ -1076,23 +1147,22 @@ def __init__( # moe expert load balance expert_load_balancer = ExpertLoadBalancer(expert_map_path, self.global_num_experts) - self.local_num_experts, self.expert_map = \ - expert_load_balancer.get_rank_placement_map( - self.moe_instance_id, - self.ep_rank) + self.local_num_experts, self.expert_map = ( + expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, self.ep_rank)) self.log2phy = expert_load_balancer.get_rank_log2phy_map( self.moe_instance_id, self.ep_rank) - self.global_redundant_expert_num = \ - expert_load_balancer.get_global_redundant_expert_num() + self.global_redundant_expert_num = ( + expert_load_balancer.get_global_redundant_expert_num()) else: # Create a tensor of size num_experts filled with -1 self.local_num_experts, self.expert_map = determine_expert_map( self.ep_size, self.ep_rank, self.global_num_experts) self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe and \ - self.torchair_graph_enabled + self.enable_multistream_moe = ( + ascend_config.torchair_graph_config.enable_multistream_moe + and self.torchair_graph_enabled) if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -1115,8 +1185,8 @@ def __init__( assert self.quant_method is not None - local_num_experts = torch.sum(self.expert_map != -1) \ - if self.expert_map is not None else num_experts + local_num_experts = (torch.sum(self.expert_map != -1) + if self.expert_map is not None else num_experts) moe_quant_params = { "num_experts": local_num_experts, @@ -1127,22 +1197,46 @@ def __init__( "weight_loader": self.weight_loader, } # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): + if self.quant_method.__class__.__name__ in ( + "GPTQMarlinMoEMethod", + "CompressedTensorsWNA16MoEMethod", + ): moe_quant_params["intermediate_size_full"] = intermediate_size + self.ep_group = get_ep_group() # NOTE: self.tp_group is not expert_tp_group self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) - - def forward(self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_prefill: bool, - enable_force_load_balance: bool = False, - top_k: Optional[int] = None, - shared_experts: Optional[Any] = None, - gate: Optional[Any] = None): + self.token_dispatcher = None + if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance( + self.quant_method, AscendUnquantizedFusedMoEMethod): + self.reduce_results = False + moe_dispatcher_config = ( + MoEDispatcherConfig().set_num_moe_experts( + self.global_num_experts).set_num_local_experts( + self.local_num_experts).set_moe_router_topk( + top_k).set_group_topk(topk_group). + set_num_groups(num_expert_group).set_expert_bias( + e_score_correction_bias).set_scaling_factor(1.0).build()) + self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher( + moe_dispatcher_config) + if envs_ascend.VLLM_ASCEND_ENABLE_DBO: + token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher( + moe_dispatcher_config) + self.token_dispatchers = [ + self.token_dispatcher, token_dispatcher1 + ] + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + enable_force_load_balance: bool = False, + top_k: Optional[int] = None, + shared_experts: Optional[Any] = None, + gate: Optional[Any] = None, + ): assert self.quant_method is not None if top_k: @@ -1158,15 +1252,16 @@ def forward(self, quantized_x_for_share, dynamic_scale_for_share = None, None from vllm_ascend.quantization.w8a8_dynamic import \ AscendW8A8DynamicFusedMoEMethod + if self.enable_multistream_moe: assert gate is not None router_logits, _ = gate(hidden_states) - if isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod - ) and fused_moe_state == FusedMoEState.MC2: + if (isinstance(self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod) + and fused_moe_state == FusedMoEState.MC2): with npu_stream_switch("moe_secondary", 0): - quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( - hidden_states) + quantized_x_for_share, dynamic_scale_for_share = ( + torch_npu.npu_dynamic_quant(hidden_states)) if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: @@ -1178,10 +1273,12 @@ def forward(self, if num_tokens < forward_context.padded_num_tokens: hidden_states = nn.functional.pad( hidden_states, - (0, 0, 0, forward_context.padded_num_tokens - num_tokens)) + (0, 0, 0, forward_context.padded_num_tokens - num_tokens), + ) router_logits = nn.functional.pad( router_logits, - (0, 0, 0, forward_context.padded_num_tokens - num_tokens)) + (0, 0, 0, forward_context.padded_num_tokens - num_tokens), + ) if tp_size > 1: chunk_hidden_states = torch.tensor_split(hidden_states, tp_size, @@ -1231,11 +1328,13 @@ def forward(self, enable_force_load_balance=enable_force_load_balance, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, - shared_experts=shared_experts if self.torchair_graph_enabled - and self.enable_multistream_moe and not is_prefill else None, + shared_experts=(shared_experts if self.torchair_graph_enabled + and self.enable_multistream_moe and not is_prefill + else None), quantized_x_for_share=quantized_x_for_share, dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=mc2_mask, + token_dispatcher=self.token_dispatcher, ) if shared_experts: @@ -1296,6 +1395,83 @@ def _forward_ms_fused_moe_comp( scoring_func=self.scoring_func, e_score_correction_bias=self.e_score_correction_bias, is_prefill=is_prefill, - enable_force_load_balance=enable_force_load_balance) + enable_force_load_balance=enable_force_load_balance, + ) + + return hidden_states + + +class AscendSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = ( + ascend_config.torchair_graph_config.enable_multistream_moe) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + ) -> torch.Tensor: + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = get_forward_context().in_profile_run + is_prefill = get_forward_context().with_prefill + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + ) return hidden_states diff --git a/vllm_ascend/ops/moe_dispatcher/__init__.py b/vllm_ascend/ops/moe_dispatcher/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py new file mode 100644 index 000000000..91118e296 --- /dev/null +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -0,0 +1,578 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2024; NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +import torch_npu +from vllm.distributed.parallel_state import get_ep_group + +from vllm_ascend.distributed.tensor_parallel import ( + all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, + all_to_all_sp2hp, gather_from_sequence_parallel_region, + reduce_scatter_last_dim_to_tensor_parallel_region) +from vllm_ascend.ops.comm_utils import async_all_to_all + + +class MoEDispatcherConfig: + + def __init__(self): + self.num_local_experts: int = 0 + self.num_moe_experts: int = 0 + self.moe_pad_expert_input_to_capacity: bool = False + self.moe_expert_capacity_factor: Optional[float] = None + self.moe_router_topk: int = 2 + self.moe_grouped_gemm: bool = False + self.group_topk: int = 0 + self.num_groups: int = 1 + self.expert_bias: torch.Tensor = None + self.scaling_factor: Optional[float] = None + self.is_fused: bool = True + + def set_num_local_experts(self, num_local_experts): + self.num_local_experts = num_local_experts + return self + + def set_num_moe_experts(self, num_moe_experts): + self.num_moe_experts = num_moe_experts + return self + + def set_moe_pad_expert_input_to_capacity(self, + moe_pad_expert_input_to_capacity): + self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity + return self + + def set_moe_expert_capacity_factor(self, moe_expert_capacity_factor): + self.moe_expert_capacity_factor = moe_expert_capacity_factor + return self + + def set_moe_router_topk(self, moe_router_topk): + self.moe_router_topk = moe_router_topk + return self + + def set_moe_grouped_gemm(self, moe_grouped_gemm): + self.moe_grouped_gemm = moe_grouped_gemm + return self + + def set_group_topk(self, group_topk): + self.group_topk = group_topk + return self + + def set_num_groups(self, num_groups): + self.num_groups = num_groups + return self + + def set_expert_bias(self, expert_bias): + self.expert_bias = expert_bias + return self + + def set_scaling_factor(self, scaling_factor): + self.scaling_factor = scaling_factor + return self + + def set_is_fused(self, is_fused): + self.is_fused = is_fused + return self + + def build(self): + return self + + +class MoEDispatcher: + + def __init__(self, config: MoEDispatcherConfig) -> None: + """ + Initialize the MoE Token Dispatcher. + """ + self.config = config + self.shared_experts = None + + def set_shared_experts(self, shared_experts): + self.shared_experts = shared_experts + + @property + def ep_group(self): + """Get expert model parallel group.""" + return get_ep_group().device_group + + @property + def ep_rank(self): + return get_ep_group().rank_in_group + + @property + def ep_size(self): + return get_ep_group().world_size + + @property + def tp_ep_group(self): + """Get expert tensor and model parallel group.""" + return None + + @property + def tp_ep_size(self): + return 1 + + +class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher): + overlap_stream = None + """ + The implementation of the AlltoAll-based token dispatcher, which handles token + dispatching on the sequence level instead of token level. The core of this implementation + lies in each device dispatching on the entire sequence, with the hidden state being partitioned. + + """ + + def __init__(self, config: MoEDispatcherConfig): + """ + Initialize the AlltoAllSeq token dispatcher. + + Args: + config (MoEDispatcherConfig): Configuration for the transformer model. + """ + super().__init__(config) + self.num_local_experts = config.num_local_experts + self.config = config + # use MOEAlltoAllSEQTokenDispatcher to init + + self.hidden_shape = None + self.num_input_tokens = None + self.num_experts = config.num_moe_experts + assert self.num_local_experts > 0, "Expected at least one expert" + if self.num_local_experts > 1: + self.expert_ids_per_ep_rank = torch.tensor( + [i % self.num_local_experts for i in range(self.num_experts)], + dtype=torch.int32, + device=torch.npu.current_device(), + ) + + local_expert_indices_offset = (self.ep_rank * self.num_local_experts) + + self.local_expert_indices = [ + local_expert_indices_offset + i + for i in range(self.num_local_experts) + ] + assert (len(self.local_expert_indices) == self.num_local_experts + ), "Invalid local expert indices" + for i in range(len(self.local_expert_indices) - 1): + assert (self.local_expert_indices[i] == + self.local_expert_indices[i + 1] - + 1), "local_expert_indices must be continuous" + self.probs = None + self.input_splits = None + self.output_splits = None + self.routing_map = None + self.hidden_shape_before_permute = None + + # [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert_cpu = None + self.num_global_tokens_per_local_expert = None + + # A cuda stream synchronization is needed in self.token_permutation() + # in some cases, because there are several non-blocking DtoH data + # transfers called in self.preprocess(). The synchronization happens + # at different points based on MoE settings as late as possible. + # Valid sync points are "before_permutation_1", "before_ep_alltoall", + # "before_finish", and "no_sync". + self.device_sync_point = "no_sync" + + # cached intermediate tensors. + self.cached_permutated_local_input_tokens = None + self.cached_global_input_tokens = None + self.cached_shared_expert_output = None + self.tokens_per_expert = None + self.perm1_finish_event = None + self.global_input_tokens_local_experts_indices = None + + if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None: + MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream() + + self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream + + def preprocess(self, + indices: torch.Tensor, + with_sync=True) -> torch.Tensor: + """ + Preprocess routing map for AlltoAll communication and token permutation. + This method computes the number of tokens assigned to each expert based on + the routing map. It also initializes the necessary data structures for + AlltoAll communication, such as input and output splits, and the mapping + between global tokens and local experts. + + Args: + routing_map (torch.Tensor): The mapping of tokens to experts, with shape + [num_tokens, num_experts]. + + Returns: + torch.Tensor: Tensor containing the number of tokens assigned to local expert. + """ + num_local_tokens_per_expert = torch.histc(indices, + bins=self.num_experts, + min=0, + max=self.num_experts) + + # num_local_tokens_per_expert: [num_experts] + + ep_size = self.ep_size + + # Dropless + self.num_out_tokens = indices.numel() + if self.ep_size > 1 or self.num_local_experts > 1: + # Token dropless and enable ep. A synchronization is needed before expert parallel + # AlltoAll communication to get the `input_splits` and `output_splits` CPU values. + self.device_sync_point = "before_ep_alltoall" + else: + # Token dropless and no ep. A synchronization is needed to get the + # `tokens_per_expert` CPU value. + self.device_sync_point = "before_finish" + + if ep_size > 1: + # =================================================== + # Calculate input_splits, output_splits for alltoall-v. + # =================================================== + self.input_splits = (num_local_tokens_per_expert.reshape( + ep_size, self.num_local_experts).sum(axis=1).to( + torch.device("cpu"), non_blocking=True).numpy()) + num_global_tokens_per_expert = gather_from_sequence_parallel_region( + num_local_tokens_per_expert, + group=self.ep_group).reshape(ep_size, self.num_experts) + self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ + 0]:self.local_expert_indices[-1] + 1] + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before sum." + ) + self.output_splits = (self.num_global_tokens_per_local_expert.sum( + axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()) + num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum( + axis=0) + # =================================================== + # num_global_tokens_per_expert: [ep_size, num_experts] + # num_global_tokens_per_local_expert: [ep_size, num_local_experts] + # num_tokens_per_local_expert: [num_local_experts] + # =================================================== + else: + self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape( + -1, self.num_experts) + num_tokens_per_local_expert = num_local_tokens_per_expert + + if self.num_local_experts > 1 and with_sync: + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before operations." + ) + self.device_sync_point = "no_sync" + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + self.expert_ids_per_ep_rank, + self.num_global_tokens_per_local_expert.ravel()) + + return num_tokens_per_local_expert + + def token_permutation( + self, + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + ): + """ + Dispatch tokens to local experts using AlltoAllSeq communication. + + Args: + hidden_states (torch.Tensor): Input token embeddings. + probs (torch.Tensor): Probs of tokens assigned to experts. + Shape: [num_tokens, num_experts]. + routing_map (torch.Tensor): Mapping of tokens assigned to experts. + Shape: [num_tokens, num_experts]. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - Permuted token embeddings for local experts. + - Number of tokens per expert. + """ + self.hidden_shape = hidden_states.shape + self.probs = probs + self.top_indices = routing_map + assert probs.dim() == 2, "Expected 2D tensor for probs" + assert routing_map.dim() == 2, "Expected 2D tensor for routing map" + + # Permutation 1: input to AlltoAll input + def alltoall_token_permutation1(hidden_states, routing_map): + assert self.hidden_shape is not None + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + tokens_per_expert = self.preprocess(routing_map) + if self.tp_ep_size > 1: + hidden_states = all_to_all_sp2hp(hidden_states, + group=self.tp_ep_group) + self.hidden_shape_before_permute = hidden_states.shape + + if self.device_sync_point == "before_permutation_1": + torch.npu.current_stream().synchronize() + + permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( + tokens=hidden_states, + indices=self.top_indices, + num_out_tokens=self.num_out_tokens, + ) + return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert + + permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1( + hidden_states, routing_map) + self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping + # permute 1 + + ep_group = self.ep_group + + # Perform expert parallel AlltoAll communication + if self.device_sync_point == "before_ep_alltoall": + torch.npu.current_stream().synchronize() + _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( + permutated_local_input_tokens, + self.output_splits, + self.input_splits, + ep_group, + ) + + # shared experts compute + if self.shared_experts is not None: + (share_experts_output), *_ = self.shared_experts(hidden_states) + else: + share_experts_output = None + + permute1_ep_all_to_all_handle.wait() + permutated_local_input_tokens.untyped_storage().resize_(0) + + def alltoall_token_permutation2(global_input_tokens): + # Permutation 2: Sort tokens by local expert. + if self.num_local_experts > 1: + global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( + global_input_tokens, + self.global_input_tokens_local_experts_indices) + + # Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens. + # global_input_tokens: [SEQL, H/TP] -> [SEQL, H] + if self.tp_ep_size > 1 and self.config.moe_grouped_gemm: + global_input_tokens = all_gather_last_dim_from_tensor_parallel_region( + global_input_tokens, self.tp_ep_group) + if self.device_sync_point == "before_finish": + torch.npu.current_stream().synchronize() + + return global_input_tokens + + # token premute2 input + global_input_tokens = alltoall_token_permutation2(global_input_tokens) + + return share_experts_output, global_input_tokens, tokens_per_expert + + def preprocess_and_permtute1(self, + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + shared_experts=None, + shared_experts_input: torch.Tensor = None): + self.hidden_shape = hidden_states.shape + self.probs = probs + self.top_indices = routing_map + assert probs.dim() == 2, "Expected 2D tensor for probs" + assert routing_map.dim() == 2, "Expected 2D tensor for routing map" + assert self.hidden_shape is not None + + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + tokens_per_expert = self.preprocess(routing_map, with_sync=False) + self.hidden_shape_before_permute = hidden_states.shape + + if self.device_sync_point == "before_permutation_1": + torch.npu.current_stream().synchronize() + + event = torch.npu.current_stream().record_event() + self.perm1_finish_event = torch.npu.Event() + with torch.npu.stream(self.overlap_stream): + assert self.overlap_stream is not None + self.overlap_stream.wait_event(event) + + if shared_experts is not None: + shared_output = shared_experts(shared_experts_input) + self.cached_shared_expert_output = shared_output + + hidden_states, self.reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( + tokens=hidden_states, + indices=self.top_indices, + num_out_tokens=self.num_out_tokens, + ) + + self.perm1_finish_event.record() + + # repeat interleve will launch a sync on current_stream. + if self.num_local_experts > 1: + self.device_sync_point = "no_sync" + if self.num_global_tokens_per_local_expert is None: + raise ValueError( + "num_global_tokens_per_local_expert must be set before operations." + ) + self.global_input_tokens_local_experts_indices = torch.repeat_interleave( + self.expert_ids_per_ep_rank, + self.num_global_tokens_per_local_expert.ravel()) + + self.cached_permutated_local_input_tokens = hidden_states + self.tokens_per_expert = tokens_per_expert + + def dispatch_alltoall(self): + ep_group = self.ep_group + + # Perform expert parallel AlltoAll communication + if self.device_sync_point == "before_ep_alltoall": + torch.npu.current_stream().synchronize() + + torch.npu.current_stream().wait_event(self.perm1_finish_event) + self.perm1_finish_event = None + _, self.cached_global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( + self.cached_permutated_local_input_tokens, + self.output_splits, + self.input_splits, + ep_group, + ) + permute1_ep_all_to_all_handle.wait() + if self.cached_permutated_local_input_tokens is None: + raise ValueError( + "cached_permutated_local_input_tokens must be set before operations." + ) + self.cached_permutated_local_input_tokens.untyped_storage().resize_(0) + self.cached_permutated_local_input_tokens = None + + def permute2(self): + global_input_tokens = self.cached_global_input_tokens + if self.num_local_experts > 1: + global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( + self.cached_global_input_tokens, + self.global_input_tokens_local_experts_indices) + assert self.cached_global_input_tokens is not None + self.cached_global_input_tokens.untyped_storage().resize_(0) + self.cached_global_input_tokens = None + + return global_input_tokens, self.tokens_per_expert + + def unpermute1(self, hidden_states: torch.Tensor): + # Unpermutation 2: expert output to AlltoAll input + if hidden_states.shape[0] > 0 and self.num_local_experts > 1: + hidden_states = torch_npu.npu_moe_token_unpermute( + hidden_states, self.reversed_global_input_permutation_mapping) + self.cached_global_output_tokens = hidden_states + self.reversed_global_input_permutation_mapping = None + + def combine_alltoall(self): + ep_group = self.ep_group + # Perform expert parallel AlltoAll communication + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + _, self.cached_local_output_tokens, handle = async_all_to_all( + self.cached_global_output_tokens, self.input_splits, + self.output_splits, ep_group) + handle.wait() + self.cached_global_output_tokens.untyped_storage().resize_(0) + self.cached_global_output_tokens = None + self.input_splits = None + self.output_splits = None + + def unpermute2(self): + output = torch_npu.npu_moe_token_unpermute( + permuted_tokens=self.cached_local_output_tokens, + sorted_indices=self.reversed_local_input_permutation_mapping.to( + torch.int32), + probs=self.probs, + restore_shape=self.hidden_shape_before_permute) + + output = output.view(self.hidden_shape) + + self.probs = None + self.reversed_local_input_permutation_mapping = None + self.cached_local_output_tokens.untyped_storage().resize_(0) + self.cached_local_output_tokens = None + + return output + + def token_unpermutation(self, + hidden_states: torch.Tensor, + bias: torch.Tensor = None): + """ + Reverse the token permutation to restore the original order. + + Args: + hidden_states (torch.Tensor): Output from local experts. + bias (torch.Tensor, optional): Bias tensor (not supported). + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Unpermuted token embeddings in the original order. + - None (bias is not supported). + """ + + def alltoall_token_unpermutation1(hidden_states): + assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher" + # Perform tensor parallel Reduce-Scatter + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + if self.tp_ep_size > 1: + hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region( + hidden_states, group=self.tp_ep_group) + + # Unpermutation 2: expert output to AlltoAll input + if hidden_states.shape[0] > 0 and self.num_local_experts > 1: + hidden_states = torch_npu.npu_moe_token_unpermute( + hidden_states, + self.reversed_global_input_permutation_mapping) + + return hidden_states + + hidden_states = alltoall_token_unpermutation1(hidden_states) + + ep_group = self.ep_group + # Perform expert parallel AlltoAll communication + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + _, permutated_local_input_tokens, handle = async_all_to_all( + hidden_states, self.input_splits, self.output_splits, ep_group) + handle.wait() + hidden_states.untyped_storage().resize_(0) + + def alltoall_token_unpermutation2(permutated_local_input_tokens): + # Unpermutation 1: AlltoAll output to output + + output = torch_npu.npu_moe_token_unpermute( + permuted_tokens=permutated_local_input_tokens, + sorted_indices=self.reversed_local_input_permutation_mapping. + to(torch.int32), + probs=self.probs, + restore_shape=self.hidden_shape_before_permute) + + # Perform tensor parallel AlltoAll communication + # output: [S*B, H/TP] -> [S*B/TP, H] + if self.tp_ep_size > 1: + output = all_to_all_hp2sp(output, self.tp_ep_group) + + # Reshape the output tensor + output = output.view(self.hidden_shape) + return output + + output = alltoall_token_unpermutation2(permutated_local_input_tokens) + + self.input_splits = None + self.output_splits = None + self.num_global_tokens_per_local_expert = None + self.num_global_tokens_per_local_expert_cpu = None + + return output, None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 728d4b191..120761870 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -653,7 +653,7 @@ def _check_dbo_is_valid(self, query_lens: torch.Tensor, ]: return False # considering the case that one dp rank may enable dbo while others may not - if not self.vllm_config.model_config.use_mla or not envs_ascend.VLLM_ASCEND_ENABLE_DBO: + if not envs_ascend.VLLM_ASCEND_ENABLE_DBO: return False # TODO: remove it if token-level microbatch is enabled [token_index,