Skip to content

Commit 9df5c0f

Browse files
committed
fix graph prefill
Signed-off-by: ningbenzhe1 <ningbenzhe@huawei.com>
1 parent 00e0243 commit 9df5c0f

File tree

5 files changed

+82
-9
lines changed

5 files changed

+82
-9
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ class AscendMetadata:
129129
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
130130
attn_mask: Optional[torch.Tensor] = None
131131

132+
# For logging.
133+
num_input_tokens: int = 0 # Number of tokens including padding.
134+
132135

133136
class AscendAttentionMetadataBuilder:
134137

vllm_ascend/distributed/parallel_state.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,18 @@ def get_etp_group() -> GroupCoordinator:
2121
return _ETP
2222

2323

24+
def model_parallel_initialized():
25+
return (_ETP is not None and _EP is not None)
26+
27+
2428
def init_ascend_model_parallel(
2529
tensor_model_parallel_size: int = 1,
2630
pipeline_model_parallel_size: int = 1,
2731
expert_tensor_parallel_size: int = 1,
2832
backend: Optional[str] = None,
2933
):
34+
if model_parallel_initialized():
35+
return
3036
assert torch.distributed.is_initialized()
3137
world_size: int = torch.distributed.get_world_size()
3238
backend = backend or torch.distributed.get_backend(

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
import torch
2121
import vllm
2222
import vllm.distributed
23+
import vllm.envs as envs
2324
from torch.distributed import ProcessGroup
2425
from torch.distributed.distributed_c10d import (Backend, PrefixStore,
2526
_get_default_timeout,
2627
is_nccl_available)
2728
from torch.distributed.rendezvous import rendezvous
2829
from vllm.config import ParallelConfig
2930

31+
_DP_GROUP = None
32+
3033

3134
def ascend_destroy_model_parallel():
3235
"""Set the groups to none and destroy them."""
@@ -164,10 +167,9 @@ def parallel_config_get_dp_port(self) -> int:
164167
"""
165168
answer = self.data_parallel_master_port
166169
self.data_parallel_master_port += 1
167-
import os
168170

169171
# NOTE: Get port from envs directly when using torchrun
170-
port = int(os.environ.get("MASTER_PORT", answer)) # type: ignore
172+
port = envs.VLLM_DP_MASTER_PORT if envs.VLLM_DP_MASTER_PORT else answer
171173
return port
172174

173175

@@ -183,10 +185,16 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
183185
self.data_parallel_rank,
184186
self.data_parallel_size,
185187
backend="gloo")
188+
global _DP_GROUP
189+
_DP_GROUP = dp_group
186190

187191
return dp_group
188192

189193

194+
def get_dp_group():
195+
return _DP_GROUP
196+
197+
190198
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
191199
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
192200
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group

vllm_ascend/worker/model_runner_v1.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
import numpy.typing as npt
3030
import torch
3131
import torch.nn as nn
32+
from torch.distributed import ReduceOp
33+
3234
from vllm.attention import AttentionType, get_attn_backend
3335
from vllm.attention.layer import Attention
3436
from vllm.config import CompilationLevel, VllmConfig
@@ -53,6 +55,7 @@
5355

5456
from vllm_ascend.attention.attention import AttentionMaskBuilder
5557
from vllm_ascend.attention.attention_v1 import AscendAttentionState
58+
from vllm_ascend.patch.platform.patch_common.patch_distributed import get_dp_group
5659
from vllm_ascend.platform import NPUPlatform
5760
from vllm_ascend.utils import vllm_version_is
5861

@@ -296,6 +299,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
296299
False) and self.vllm_config.model_config.use_mla
297300
self.use_cached_npu_graph = additional_config.get(
298301
"use_cached_npu_graph", False)
302+
self.has_prefilled = False
303+
self.dp_group = get_dp_group()
299304

300305
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
301306
"""Update the cached states and the persistent batch with the scheduler
@@ -595,6 +600,22 @@ def _process_reqs(
595600
device=input_ids.device)
596601
input_ids = torch.cat([input_ids, padding])
597602
positions = torch.cat([positions, padding])
603+
if self.has_prefilled and not attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
604+
self.has_prefilled = False
605+
if not self.has_prefilled and self.enable_torchair_graph_mode:
606+
self.has_prefilled = self.has_prefilled_all_rank(
607+
attn_metadata.attn_state == AscendAttentionState.DecodeOnly)
608+
609+
if self.dp_group:
610+
while not self.has_prefilled and self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
611+
self._dummy_run(1)
612+
tensor = torch.tensor([1], dtype=torch.int32, device="cpu")
613+
torch.distributed.all_reduce(tensor,
614+
op=ReduceOp.MAX,
615+
group=self.dp_group)
616+
self.has_prefilled = self.has_prefilled_all_rank(
617+
attn_metadata.attn_state ==
618+
AscendAttentionState.DecodeOnly)
598619

599620
# Run forward pass
600621
with set_forward_context(attn_metadata,
@@ -604,7 +625,7 @@ def _process_reqs(
604625
if self.enable_torchair_graph_mode:
605626
model_kwargs["kv_caches"] = self.kv_caches
606627
model_kwargs["attn_metadata"] = attn_metadata
607-
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
628+
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly and self.has_prefilled:
608629
torch._dynamo.mark_static(input_ids)
609630
torch._dynamo.mark_static(positions)
610631
torch._dynamo.mark_static(attn_metadata.decode.block_table)
@@ -633,6 +654,15 @@ def _process_reqs(
633654

634655
return hidden_states[sample_indices]
635656

657+
def has_prefilled_all_rank(self, has_prefilled: bool) -> bool:
658+
tensor = torch.tensor([has_prefilled], dtype=torch.int32, device="cpu")
659+
if self.dp_group:
660+
torch.distributed.all_reduce(tensor,
661+
op=ReduceOp.MIN,
662+
group=self.dp_group)
663+
aggregated_has_prefilled = bool(tensor.item())
664+
return aggregated_has_prefilled
665+
636666
def apply_grammar_bitmask(
637667
self,
638668
scheduler_output: "SchedulerOutput",
@@ -832,7 +862,11 @@ def _profile_multimodal(self) -> None:
832862
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
833863

834864
@torch.inference_mode()
835-
def _dummy_run(self, num_tokens: int) -> torch.Tensor:
865+
def _dummy_run(
866+
self,
867+
num_tokens: int,
868+
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
869+
) -> torch.Tensor:
836870
model = self.model
837871
if self.is_multimodal_model:
838872
input_ids = None
@@ -861,10 +895,32 @@ def _dummy_run(self, num_tokens: int) -> torch.Tensor:
861895
})
862896

863897
with set_forward_context(None, self.vllm_config):
864-
hidden_states = model(input_ids=input_ids,
865-
positions=positions,
866-
intermediate_tensors=intermediate_tensors,
867-
inputs_embeds=inputs_embeds)
898+
if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly:
899+
attn_metadata = self.attn_metadata_builder.dummy_build(
900+
num_reqs=num_tokens, num_actual_tokens=1)
901+
torch._dynamo.mark_static(input_ids)
902+
torch._dynamo.mark_static(positions)
903+
torch._dynamo.mark_static(attn_metadata.decode.block_table)
904+
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
905+
torch._dynamo.mark_static(attn_metadata.slot_mapping)
906+
for kv in self.kv_caches:
907+
assert isinstance(kv, tuple), "kv_cache must be a tuple"
908+
torch._dynamo.mark_static(kv[0])
909+
torch._dynamo.mark_static(kv[1])
910+
hidden_states = self.compile_model(
911+
input_ids=input_ids,
912+
positions=positions,
913+
intermediate_tensors=intermediate_tensors,
914+
inputs_embeds=None,
915+
kv_caches=self.kv_caches,
916+
attn_metadata=attn_metadata,
917+
)
918+
else:
919+
hidden_states = model(
920+
input_ids=input_ids,
921+
positions=positions,
922+
intermediate_tensors=intermediate_tensors,
923+
inputs_embeds=inputs_embeds)
868924
return hidden_states
869925

870926
def profile_run(self) -> None:

vllm_ascend/worker/worker_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def execute_model(
172172
scheduler_output: "SchedulerOutput",
173173
) -> Optional[ModelRunnerOutput]:
174174
output = self.model_runner.execute_model(scheduler_output)
175-
return output if self.rank == 0 else None
175+
return output if self.is_driver_worker else None
176176

177177
def load_model(self) -> None:
178178
self.model_runner.load_model()

0 commit comments

Comments
 (0)