Skip to content

Commit 21a7023

Browse files
committed
fix graph mode pd stuck
Signed-off-by: ningbenzhe1 <ningbenzhe@huawei.com>
1 parent a93bed4 commit 21a7023

File tree

7 files changed

+68
-8
lines changed

7 files changed

+68
-8
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ class AscendMetadata:
129129
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
130130
attn_mask: Optional[torch.Tensor] = None
131131

132+
# For logging.
133+
num_input_tokens: int = 0 # Number of tokens including padding.
134+
132135

133136
class AscendAttentionMetadataBuilder:
134137

vllm_ascend/distributed/parallel_state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,18 @@ def get_etp_group() -> GroupCoordinator:
2121
return _ETP
2222

2323

24+
def model_parallel_initialized():
25+
return (_ETP is not None and _EP is not None)
26+
27+
2428
def init_ascend_model_parallel(
2529
tensor_model_parallel_size: int = 1,
2630
pipeline_model_parallel_size: int = 1,
2731
expert_tensor_parallel_size: int = 1,
2832
backend: Optional[str] = None,
2933
):
34+
if model_parallel_initialized():
35+
return
3036
assert torch.distributed.is_initialized()
3137
world_size: int = torch.distributed.get_world_size()
3238
backend = backend or torch.distributed.get_backend(

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def fused_experts_with_mc2(
6666
local_rank = torch.distributed.get_rank(group=ep_group)
6767
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
6868

69-
world_szie = torch.distributed.get_world_size()
70-
tp_size = world_szie // all_to_all_group_size
69+
tp_size = get_etp_group().world_size
7170
tp_rank = rank % tp_size
7271

7372
stage1_kwargs = {

vllm_ascend/patch/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,14 @@
7070
# on multi-node dp inference implementation
7171
# 4. `ParallelConfig.stateless_init_dp_group`
7272
# Why:
73-
# vLLM use gloo backend by default to initialize stateless dp process gourp, but we want to use hccl here to
74-
# get better performance
73+
# vLLM use gloo backend by default to initialize stateless dp process group, but we want to use hccl here to
74+
# get better performance. Initialize the global variable of dp_group to prefill dummy_run.
7575
# How:
76-
# adopt nccl backend to init process group
76+
# adopt nccl backend to init process group and add the global variable of dp_group.
7777
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
7878
# Future Plan:
7979
# Remove those patch when vllm merged them
80+
# Add the global variable of dp_group in platform when vllm merged them.
8081
#
8182
#
8283
# * Worker Patch:

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
import torch
2121
import vllm
2222
import vllm.distributed
23+
import vllm.envs as envs
2324
from torch.distributed import ProcessGroup
2425
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
2526
_get_default_timeout,
2627
is_nccl_available)
2728
from torch.distributed.rendezvous import rendezvous
2829
from vllm.config import ParallelConfig
2930

31+
_DP_GROUP = None
32+
3033

3134
def ascend_destroy_model_parallel():
3235
"""Set the groups to none and destroy them."""
@@ -164,10 +167,9 @@ def parallel_config_get_dp_port(self) -> int:
164167
"""
165168
answer = self.data_parallel_master_port
166169
self.data_parallel_master_port += 1
167-
import os
168170

169171
# NOTE: Get port from envs directly when using torchrun
170-
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
172+
port = envs.VLLM_DP_MASTER_PORT if envs.VLLM_DP_MASTER_PORT else answer
171173
return port
172174

173175

@@ -183,10 +185,16 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
183185
self.data_parallel_rank,
184186
self.data_parallel_size,
185187
backend="gloo")
188+
global _DP_GROUP
189+
_DP_GROUP = dp_group
186190

187191
return dp_group
188192

189193

194+
def get_dp_group():
195+
return _DP_GROUP
196+
197+
190198
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
191199
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
192200
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group

vllm_ascend/worker/model_runner_v1.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import numpy.typing as npt
3030
import torch
3131
import torch.nn as nn
32+
from torch.distributed import ReduceOp
3233
from vllm.attention import AttentionType, get_attn_backend
3334
from vllm.attention.layer import Attention
3435
from vllm.config import CompilationLevel, VllmConfig
@@ -59,6 +60,8 @@
5960

6061
from vllm_ascend.attention.attention import AttentionMaskBuilder
6162
from vllm_ascend.attention.attention_v1 import AscendAttentionState
63+
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
64+
get_dp_group
6265
from vllm_ascend.platform import NPUPlatform
6366
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
6467

@@ -318,6 +321,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
318321
False) and self.vllm_config.model_config.use_mla
319322
self.use_cached_npu_graph = additional_config.get(
320323
"use_cached_npu_graph", False)
324+
self.has_prefilled = False
325+
self.dp_group = get_dp_group()
321326

322327
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
323328
"""Update the cached states and the persistent batch with the scheduler
@@ -624,6 +629,9 @@ def _process_reqs(
624629
input_ids = torch.cat([input_ids, padding])
625630
positions = torch.cat([positions, padding])
626631

632+
if self.enable_torchair_graph_mode:
633+
self.sync_prefill_when_enable_graph(attn_metadata)
634+
627635
# Run forward pass
628636
with set_forward_context(attn_metadata,
629637
self.vllm_config,
@@ -685,6 +693,41 @@ def _process_reqs(
685693
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
686694
total_num_scheduled_tokens, sample_indices)
687695

696+
def sync_prefill_when_enable_graph(self, attn_metadata):
697+
"""
698+
NOTE: This method serves as a temporary solution to the deadlock issue under the p and d in graph mode.
699+
It will be removed along with its related calls once the official solution is implemented.
700+
"""
701+
702+
def has_prefilled_all_rank(has_prefilled: bool) -> bool:
703+
status = torch.tensor([has_prefilled],
704+
dtype=torch.int32,
705+
device="cpu")
706+
if self.dp_group:
707+
torch.distributed.all_reduce(status,
708+
op=ReduceOp.MIN,
709+
group=self.dp_group)
710+
aggregated_has_prefilled = bool(status.item())
711+
return aggregated_has_prefilled
712+
713+
if self.has_prefilled and not attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
714+
self.has_prefilled = False
715+
716+
if not self.has_prefilled:
717+
self.has_prefilled = has_prefilled_all_rank(
718+
attn_metadata.attn_state == AscendAttentionState.DecodeOnly)
719+
720+
if self.dp_group:
721+
while not self.has_prefilled and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
722+
self._dummy_run(1)
723+
tensor = torch.tensor([1], dtype=torch.int32, device="cpu")
724+
torch.distributed.all_reduce(tensor,
725+
op=ReduceOp.MAX,
726+
group=self.dp_group)
727+
self.has_prefilled = has_prefilled_all_rank(
728+
attn_metadata.attn_state ==
729+
AscendAttentionState.DecodeOnly)
730+
688731
def _calc_spec_decode_metadata(
689732
self,
690733
num_draft_tokens: np.ndarray,

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def execute_model(
173173
scheduler_output: "SchedulerOutput",
174174
) -> Optional[ModelRunnerOutput]:
175175
output = self.model_runner.execute_model(scheduler_output)
176-
return output if self.rank == 0 else None
176+
return output if self.is_driver_worker else None
177177

178178
def load_model(self) -> None:
179179
self.model_runner.load_model()

0 commit comments

Comments
 (0)