From 9e099a531662975d35ff21ed2e1c1507b58ef9b4 Mon Sep 17 00:00:00 2001 From: sdmyzlp <117554856+sdmyzlp@users.noreply.github.com> Date: Thu, 26 Jun 2025 09:32:07 +0800 Subject: [PATCH 1/3] Handle with_prefill_across_dp for multistream mla (#1322) Signed-off-by: sdmyzlp --- tests/multicard/test_torchair_graph_mode.py | 115 ++++++++++++-------- vllm_ascend/attention/mla_v1.py | 16 +-- vllm_ascend/models/deepseek_v2.py | 10 +- 3 files changed, 85 insertions(+), 56 deletions(-) diff --git a/tests/multicard/test_torchair_graph_mode.py b/tests/multicard/test_torchair_graph_mode.py index 96fa92ef4b..d06a0872bc 100644 --- a/tests/multicard/test_torchair_graph_mode.py +++ b/tests/multicard/test_torchair_graph_mode.py @@ -20,6 +20,7 @@ Run `pytest tests/multicard/test_torchair_graph_mode.py`. """ import os +from typing import Dict import pytest @@ -28,6 +29,55 @@ os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" +def _deepseek_torchair_test_fixture( + additional_config: Dict, + *, + tensor_parallel_size=4, +): + example_prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # torchair is only work without chunked-prefill now + kwargs = { + "ascend_scheduler_config": { + "enabled": True, + }, + "refresh": True, + } + additional_config.update(**kwargs) + + with VllmRunner( + "vllm-ascend/DeepSeek-V3-Pruning", + dtype="half", + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend="mp", + enforce_eager=False, + additional_config=additional_config, + ) as vllm_model: + # use greedy sampler to make sure the generated results are fix + vllm_output = vllm_model.generate_greedy(example_prompts, 5) + + # NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of + # DeepSeek-V3 with 2 hidden layers, thus the golden results seems + # inaccurate. This will only change if accuracy improves with the + # official weights of DeepSeek-V3. + golden_results = [ + 'Hello, my name is下载早点向前很有่อง', + 'The president of the United States isSender)## physiological Albany', + 'The capital of France is Rocky转角 hospitalizedinterval sparked', + 'The future of AI is её asegο BIOS一扫', + ] + + assert len(golden_results) == len(vllm_output) + for i in range(len(vllm_output)): + assert golden_results[i] == vllm_output[i][1] + print(f"Generated text: {vllm_output[i][1]!r}") + + @pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", reason="torchair graph is not supported on v0") @pytest.mark.parametrize("VLLM_ASCEND_ENABLE_DBO", ["0", "1"]) @@ -38,46 +88,25 @@ def test_e2e_deepseekv3_with_torchair(monkeypatch: pytest.MonkeyPatch, m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") m.setenv("VLLM_ASCEND_ENABLE_DBO", VLLM_ASCEND_ENABLE_DBO) - example_prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - dtype = "half" - max_tokens = 5 - # torchair is only work without chunked-prefill now - with VllmRunner( - "vllm-ascend/DeepSeek-V3-Pruning", - dtype=dtype, - tensor_parallel_size=4, - distributed_executor_backend="mp", - additional_config={ - "torchair_graph_config": { - "enabled": True, - }, - "ascend_scheduler_config": { - "enabled": True, - }, - "refresh": True, - }, - enforce_eager=False, - ) as vllm_model: - # use greedy sampler to make sure the generated results are fix - vllm_output = vllm_model.generate_greedy(example_prompts, - max_tokens) - # NOTE: vllm-ascend/DeepSeek-V3-Pruning is a random weight of - # DeepSeek-V3 with 2 hidden layers, thus the golden results seems - # inaccurate. This will only change if accuracy improves with the - # official weights of DeepSeek-V3. - golden_results = [ - 'Hello, my name is下载早点向前很有่อง', - 'The president of the United States isSender)## physiological Albany', - 'The capital of France is Rocky转角 hospitalizedinterval sparked', - 'The future of AI is её asegο BIOS一扫', - ] - - assert len(golden_results) == len(vllm_output) - for i in range(len(vllm_output)): - assert golden_results[i] == vllm_output[i][1] - print(f"Generated text: {vllm_output[i][1]!r}") + additional_config = { + "torchair_graph_config": { + "enabled": True, + }, + } + _deepseek_torchair_test_fixture(additional_config) + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="torchair graph is not supported on v0") +def test_e2e_deepseekv3_with_torchair_ms_mla(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_USE_MODELSCOPE", "True") + m.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + + additional_config = { + "torchair_graph_config": { + "enabled": True, + "enable_multistream_mla": True, + }, + } + _deepseek_torchair_test_fixture(additional_config) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 98f0a3389c..b9a09f0c6a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -588,8 +588,6 @@ def __init__( ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - self.enable_multistream_mla = \ - ascend_config.torchair_graph_config.enable_multistream_mla # Adapt torch air graph mode with spec decoding. speculative_config = get_current_vllm_config().speculative_config @@ -883,6 +881,7 @@ def exec_kv( sin: torch.Tensor, kv_cache: Tuple, slots: torch.Tensor, + enable_multistream_mla: bool = False, ): B = hidden_states.shape[0] @@ -894,7 +893,7 @@ def exec_kv( cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" with npu_stream_switch("mla_secondary", 0, - enabled=self.enable_multistream_mla): + enabled=enable_multistream_mla): k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv, self.kv_a_layernorm.weight, @@ -1066,6 +1065,7 @@ def forward( kv_cache: Tuple[torch.Tensor], attn_metadata: M, output: Optional[torch.Tensor] = None, + enable_multistream_mla: bool = False, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if attn_metadata is None: @@ -1127,22 +1127,22 @@ def forward( # KvRmsNormRopeCache and SingleRope. npu_wait_tensor(decode_hs_or_q_c, cos, - enabled=self.enable_multistream_mla) + enabled=enable_multistream_mla) npu_wait_tensor(decode_hs_or_q_c, sin, - enabled=self.enable_multistream_mla) + enabled=enable_multistream_mla) decode_ql_nope, decode_q_pe = \ self._q_proj_and_k_up_proj(decode_hs_or_q_c) if self.running_in_graph: decode_k_pe, decode_k_nope = self.exec_kv( hidden_states_or_kv_c_normed, cos, sin, kv_cache, - attn_metadata.slot_mapping) + attn_metadata.slot_mapping, enable_multistream_mla) with npu_stream_switch("mla_secondary", 0, - enabled=self.enable_multistream_mla): + enabled=enable_multistream_mla): npu_wait_tensor(decode_q_pe, decode_k_pe, - enabled=self.enable_multistream_mla) + enabled=enable_multistream_mla) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) else: decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 36a761c4c2..222ddc966b 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -470,15 +470,15 @@ def forward( hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + enable_multistream_mla = (self.enable_multistream_mla + and not get_forward_context().with_prefill) + forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} if self.q_lora_rank is not None: ckq = self.q_a_proj(hidden_states)[0] - use_multistream_mla = (self.enable_multistream_mla - and attn_metadata is not None - and attn_metadata.num_decodes > 0) - npu_wait_tensor(hidden_states, ckq, enabled=use_multistream_mla) + npu_wait_tensor(hidden_states, ckq, enabled=enable_multistream_mla) with npu_stream_switch("mla_secondary", 0, - enabled=use_multistream_mla): + enabled=enable_multistream_mla): hidden_states_or_q_c = self.q_a_layernorm(ckq) else: hidden_states_or_q_c = hidden_states From 066ea109a65b53e60bd3ba9c8dee0cf60d32ecd8 Mon Sep 17 00:00:00 2001 From: sdmyzlp Date: Mon, 7 Jul 2025 10:30:48 +0800 Subject: [PATCH 2/3] Fix enable_multistream_moe for unquantized scenario Signed-off-by: sdmyzlp --- vllm_ascend/ops/fused_moe.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index fe1164fd4d..4f211b3dbb 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1144,9 +1144,11 @@ def forward(self, if self.enable_multistream_moe: assert gate is not None router_logits, _ = gate(hidden_states) - if isinstance(self.quant_method.quant_method, - AscendW8A8DynamicFusedMoEMethod - ) and fused_moe_state == FusedMoEState.MC2: + if not isinstance(self.quant_method, + AscendUnquantizedFusedMoEMethod) and isinstance( + self.quant_method.quant_method, + AscendW8A8DynamicFusedMoEMethod + ) and fused_moe_state == FusedMoEState.MC2: with npu_stream_switch("moe_secondary", 0): quantized_x_for_share, dynamic_scale_for_share = torch_npu.npu_dynamic_quant( hidden_states) From 1bd5be977ff261b2e75161a2ac7e6eed315af452 Mon Sep 17 00:00:00 2001 From: sharonyunyun <106064496+sharonyunyun@users.noreply.github.com> Date: Wed, 25 Jun 2025 19:56:49 +0800 Subject: [PATCH 3/3] adjusting the communication method in graph mode (#1194) Signed-off-by: sharonyunyun Signed-off-by: sdmyzlp --- vllm_ascend/attention/mla_v1.py | 8 +- vllm_ascend/models/deepseek_v2.py | 172 ++++++++++++++++++++++++++---- vllm_ascend/ops/fused_moe.py | 9 +- 3 files changed, 161 insertions(+), 28 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index b9a09f0c6a..e1c09c4597 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -9,6 +9,7 @@ MLAAttentionImpl) from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down @@ -584,6 +585,7 @@ def __init__( self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled @@ -602,7 +604,7 @@ def _v_up_proj_and_o_proj(self, x): x = torch.bmm(x, self.W_UV) # Convert from (N, B, V) to (B, N * V) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return self.o_proj(x)[0] + return self.o_proj(x, is_prefill=False)[0] # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): @@ -867,12 +869,12 @@ def _forward_prefill( current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: - return self.o_proj(attn_output)[0] + return self.o_proj(attn_output, is_prefill=True)[0] else: current_ms_metadata.before_comm_event.record() with torch.npu.stream(current_ms_metadata.comm_stream): current_ms_metadata.before_comm_event.wait() - return self.o_proj(attn_output)[0] + return self.o_proj(attn_output, is_prefill=True)[0] def exec_kv( self, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 222ddc966b..69dc59bf59 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -35,9 +35,16 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) +# Temporarily disable yapf since it conflicts with isort. +# yapf: disable from vllm.distributed import (get_dp_group, get_pp_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) + get_tp_group, split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +# yapf: enable from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -132,6 +139,80 @@ def weight_loader(self, param: torch.nn.Parameter, shard.copy_(loaded_weight) +class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): + + def forward( + self, + input_, + is_prefill=True + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + if not is_prefill and output_parallel.shape[0] % self.tp_size == 0: + output = tensor_model_parallel_reduce_scatter(output_parallel, + dim=0) + else: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class CustomDeepseekV2RowParallelLinear(RowParallelLinear): + + def forward( + self, + input_, + is_prefill=True + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + class CustomDeepseekV2MLP(nn.Module): def __init__( @@ -291,10 +372,10 @@ def __init__( self.params_dtype = torch.get_default_dtype() - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + def forward(self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + replace_allreduce: bool = False) -> torch.Tensor: forward_context = get_forward_context() if attn_metadata is None: attn_metadata = forward_context.attn_metadata @@ -323,7 +404,7 @@ def forward( enable_force_load_balance=enable_force_load_balance, shared_experts=self.shared_experts, gate=self.gate if self.enable_multistream_moe else None, - ) + replace_allreduce=replace_allreduce) hidden_states = ( experts_hidden_states[0] * self.routed_scaling_factor + @@ -370,6 +451,14 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_mla = \ + ascend_config.torchair_graph_config.enable_multistream_mla + if self.q_lora_rank is not None: self.q_a_proj = ReplicatedLinear(self.hidden_size, self.q_lora_rank, @@ -406,11 +495,23 @@ def __init__( bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + if (config.n_routed_experts is not None + and self.debug_layer_idx >= config.first_k_dense_replace + and self.debug_layer_idx % config.moe_layer_freq == 0 and + ascend_config.torchair_graph_config.enable_multistream_moe): + self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + else: + self.o_proj = CustomDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' @@ -456,14 +557,6 @@ def __init__( o_proj=self.o_proj, ) - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_mla = \ - ascend_config.torchair_graph_config.enable_multistream_mla - def forward( self, positions: torch.Tensor, @@ -530,6 +623,10 @@ def __init__( # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) self.layer_idx = layer_idx + self.layers = config.num_hidden_layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + ascend_config = get_ascend_config() # TODO: enable mla in vllm-ascend if model_config.use_mla: attn_cls = CustomDeepseekV2MLAAttention @@ -561,6 +658,8 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) + self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \ + and model_config.use_mla and envs.VLLM_USE_V1 and self.tp_size > 1 else: self.mlp = CustomDeepseekV2MLP( hidden_size=config.hidden_size, @@ -569,11 +668,13 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.mlp", ) + self.mla_moe_communication = False self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.routed_scaling_factor = config.routed_scaling_factor + self.first_k_dense_replace = config.first_k_dense_replace def forward( self, @@ -582,8 +683,13 @@ def forward( residual: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None, + replace_allreduce: bool = False, ) -> torch.Tensor: # Self Attention + if attn_metadata is not None and attn_metadata.num_decodes > 0: + mla_moe_communication = self.mla_moe_communication and replace_allreduce + else: + mla_moe_communication = False if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -595,6 +701,9 @@ def forward( # to save npu memory because they're no longer used. dispose_tensor(previous_hidden_states) dispose_tensor(previous_residual) + if mla_moe_communication and self.layer_idx > self.first_k_dense_replace: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) hidden_states = self.self_attn( positions=positions, @@ -603,6 +712,13 @@ def forward( attn_metadata=attn_metadata, ) + if mla_moe_communication and residual.shape[0] != hidden_states.shape[ + 0]: + chunk_hidden_states = torch.tensor_split(residual, + self.tp_size, + dim=0) + residual = chunk_hidden_states[self.tp_rank] + if hidden_states.dtype == torch.float16: # Fix FP16 overflow # We scale both hidden_states and residual before @@ -618,7 +734,9 @@ def forward( hidden_states, residual) if isinstance(self.mlp, CustomDeepseekV2MoE): - hidden_states = self.mlp(hidden_states, attn_metadata) + hidden_states = self.mlp(hidden_states, + attn_metadata, + replace_allreduce=mla_moe_communication) else: hidden_states = self.mlp(hidden_states) @@ -631,6 +749,10 @@ def forward( # The scaling of DeepseekV2MOE output would be done in the forward # of DeepseekV2MOE hidden_states *= 1. / self.routed_scaling_factor + if mla_moe_communication and self.layer_idx == self.layers - 1: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) + residual = tensor_model_parallel_all_gather(residual, dim=0) return hidden_states, residual @@ -649,6 +771,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + self.tp_size = get_tensor_model_parallel_world_size() if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -701,13 +824,18 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + replace_allreduce = hidden_states.shape[0] % self.tp_size == 0 + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, residual, + positions, + hidden_states, + residual, kv_caches[i - self.start_layer] if kv_caches is not None else None, - attn_metadata) + attn_metadata, + replace_allreduce=replace_allreduce) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 4f211b3dbb..15fa55cc6f 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -1126,7 +1126,8 @@ def forward(self, enable_force_load_balance: bool = False, top_k: Optional[int] = None, shared_experts: Optional[Any] = None, - gate: Optional[Any] = None): + gate: Optional[Any] = None, + replace_allreduce: bool = False): assert self.quant_method is not None if top_k: @@ -1158,7 +1159,8 @@ def forward(self, shared_hidden_states = shared_experts(hidden_states) tp_size = get_tensor_model_parallel_world_size() - if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: + if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather + and not replace_allreduce): if num_tokens < tp_size: hidden_states = nn.functional.pad( hidden_states, (0, 0, 0, tp_size - num_tokens)) @@ -1217,7 +1219,8 @@ def forward(self, if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states - if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: + if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather + and not replace_allreduce): dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0)