Skip to content

Commit 33dbe57

Browse files
authored
[0.9.1][bugfix] fix mc2 op GroupCoordinator bug (#1711)
### What this PR does / why we need it? ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: zzzzwwjj <1183291235@qq.com>
1 parent 3b99491 commit 33dbe57

File tree

6 files changed

+83
-4
lines changed

6 files changed

+83
-4
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Optional
2+
3+
import torch
4+
from vllm.distributed.parallel_state import (GroupCoordinator, get_world_group,
5+
init_model_parallel_group)
6+
7+
# Currently, mc2 op need their own group coordinator.
8+
_MC2: Optional[GroupCoordinator] = None
9+
10+
11+
def get_mc2_group() -> GroupCoordinator:
12+
assert _MC2 is not None, ("mc2 group is not initialized")
13+
return _MC2
14+
15+
16+
def model_parallel_initialized():
17+
return (_MC2 is not None)
18+
19+
20+
def init_ascend_model_parallel(
21+
expert_parallel_size: int = 1,
22+
world_size: Optional[int] = None,
23+
backend: Optional[str] = None,
24+
):
25+
if model_parallel_initialized():
26+
return
27+
assert torch.distributed.is_initialized()
28+
world_size = world_size or torch.distributed.get_world_size()
29+
backend = backend or torch.distributed.get_backend(
30+
get_world_group().device_group)
31+
num_expert_parallel_groups = world_size // expert_parallel_size
32+
33+
global _MC2
34+
group_ranks = []
35+
for i in range(num_expert_parallel_groups):
36+
ranks = list(range(i, world_size, num_expert_parallel_groups))
37+
group_ranks.append(ranks)
38+
39+
_MC2 = init_model_parallel_group(group_ranks,
40+
get_world_group().local_rank,
41+
backend,
42+
group_name="mc2")
43+
44+
45+
def destroy_ascend_model_parallel():
46+
global _MC2
47+
if _MC2:
48+
_MC2.destroy()
49+
_MC2 = None

vllm_ascend/ops/fused_moe.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import vllm_ascend.envs as envs_ascend
4040
from vllm_ascend.ascend_config import get_ascend_config
4141
from vllm_ascend.ascend_forward_context import FusedMoEState
42+
from vllm_ascend.distributed.parallel_state import get_mc2_group
4243
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
4344
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
4445
get_ascend_soc_version, npu_stream_switch,
@@ -125,7 +126,7 @@ def fused_experts_with_mc2(
125126
mc2_mask: Optional[torch.Tensor] = None,
126127
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
127128
quant_mode = 0
128-
ep_group = get_ep_group()
129+
ep_group = get_mc2_group()
129130
ep_rank_id = ep_group.rank_in_group
130131
ep_world_size = ep_group.world_size
131132
tp_world_size = get_tp_group().world_size
@@ -878,7 +879,7 @@ def __init__(self, moe: MoEConfig = None):
878879
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
879880

880881
try:
881-
device_group = get_ep_group().device_group
882+
device_group = get_mc2_group().device_group
882883
# TODO: Try local_rank = ep_group.rank_in_group
883884
local_rank = torch.distributed.get_rank(group=device_group)
884885
backend = device_group._get_backend(torch.device("npu"))

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,34 @@
1717
# Adapted from vllm/model_executor/models/qwen2_vl.py
1818
# This file is a part of the vllm-ascend project.
1919

20+
import vllm
2021
import vllm.envs as envs
2122
from torch.distributed import ProcessGroup
2223
from vllm.config import ParallelConfig
2324
from vllm.distributed.utils import \
2425
stateless_init_torch_distributed_process_group
2526

2627

28+
def ascend_destroy_model_parallel():
29+
"""Set the groups to none and destroy them."""
30+
from vllm.distributed.parallel_state import _DP, _EP, _PP, _TP
31+
if _TP:
32+
_TP.destroy()
33+
_TP = None
34+
if _PP:
35+
_PP.destroy()
36+
_PP = None
37+
if _DP:
38+
_DP.destroy()
39+
_DP = None
40+
if _EP:
41+
_EP.destroy()
42+
_EP = None
43+
from vllm_ascend.distributed.parallel_state import \
44+
destroy_ascend_model_parallel
45+
destroy_ascend_model_parallel()
46+
47+
2748
def parallel_config_get_dp_port(self) -> int:
2849
"""
2950
We might need to initialize process groups in multiple
@@ -57,5 +78,6 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
5778
return dp_group
5879

5980

81+
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
6082
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
6183
ParallelConfig.stateless_init_dp_group = stateless_init_dp_group

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import vllm_ascend.envs as ascend_envs
2727
from vllm_ascend.ascend_config import get_ascend_config
2828
from vllm_ascend.ascend_forward_context import FusedMoEState
29+
from vllm_ascend.distributed.parallel_state import get_mc2_group
2930
from vllm_ascend.ops.fused_moe import select_experts
3031
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
3132
dispose_tensor, get_ascend_soc_version,
@@ -223,7 +224,7 @@ def fused_experts_with_mc2(
223224
if log2phy:
224225
topk_ids = log2phy[topk_ids]
225226
quant_mode = 2
226-
ep_group = get_ep_group()
227+
ep_group = get_mc2_group()
227228
ep_rank_id = ep_group.rank_in_group
228229
ep_world_size = ep_group.world_size
229230

@@ -763,7 +764,7 @@ def __init__(self):
763764
self.enable_weight_nz_layout = ascend_config.enable_weight_nz_layout
764765

765766
try:
766-
device_group = self.ep_group.device_group
767+
device_group = get_mc2_group().device_group
767768
# TODO: Try local_rank = ep_group.rank_in_group
768769
local_rank = torch.distributed.get_rank(group=device_group)
769770
backend = device_group._get_backend(torch.device("npu"))

vllm_ascend/worker/worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
from vllm_ascend.ascend_config import init_ascend_config
5151
from vllm_ascend.device_allocator.camem import CaMemAllocator
52+
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
5253
from vllm_ascend.platform import NPUPlatform
5354
from vllm_ascend.utils import init_ascend_soc_version, try_register_lib
5455
from vllm_ascend.worker.model_runner import NPUModelRunner
@@ -545,6 +546,8 @@ def _init_worker_distributed_environment(
545546
ensure_model_parallel_initialized(
546547
parallel_config.tensor_parallel_size,
547548
parallel_config.pipeline_parallel_size)
549+
init_ascend_model_parallel(parallel_config.expert_parallel_size,
550+
parallel_config.world_size_across_dp)
548551
ensure_kv_transfer_initialized(vllm_config)
549552

550553

vllm_ascend/worker/worker_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
from vllm_ascend.ascend_config import init_ascend_config
4242
from vllm_ascend.device_allocator.camem import CaMemAllocator
43+
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
4344
from vllm_ascend.platform import NPUPlatform
4445
from vllm_ascend.utils import init_ascend_soc_version, try_register_lib
4546
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -260,6 +261,8 @@ def _init_worker_distributed_environment(self) -> None:
260261
ensure_model_parallel_initialized(
261262
self.parallel_config.tensor_parallel_size,
262263
self.parallel_config.pipeline_parallel_size)
264+
init_ascend_model_parallel(self.parallel_config.expert_parallel_size,
265+
self.parallel_config.world_size_across_dp)
263266
ensure_kv_transfer_initialized(self.vllm_config)
264267

265268
def _init_profiler(self):

0 commit comments

Comments
 (0)