Skip to content

Commit d8e3159

Browse files
committed
feat: optimize forward metadata collection across dp ranks
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent e4e9ea0 commit d8e3159

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -628,17 +628,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
628628
if batch_changed:
629629
self.input_batch.refresh_sampling_metadata()
630630

631-
def _get_forward_metadata_across_dp(
632-
self, total_num_scheduled_tokens: int,
633-
with_prefill: bool) -> tuple[int, bool]:
634-
forward_metadata = torch.tensor(
635-
[total_num_scheduled_tokens, with_prefill],
636-
device="cpu",
637-
dtype=torch.int32)
638-
dist.all_reduce(forward_metadata,
639-
op=ReduceOp.MAX,
640-
group=get_dp_group().cpu_group)
641-
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
631+
def _get_forward_metadata_across_dp(self, num_tokens: int,
632+
with_prefill: bool) -> tuple[int, bool]:
633+
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
634+
device="npu", dtype=torch.int32)
635+
global_forward_metadata = get_dp_group().all_gather(
636+
local_forward_metadata)
637+
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
638+
with_prefill = bool(global_forward_metadata[:, 1].any())
639+
return num_tokens_across_dp, with_prefill
642640

643641
def get_eagle_atten_dict(
644642
self,
@@ -1107,9 +1105,12 @@ def _process_reqs(
11071105
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
11081106
]
11091107

1108+
num_tokens_across_dp = None
11101109
if self.dp_size > 1:
1111-
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
1112-
total_num_scheduled_tokens, with_prefill)
1110+
num_tokens_across_dp, with_prefill = \
1111+
self._get_forward_metadata_across_dp(num_input_tokens,
1112+
with_prefill)
1113+
max_num_tokens = int(num_tokens_across_dp.max().item())
11131114
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
11141115
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
11151116

@@ -1118,6 +1119,8 @@ def _process_reqs(
11181119
if self.dp_size > 1:
11191120
padded_batch_size = self.select_torchair_padded_batch_size(
11201121
max_num_tokens)
1122+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1123+
padded_batch_size)
11211124
else:
11221125
padded_batch_size = self.select_torchair_padded_batch_size(
11231126
total_num_scheduled_tokens)
@@ -1196,7 +1199,8 @@ def _process_reqs(
11961199
# Run forward pass
11971200
with set_forward_context(attn_metadata,
11981201
self.vllm_config,
1199-
num_tokens=num_input_tokens):
1202+
num_tokens=num_input_tokens,
1203+
num_tokens_across_dp=num_tokens_across_dp):
12001204
with ProfileExecuteDuration().capture_async("forward"):
12011205
model_kwargs = {}
12021206
if self.torchair_graph_enabled:
@@ -1819,6 +1823,7 @@ def _dummy_run(
18191823
is_compile: bool = False,
18201824
with_prefill: bool = True,
18211825
skip_attn: bool = True,
1826+
num_tokens_across_dp: Optional[int] = None,
18221827
) -> torch.Tensor:
18231828
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
18241829
# for dummy run with LoRA so that the num_reqs collectively
@@ -1873,7 +1878,8 @@ def _dummy_run(
18731878

18741879
with set_forward_context(None,
18751880
self.vllm_config,
1876-
num_tokens=num_tokens):
1881+
num_tokens=num_tokens,
1882+
num_tokens_across_dp=num_tokens_across_dp):
18771883
if self.torchair_graph_enabled and not with_prefill:
18781884
attn_metadata = self.attn_metadata_builder.build_dummy(
18791885
num_reqs=num_tokens, num_actual_tokens=1)

vllm_ascend/worker/worker_v1.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -277,16 +277,29 @@ def pin_lora(self, lora_id: int) -> bool:
277277

278278
def execute_dummy_batch(self) -> None:
279279
runner = self.model_runner
280-
max_num_tokens = 1
280+
281+
# If torchair graph is enabled, notify the other DP ranks that this is a
282+
# dummy run by using '-1' as a flag for num_tokens. This will be
283+
# replaced with the final determined graph size before the forward pass.
284+
num_tokens = (-1 if runner.torchair_graph_enabled and not with_prefill
285+
else 1)
286+
num_tokens_across_dp = None
281287
with_prefill = False
288+
282289
if runner.dp_size > 1:
283-
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
284-
max_num_tokens, with_prefill)
290+
num_tokens_across_dp, with_prefill = \
291+
runner._get_forward_metadata_across_dp(num_tokens, with_prefill)
292+
num_tokens = int(num_tokens_across_dp.max().item())
293+
285294
if runner.torchair_graph_enabled and not with_prefill:
286-
max_num_tokens = runner.select_torchair_padded_batch_size(
287-
max_num_tokens)
288-
runner._dummy_run(max_num_tokens,
295+
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)
296+
if num_tokens_across_dp is not None:
297+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
298+
num_tokens)
299+
300+
runner._dummy_run(num_tokens,
289301
is_compile=False,
302+
num_tokens_across_dp=num_tokens_across_dp,
290303
with_prefill=with_prefill)
291304

292305
def _init_worker_distributed_environment(self) -> None:

0 commit comments

Comments
 (0)