Skip to content

Commit 293eefe

Browse files
committed
fix graph mode pd stuck
Signed-off-by: ningbenzhe1 <ningbenzhe@huawei.com>
1 parent 1f9fb86 commit 293eefe

File tree

6 files changed

+83
-12
lines changed

6 files changed

+83
-12
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ class AscendMetadata:
129129
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
130130
attn_mask: Optional[torch.Tensor] = None
131131

132+
# For logging.
133+
num_input_tokens: int = 0 # Number of tokens including padding.
134+
132135

133136
class AscendAttentionMetadataBuilder:
134137

vllm_ascend/distributed/parallel_state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,18 @@ def get_etp_group() -> GroupCoordinator:
2121
return _ETP
2222

2323

24+
def model_parallel_initialized():
25+
return (_ETP is not None and _EP is not None)
26+
27+
2428
def init_ascend_model_parallel(
2529
tensor_model_parallel_size: int = 1,
2630
pipeline_model_parallel_size: int = 1,
2731
expert_tensor_parallel_size: int = 1,
2832
backend: Optional[str] = None,
2933
):
34+
if model_parallel_initialized():
35+
return
3036
assert torch.distributed.is_initialized()
3137
world_size: int = torch.distributed.get_world_size()
3238
backend = backend or torch.distributed.get_backend(

vllm_ascend/ops/fused_moe.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ def fused_experts_with_mc2(
7474
local_rank = torch.distributed.get_rank(group=ep_group)
7575
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
7676

77-
world_szie = torch.distributed.get_world_size()
78-
tp_size = world_szie // all_to_all_group_size
77+
tp_size = get_etp_group().world_size
7978
tp_rank = rank % tp_size
8079

8180
stage1_kwargs = {

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
import torch
2121
import vllm
2222
import vllm.distributed
23+
import vllm.envs as envs
2324
from torch.distributed import ProcessGroup
2425
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
2526
_get_default_timeout,
2627
is_nccl_available)
2728
from torch.distributed.rendezvous import rendezvous
2829
from vllm.config import ParallelConfig
2930

31+
_DP_GROUP = None
32+
3033

3134
def ascend_destroy_model_parallel():
3235
"""Set the groups to none and destroy them."""
@@ -164,10 +167,9 @@ def parallel_config_get_dp_port(self) -> int:
164167
"""
165168
answer = self.data_parallel_master_port
166169
self.data_parallel_master_port += 1
167-
import os
168170

169171
# NOTE: Get port from envs directly when using torchrun
170-
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
172+
port = envs.VLLM_DP_MASTER_PORT if envs.VLLM_DP_MASTER_PORT else answer
171173
return port
172174

173175

@@ -183,10 +185,16 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
183185
self.data_parallel_rank,
184186
self.data_parallel_size,
185187
backend="gloo")
188+
global _DP_GROUP
189+
_DP_GROUP = dp_group
186190

187191
return dp_group
188192

189193

194+
def get_dp_group():
195+
return _DP_GROUP
196+
197+
190198
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
191199
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
192200
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group

vllm_ascend/worker/model_runner_v1.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import numpy.typing as npt
3030
import torch
3131
import torch.nn as nn
32+
from torch.distributed import ReduceOp
3233
from vllm.attention import AttentionType, get_attn_backend
3334
from vllm.attention.layer import Attention
3435
from vllm.config import CompilationLevel, VllmConfig
@@ -59,6 +60,8 @@
5960

6061
from vllm_ascend.attention.attention import AttentionMaskBuilder
6162
from vllm_ascend.attention.attention_v1 import AscendAttentionState
63+
from vllm_ascend.patch.platform.patch_common.patch_distributed import \
64+
get_dp_group
6265
from vllm_ascend.platform import NPUPlatform
6366
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
6467
from vllm_ascend.utils import vllm_version_is
@@ -355,6 +358,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
355358
False) and self.vllm_config.model_config.use_mla
356359
self.use_cached_npu_graph = additional_config.get(
357360
"use_cached_npu_graph", False)
361+
self.has_prefilled = False
362+
self.dp_group = get_dp_group()
358363

359364
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
360365
"""Update the cached states and the persistent batch with the scheduler
@@ -659,6 +664,22 @@ def _process_reqs(
659664
device=input_ids.device)
660665
input_ids = torch.cat([input_ids, padding])
661666
positions = torch.cat([positions, padding])
667+
if self.has_prefilled and not attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
668+
self.has_prefilled = False
669+
if not self.has_prefilled and self.enable_torchair_graph_mode:
670+
self.has_prefilled = self.has_prefilled_all_rank(
671+
attn_metadata.attn_state == AscendAttentionState.DecodeOnly)
672+
673+
if self.dp_group:
674+
while not self.has_prefilled and self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
675+
self._dummy_run(1)
676+
tensor = torch.tensor([1], dtype=torch.int32, device="cpu")
677+
torch.distributed.all_reduce(tensor,
678+
op=ReduceOp.MAX,
679+
group=self.dp_group)
680+
self.has_prefilled = self.has_prefilled_all_rank(
681+
attn_metadata.attn_state ==
682+
AscendAttentionState.DecodeOnly)
662683

663684
# Run forward pass
664685
with set_forward_context(attn_metadata,
@@ -668,7 +689,7 @@ def _process_reqs(
668689
if self.enable_torchair_graph_mode:
669690
model_kwargs["kv_caches"] = self.kv_caches
670691
model_kwargs["attn_metadata"] = attn_metadata
671-
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
692+
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly and self.has_prefilled:
672693
torch._dynamo.mark_static(input_ids)
673694
torch._dynamo.mark_static(positions)
674695
torch._dynamo.mark_static(attn_metadata.decode.block_table)
@@ -796,6 +817,15 @@ def _calc_spec_decode_metadata(
796817
)
797818
return metadata
798819

820+
def has_prefilled_all_rank(self, has_prefilled: bool) -> bool:
821+
tensor = torch.tensor([has_prefilled], dtype=torch.int32, device="cpu")
822+
if self.dp_group:
823+
torch.distributed.all_reduce(tensor,
824+
op=ReduceOp.MIN,
825+
group=self.dp_group)
826+
aggregated_has_prefilled = bool(tensor.item())
827+
return aggregated_has_prefilled
828+
799829
def apply_grammar_bitmask(
800830
self,
801831
scheduler_output: "SchedulerOutput",
@@ -1063,7 +1093,11 @@ def _profile_multimodal(self) -> None:
10631093
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
10641094

10651095
@torch.inference_mode()
1066-
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
1096+
def _dummy_run(
1097+
self,
1098+
num_tokens: int,
1099+
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
1100+
) -> torch.Tensor:
10671101
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
10681102
# for dummy run with LoRA so that the num_reqs collectively
10691103
# has num_tokens in total.
@@ -1107,11 +1141,32 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
11071141
})
11081142

11091143
with set_forward_context(None, self.vllm_config):
1110-
hidden_states = model(
1111-
input_ids=input_ids,
1112-
positions=positions,
1113-
intermediate_tensors=intermediate_tensors,
1114-
inputs_embeds=inputs_embeds)
1144+
if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly:
1145+
attn_metadata = self.attn_metadata_builder.dummy_build(
1146+
num_reqs=num_tokens, num_actual_tokens=1)
1147+
torch._dynamo.mark_static(input_ids)
1148+
torch._dynamo.mark_static(positions)
1149+
torch._dynamo.mark_static(attn_metadata.decode.block_table)
1150+
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
1151+
torch._dynamo.mark_static(attn_metadata.slot_mapping)
1152+
for kv in self.kv_caches:
1153+
assert isinstance(kv, tuple), "kv_cache must be a tuple"
1154+
torch._dynamo.mark_static(kv[0])
1155+
torch._dynamo.mark_static(kv[1])
1156+
hidden_states = self.compile_model(
1157+
input_ids=input_ids,
1158+
positions=positions,
1159+
intermediate_tensors=intermediate_tensors,
1160+
inputs_embeds=None,
1161+
kv_caches=self.kv_caches,
1162+
attn_metadata=attn_metadata,
1163+
)
1164+
else:
1165+
hidden_states = model(
1166+
input_ids=input_ids,
1167+
positions=positions,
1168+
intermediate_tensors=intermediate_tensors,
1169+
inputs_embeds=inputs_embeds)
11151170
return hidden_states
11161171

11171172
def profile_run(self) -> None:

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def execute_model(
173173
scheduler_output: "SchedulerOutput",
174174
) -> Optional[ModelRunnerOutput]:
175175
output = self.model_runner.execute_model(scheduler_output)
176-
return output if self.rank == 0 else None
176+
return output if self.is_driver_worker else None
177177

178178
def load_model(self) -> None:
179179
self.model_runner.load_model()

0 commit comments

Comments
 (0)