Skip to content

Commit 62c12ca

Browse files
committed
feat: optimize forward metadata collection across dp ranks
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent 0e43813 commit 62c12ca

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -614,17 +614,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
614614
if batch_changed:
615615
self.input_batch.refresh_sampling_metadata()
616616

617-
def _get_forward_metadata_across_dp(
618-
self, total_num_scheduled_tokens: int,
619-
with_prefill: bool) -> tuple[int, bool]:
620-
forward_metadata = torch.tensor(
621-
[total_num_scheduled_tokens, with_prefill],
622-
device="cpu",
623-
dtype=torch.int32)
624-
dist.all_reduce(forward_metadata,
625-
op=ReduceOp.MAX,
626-
group=get_dp_group().cpu_group)
627-
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
617+
def _get_forward_metadata_across_dp(self, num_tokens: int,
618+
with_prefill: bool) -> tuple[int, bool]:
619+
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
620+
device="npu", dtype=torch.int32)
621+
global_forward_metadata = get_dp_group().all_gather(
622+
local_forward_metadata)
623+
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
624+
with_prefill = bool(global_forward_metadata[:, 1].any())
625+
return num_tokens_across_dp, with_prefill
628626

629627
def get_eagle_atten_dict(
630628
self,
@@ -1093,9 +1091,12 @@ def _process_reqs(
10931091
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10941092
]
10951093

1094+
num_tokens_across_dp = None
10961095
if self.dp_size > 1:
1097-
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
1098-
total_num_scheduled_tokens, with_prefill)
1096+
num_tokens_across_dp, with_prefill = \
1097+
self._get_forward_metadata_across_dp(num_input_tokens,
1098+
with_prefill)
1099+
max_num_tokens = int(num_tokens_across_dp.max().item())
10991100
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
11001101
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
11011102

@@ -1104,6 +1105,8 @@ def _process_reqs(
11041105
if self.dp_size > 1:
11051106
padded_batch_size = self.select_torchair_padded_batch_size(
11061107
max_num_tokens)
1108+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1109+
padded_batch_size)
11071110
else:
11081111
padded_batch_size = self.select_torchair_padded_batch_size(
11091112
total_num_scheduled_tokens)
@@ -1182,7 +1185,8 @@ def _process_reqs(
11821185
# Run forward pass
11831186
with set_forward_context(attn_metadata,
11841187
self.vllm_config,
1185-
num_tokens=num_input_tokens):
1188+
num_tokens=num_input_tokens,
1189+
num_tokens_across_dp=num_tokens_across_dp):
11861190
with ProfileExecuteDuration().capture_async("forward"):
11871191
model_kwargs = {}
11881192
if self.torchair_graph_enabled:
@@ -1775,6 +1779,7 @@ def _dummy_run(
17751779
is_compile: bool = False,
17761780
with_prefill: bool = True,
17771781
skip_attn: bool = True,
1782+
num_tokens_across_dp: Optional[int] = None,
17781783
) -> torch.Tensor:
17791784
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
17801785
# for dummy run with LoRA so that the num_reqs collectively
@@ -1829,7 +1834,8 @@ def _dummy_run(
18291834

18301835
with set_forward_context(None,
18311836
self.vllm_config,
1832-
num_tokens=num_tokens):
1837+
num_tokens=num_tokens,
1838+
num_tokens_across_dp=num_tokens_across_dp):
18331839
if self.torchair_graph_enabled and not with_prefill:
18341840
attn_metadata = self.attn_metadata_builder.build_dummy(
18351841
num_reqs=num_tokens, num_actual_tokens=1)

vllm_ascend/worker/worker_v1.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,16 +247,29 @@ def pin_lora(self, lora_id: int) -> bool:
247247

248248
def execute_dummy_batch(self) -> None:
249249
runner = self.model_runner
250-
max_num_tokens = 1
250+
251+
# If torchair graph is enabled, notify the other DP ranks that this is a
252+
# dummy run by using '-1' as a flag for num_tokens. This will be
253+
# replaced with the final determined graph size before the forward pass.
254+
num_tokens = (-1 if runner.torchair_graph_enabled and not with_prefill
255+
else 1)
256+
num_tokens_across_dp = None
251257
with_prefill = False
258+
252259
if runner.dp_size > 1:
253-
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
254-
max_num_tokens, with_prefill)
260+
num_tokens_across_dp, with_prefill = \
261+
runner._get_forward_metadata_across_dp(num_tokens, with_prefill)
262+
num_tokens = int(num_tokens_across_dp.max().item())
263+
255264
if runner.torchair_graph_enabled and not with_prefill:
256-
max_num_tokens = runner.select_torchair_padded_batch_size(
257-
max_num_tokens)
258-
runner._dummy_run(max_num_tokens,
265+
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)
266+
if num_tokens_across_dp is not None:
267+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
268+
num_tokens)
269+
270+
runner._dummy_run(num_tokens,
259271
is_compile=False,
272+
num_tokens_across_dp=num_tokens_across_dp,
260273
with_prefill=with_prefill)
261274

262275
def _init_worker_distributed_environment(self) -> None:

0 commit comments

Comments
 (0)