Skip to content

Commit e9113e2

Browse files
author
yangcheng (AJ)
committed
fix mc2 bug
Signed-off-by: yangcheng (AJ) <y00806874@china.huawei.com>
1 parent 103bb69 commit e9113e2

File tree

2 files changed

+18
-27
lines changed

2 files changed

+18
-27
lines changed

vllm_ascend/models/qwen3_moe.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,23 @@
1616
# Adapted from vllm/model_executor/models/qwen3_moe.py
1717
# This file is a part of the vllm-ascend project.
1818

19-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19+
from typing import Optional
2020

2121
import torch
22-
import torch.distributed as dist
23-
import torch_npu
2422
import vllm
25-
import vllm.envs as envs
2623
from torch import nn
2724
from transformers import PretrainedConfig
2825
from vllm.attention import AttentionMetadata
29-
from vllm.distributed import (get_tensor_model_parallel_world_size,
30-
get_tp_group)
26+
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
3127
from vllm.distributed.parallel_state import get_dp_group
3228
from vllm.forward_context import get_forward_context
3329
from vllm.model_executor.layers.linear import ReplicatedLinear
34-
3530
from vllm.model_executor.layers.quantization import QuantizationConfig
36-
31+
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
3732
from vllm_ascend.ascend_config import get_ascend_config
3833
from vllm_ascend.distributed.parallel_state import get_ep_group
3934
from vllm_ascend.ops.fused_moe import AscendFusedMoE
4035

41-
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
42-
from transformers import PretrainedConfig
43-
from vllm.model_executor.layers.quantization import QuantizationConfig
44-
4536

4637
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
4738
packed_modules_mapping = {
@@ -55,19 +46,18 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
5546
"up_proj",
5647
],
5748
"experts":
58-
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
49+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
5950
}
6051

6152

6253
class AscendQwen3MoeSparseMoeBlock(nn.Module):
63-
6454
top_k: int
6555

6656
def __init__(
67-
self,
68-
config: PretrainedConfig,
69-
quant_config: Optional[QuantizationConfig] = None,
70-
prefix: str = "",
57+
self,
58+
config: PretrainedConfig,
59+
quant_config: Optional[QuantizationConfig] = None,
60+
prefix: str = "",
7161
):
7262
super().__init__()
7363
self.tp_size = get_tensor_model_parallel_world_size()
@@ -97,7 +87,6 @@ def __init__(
9787
quant_config=quant_config,
9888
prefix=f"{prefix}.experts")
9989

100-
10190
self.top_k = config.num_experts_per_tok
10291

10392
self.dp_size = get_dp_group().world_size
@@ -122,7 +111,7 @@ def forward(
122111
is_prefill = True
123112
enable_force_load_balance = True
124113
else:
125-
# is_prefill = attn_metadata.num_prefills > 0 is_prefill or
114+
# is_prefill = attn_metadata.num_prefills > 0
126115
enable_force_load_balance = False
127116
if hasattr(attn_metadata, 'with_prefill_across_dp'):
128117
is_prefill = attn_metadata.with_prefill_across_dp

vllm_ascend/ops/fused_moe.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,13 @@ def fused_experts_with_mc2(
118118
top_k: int,
119119
expert_map: torch.Tensor = None,
120120
moe_all_to_all_group_name: Optional[str] = None,
121-
shared_experts: Optional[Any] = None
121+
shared_experts: Optional[Any] = None,
122+
global_batch_size: int = 256,
122123
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
123-
global_bs = 0
124+
125+
ep_group = get_ep_group().device_group
126+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
127+
global_bs = global_batch_size * all_to_all_group_size
124128
moe_expert_num = len(expert_map)
125129
kwargs_mc2 = {
126130
"x": hidden_states,
@@ -132,11 +136,8 @@ def fused_experts_with_mc2(
132136
}
133137

134138
rank = torch.distributed.get_rank()
135-
136139
quant_mode = 0
137-
ep_group = get_ep_group().device_group
138140
local_rank = torch.distributed.get_rank(group=ep_group)
139-
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
140141

141142
tp_size = get_etp_group().world_size
142143
tp_rank = rank % tp_size
@@ -204,7 +205,7 @@ def fused_experts_with_mc2(
204205
"expert_shard_type": 0,
205206
"shared_expert_rank_num": 0,
206207
"moe_expert_num": moe_expert_num,
207-
"global_bs": 0,
208+
"global_bs": global_bs,
208209
}
209210
tp_recv_counts = output[5]
210211
stage3_kwargs = {
@@ -1037,7 +1038,8 @@ def apply(
10371038
top_k=top_k,
10381039
expert_map=expert_map,
10391040
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
1040-
shared_experts=shared_experts)
1041+
shared_experts=shared_experts,
1042+
global_batch_size=self.global_batch_size)
10411043
elif fused_moe_state == FusedMoEState.AllGather:
10421044
return fused_experts(hidden_states=x,
10431045
w1=layer.w13_weight,

0 commit comments

Comments
 (0)