Skip to content

Commit 5d21f95

Browse files
author
yangcheng (AJ)
committed
fix mc2 bug
Signed-off-by: yangcheng (AJ) <y00806874@china.huawei.com>
1 parent 103bb69 commit 5d21f95

File tree

2 files changed

+22
-30
lines changed

2 files changed

+22
-30
lines changed

vllm_ascend/models/qwen3_moe.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,23 @@
1616
# Adapted from vllm/model_executor/models/qwen3_moe.py
1717
# This file is a part of the vllm-ascend project.
1818

19-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19+
from typing import Optional
2020

2121
import torch
22-
import torch.distributed as dist
23-
import torch_npu
24-
import vllm
25-
import vllm.envs as envs
2622
from torch import nn
2723
from transformers import PretrainedConfig
24+
from vllm_ascend.ascend_config import get_ascend_config
25+
from vllm_ascend.distributed.parallel_state import get_ep_group
26+
from vllm_ascend.ops.fused_moe import AscendFusedMoE
27+
28+
import vllm
2829
from vllm.attention import AttentionMetadata
29-
from vllm.distributed import (get_tensor_model_parallel_world_size,
30-
get_tp_group)
30+
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
3131
from vllm.distributed.parallel_state import get_dp_group
3232
from vllm.forward_context import get_forward_context
3333
from vllm.model_executor.layers.linear import ReplicatedLinear
34-
3534
from vllm.model_executor.layers.quantization import QuantizationConfig
36-
37-
from vllm_ascend.ascend_config import get_ascend_config
38-
from vllm_ascend.distributed.parallel_state import get_ep_group
39-
from vllm_ascend.ops.fused_moe import AscendFusedMoE
40-
4135
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
42-
from transformers import PretrainedConfig
43-
from vllm.model_executor.layers.quantization import QuantizationConfig
4436

4537

4638
class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
@@ -55,19 +47,18 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
5547
"up_proj",
5648
],
5749
"experts":
58-
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
50+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
5951
}
6052

6153

6254
class AscendQwen3MoeSparseMoeBlock(nn.Module):
63-
6455
top_k: int
6556

6657
def __init__(
67-
self,
68-
config: PretrainedConfig,
69-
quant_config: Optional[QuantizationConfig] = None,
70-
prefix: str = "",
58+
self,
59+
config: PretrainedConfig,
60+
quant_config: Optional[QuantizationConfig] = None,
61+
prefix: str = "",
7162
):
7263
super().__init__()
7364
self.tp_size = get_tensor_model_parallel_world_size()
@@ -97,7 +88,6 @@ def __init__(
9788
quant_config=quant_config,
9889
prefix=f"{prefix}.experts")
9990

100-
10191
self.top_k = config.num_experts_per_tok
10292

10393
self.dp_size = get_dp_group().world_size
@@ -122,7 +112,7 @@ def forward(
122112
is_prefill = True
123113
enable_force_load_balance = True
124114
else:
125-
# is_prefill = attn_metadata.num_prefills > 0 is_prefill or
115+
# is_prefill = attn_metadata.num_prefills > 0
126116
enable_force_load_balance = False
127117
if hasattr(attn_metadata, 'with_prefill_across_dp'):
128118
is_prefill = attn_metadata.with_prefill_across_dp

vllm_ascend/ops/fused_moe.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,13 @@ def fused_experts_with_mc2(
118118
top_k: int,
119119
expert_map: torch.Tensor = None,
120120
moe_all_to_all_group_name: Optional[str] = None,
121-
shared_experts: Optional[Any] = None
121+
shared_experts: Optional[Any] = None,
122+
global_batch_size: int = 256,
122123
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
123-
global_bs = 0
124+
125+
ep_group = get_ep_group().device_group
126+
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
127+
global_bs = global_batch_size * all_to_all_group_size
124128
moe_expert_num = len(expert_map)
125129
kwargs_mc2 = {
126130
"x": hidden_states,
@@ -132,11 +136,8 @@ def fused_experts_with_mc2(
132136
}
133137

134138
rank = torch.distributed.get_rank()
135-
136139
quant_mode = 0
137-
ep_group = get_ep_group().device_group
138140
local_rank = torch.distributed.get_rank(group=ep_group)
139-
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
140141

141142
tp_size = get_etp_group().world_size
142143
tp_rank = rank % tp_size
@@ -204,7 +205,7 @@ def fused_experts_with_mc2(
204205
"expert_shard_type": 0,
205206
"shared_expert_rank_num": 0,
206207
"moe_expert_num": moe_expert_num,
207-
"global_bs": 0,
208+
"global_bs": global_bs,
208209
}
209210
tp_recv_counts = output[5]
210211
stage3_kwargs = {
@@ -1037,7 +1038,8 @@ def apply(
10371038
top_k=top_k,
10381039
expert_map=expert_map,
10391040
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
1040-
shared_experts=shared_experts)
1041+
shared_experts=shared_experts,
1042+
global_batch_size=self.global_batch_size)
10411043
elif fused_moe_state == FusedMoEState.AllGather:
10421044
return fused_experts(hidden_states=x,
10431045
w1=layer.w13_weight,

0 commit comments

Comments
 (0)