Skip to content

Commit a0af061

Browse files
committed
fix graph mode pd stuck
Signed-off-by: ningbenzhe1 <ningbenzhe@huawei.com>
1 parent e2a0c19 commit a0af061

File tree

7 files changed

+89
-15
lines changed

7 files changed

+89
-15
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ class AscendMetadata:
130130
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
131131
attn_mask: Optional[torch.Tensor] = None
132132

133+
# For logging.
134+
num_input_tokens: int = 0 # Number of tokens including padding.
135+
133136

134137
class AscendAttentionMetadataBuilder:
135138

vllm_ascend/distributed/parallel_state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,18 @@ def get_etp_group() -> GroupCoordinator:
2121
return _ETP
2222

2323

24+
def model_parallel_initialized():
25+
return (_ETP is not None and _EP is not None)
26+
27+
2428
def init_ascend_model_parallel(
2529
tensor_model_parallel_size: int = 1,
2630
pipeline_model_parallel_size: int = 1,
2731
expert_tensor_parallel_size: int = 1,
2832
backend: Optional[str] = None,
2933
):
34+
if model_parallel_initialized():
35+
return
3036
assert torch.distributed.is_initialized()
3137
world_size: int = torch.distributed.get_world_size()
3238
backend = backend or torch.distributed.get_backend(

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ def fused_experts_with_mc2(
7474
local_rank = torch.distributed.get_rank(group=ep_group)
7575
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
7676

77-
world_szie = torch.distributed.get_world_size()
78-
tp_size = world_szie // all_to_all_group_size
77+
tp_size = get_etp_group().world_size
7978
tp_rank = rank % tp_size
8079

8180
stage1_kwargs = {

vllm_ascend/patch/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@
7777
# on multi-node dp inference implementation
7878
# 4. `ParallelConfig.stateless_init_dp_group`
7979
# Why:
80-
# vLLM use gloo backend by default to initialize stateless dp process gourp, but we want to use hccl here to
81-
# get better performance
80+
# vLLM use gloo backend by default to initialize stateless dp process group, but we want to use hccl here to
81+
# get better performance. Initialize the global variable of dp_group to prefill dummy_run.
8282
# How:
83-
# adopt nccl backend to init process group
83+
# adopt nccl backend to init process group and add the global variable of dp_group.
8484
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
8585
# Future Plan:
8686
# Remove those patch when vllm merged them
87+
# Add the global variable of dp_group in platform when vllm merged them.
8788
#
8889
#
8990
# * Worker Patch:

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
import torch
2121
import vllm
2222
import vllm.distributed
23+
import vllm.envs as envs
2324
from torch.distributed import ProcessGroup
2425
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
2526
_get_default_timeout,
2627
is_nccl_available)
2728
from torch.distributed.rendezvous import rendezvous
2829
from vllm.config import ParallelConfig
2930

31+
_DP_GROUP = None
32+
3033

3134
def ascend_destroy_model_parallel():
3235
"""Set the groups to none and destroy them."""
@@ -164,10 +167,9 @@ def parallel_config_get_dp_port(self) -> int:
164167
"""
165168
answer = self.data_parallel_master_port
166169
self.data_parallel_master_port += 1
167-
import os
168170

169171
# NOTE: Get port from envs directly when using torchrun
170-
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
172+
port = envs.VLLM_DP_MASTER_PORT if envs.VLLM_DP_MASTER_PORT else answer
171173
return port
172174

173175

@@ -183,10 +185,16 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
183185
self.data_parallel_rank,
184186
self.data_parallel_size,
185187
backend="gloo")
188+
global _DP_GROUP
189+
_DP_GROUP = dp_group
186190

187191
return dp_group
188192

189193

194+
def get_dp_group():
195+
return _DP_GROUP
196+
197+
190198
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
191199
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
192200
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group

vllm_ascend/worker/model_runner_v1.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import numpy.typing as npt
3030
import torch
3131
import torch.nn as nn
32+
from torch.distributed import ReduceOp
3233
from vllm.attention import AttentionType, get_attn_backend
3334
from vllm.attention.layer import Attention
3435
from vllm.config import CompilationLevel, VllmConfig
@@ -59,6 +60,8 @@
5960

6061
from vllm_ascend.attention.attention import AttentionMaskBuilder
6162
from vllm_ascend.attention.attention_v1 import AscendAttentionState
63+
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
64+
get_dp_group
6265
from vllm_ascend.platform import NPUPlatform
6366
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
6467
from vllm_ascend.utils import vllm_version_is
@@ -328,6 +331,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
328331
False) and self.vllm_config.model_config.use_mla
329332
self.use_cached_npu_graph = additional_config.get(
330333
"use_cached_npu_graph", False)
334+
self.has_prefilled = False
335+
self.dp_group = get_dp_group()
331336

332337
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
333338
"""Update the cached states and the persistent batch with the scheduler
@@ -635,6 +640,22 @@ def _process_reqs(
635640
device=input_ids.device)
636641
input_ids = torch.cat([input_ids, padding])
637642
positions = torch.cat([positions, padding])
643+
if self.has_prefilled and not attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
644+
self.has_prefilled = False
645+
if not self.has_prefilled and self.enable_torchair_graph_mode:
646+
self.has_prefilled = self.has_prefilled_all_rank(
647+
attn_metadata.attn_state == AscendAttentionState.DecodeOnly)
648+
649+
if self.dp_group:
650+
while not self.has_prefilled and self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
651+
self._dummy_run(1)
652+
tensor = torch.tensor([1], dtype=torch.int32, device="cpu")
653+
torch.distributed.all_reduce(tensor,
654+
op=ReduceOp.MAX,
655+
group=self.dp_group)
656+
self.has_prefilled = self.has_prefilled_all_rank(
657+
attn_metadata.attn_state ==
658+
AscendAttentionState.DecodeOnly)
638659

639660
# Run forward pass
640661
with set_forward_context(attn_metadata,
@@ -644,7 +665,7 @@ def _process_reqs(
644665
if self.enable_torchair_graph_mode:
645666
model_kwargs["kv_caches"] = self.kv_caches
646667
model_kwargs["attn_metadata"] = attn_metadata
647-
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
668+
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly and self.has_prefilled:
648669
torch._dynamo.mark_static(input_ids)
649670
torch._dynamo.mark_static(positions)
650671
torch._dynamo.mark_static(attn_metadata.decode.block_table)
@@ -772,6 +793,15 @@ def _calc_spec_decode_metadata(
772793
)
773794
return metadata
774795

796+
def has_prefilled_all_rank(self, has_prefilled: bool) -> bool:
797+
tensor = torch.tensor([has_prefilled], dtype=torch.int32, device="cpu")
798+
if self.dp_group:
799+
torch.distributed.all_reduce(tensor,
800+
op=ReduceOp.MIN,
801+
group=self.dp_group)
802+
aggregated_has_prefilled = bool(tensor.item())
803+
return aggregated_has_prefilled
804+
775805
def apply_grammar_bitmask(
776806
self,
777807
scheduler_output: "SchedulerOutput",
@@ -1039,7 +1069,11 @@ def _profile_multimodal(self) -> None:
10391069
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
10401070

10411071
@torch.inference_mode()
1042-
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
1072+
def _dummy_run(
1073+
self,
1074+
num_tokens: int,
1075+
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
1076+
) -> torch.Tensor:
10431077
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
10441078
# for dummy run with LoRA so that the num_reqs collectively
10451079
# has num_tokens in total.
@@ -1083,11 +1117,34 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
10831117
})
10841118

10851119
with set_forward_context(None, self.vllm_config):
1086-
hidden_states = model(
1087-
input_ids=input_ids,
1088-
positions=positions,
1089-
intermediate_tensors=intermediate_tensors,
1090-
inputs_embeds=inputs_embeds)
1120+
if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly:
1121+
attn_metadata = self.attn_metadata_builder.dummy_build(
1122+
num_reqs=num_tokens, num_actual_tokens=1)
1123+
torch._dynamo.mark_static(input_ids)
1124+
torch._dynamo.mark_static(positions)
1125+
torch._dynamo.mark_static(attn_metadata.decode.block_table)
1126+
torch._dynamo.mark_static(
1127+
attn_metadata.decode.input_positions)
1128+
torch._dynamo.mark_static(attn_metadata.slot_mapping)
1129+
for kv in self.kv_caches:
1130+
assert isinstance(kv,
1131+
tuple), "kv_cache must be a tuple"
1132+
torch._dynamo.mark_static(kv[0])
1133+
torch._dynamo.mark_static(kv[1])
1134+
hidden_states = self.compile_model(
1135+
input_ids=input_ids,
1136+
positions=positions,
1137+
intermediate_tensors=intermediate_tensors,
1138+
inputs_embeds=None,
1139+
kv_caches=self.kv_caches,
1140+
attn_metadata=attn_metadata,
1141+
)
1142+
else:
1143+
hidden_states = model(
1144+
input_ids=input_ids,
1145+
positions=positions,
1146+
intermediate_tensors=intermediate_tensors,
1147+
inputs_embeds=inputs_embeds)
10911148
return hidden_states
10921149

10931150
def profile_run(self) -> None:

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def execute_model(
173173
scheduler_output: "SchedulerOutput",
174174
) -> Optional[ModelRunnerOutput]:
175175
output = self.model_runner.execute_model(scheduler_output)
176-
return output if self.rank == 0 else None
176+
return output if self.is_driver_worker else None
177177

178178
def load_model(self) -> None:
179179
self.model_runner.load_model()

0 commit comments

Comments
 (0)