Skip to content

Commit dada008

Browse files
committed
feat: optimize forward metadata collection across dp ranks
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
1 parent e511ddd commit dada008

File tree

2 files changed

+40
-21
lines changed

2 files changed

+40
-21
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -621,17 +621,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
621621
if batch_changed:
622622
self.input_batch.refresh_sampling_metadata()
623623

624-
def _get_forward_metadata_across_dp(
625-
self, total_num_scheduled_tokens: int,
626-
with_prefill: bool) -> tuple[int, bool]:
627-
forward_metadata = torch.tensor(
628-
[total_num_scheduled_tokens, with_prefill],
629-
device="cpu",
630-
dtype=torch.int32)
631-
dist.all_reduce(forward_metadata,
632-
op=ReduceOp.MAX,
633-
group=get_dp_group().cpu_group)
634-
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
624+
def _get_forward_metadata_across_dp(self, num_tokens: int,
625+
with_prefill: bool) -> tuple[int, bool]:
626+
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
627+
device="npu", dtype=torch.int32)
628+
global_forward_metadata = get_dp_group().all_gather(
629+
local_forward_metadata)
630+
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
631+
with_prefill = bool(global_forward_metadata[:, 1].any())
632+
return num_tokens_across_dp, with_prefill
635633

636634
def get_eagle_atten_dict(
637635
self,
@@ -1100,9 +1098,12 @@ def _process_reqs(
11001098
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
11011099
]
11021100

1101+
num_tokens_across_dp = None
11031102
if self.dp_size > 1:
1104-
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
1105-
total_num_scheduled_tokens, with_prefill)
1103+
num_tokens_across_dp, with_prefill = \
1104+
self._get_forward_metadata_across_dp(num_input_tokens,
1105+
with_prefill)
1106+
max_num_tokens = int(num_tokens_across_dp.max().item())
11061107
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
11071108
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
11081109

@@ -1111,6 +1112,8 @@ def _process_reqs(
11111112
if self.dp_size > 1:
11121113
padded_batch_size = self.select_torchair_padded_batch_size(
11131114
max_num_tokens)
1115+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
1116+
padded_batch_size)
11141117
else:
11151118
padded_batch_size = self.select_torchair_padded_batch_size(
11161119
total_num_scheduled_tokens)
@@ -1189,7 +1192,8 @@ def _process_reqs(
11891192
# Run forward pass
11901193
with set_forward_context(attn_metadata,
11911194
self.vllm_config,
1192-
num_tokens=num_input_tokens):
1195+
num_tokens=num_input_tokens,
1196+
num_tokens_across_dp=num_tokens_across_dp):
11931197
with ProfileExecuteDuration().capture_async("forward"):
11941198
model_kwargs = {}
11951199
if self.torchair_graph_enabled:
@@ -1806,6 +1810,7 @@ def _dummy_run(
18061810
is_compile: bool = False,
18071811
with_prefill: bool = True,
18081812
skip_attn: bool = True,
1813+
num_tokens_across_dp: Optional[int] = None,
18091814
) -> torch.Tensor:
18101815
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
18111816
# for dummy run with LoRA so that the num_reqs collectively
@@ -1860,7 +1865,8 @@ def _dummy_run(
18601865

18611866
with set_forward_context(None,
18621867
self.vllm_config,
1863-
num_tokens=num_tokens):
1868+
num_tokens=num_tokens,
1869+
num_tokens_across_dp=num_tokens_across_dp):
18641870
if self.torchair_graph_enabled and not with_prefill:
18651871
attn_metadata = self.attn_metadata_builder.build_dummy(
18661872
num_reqs=num_tokens, num_actual_tokens=1)

vllm_ascend/worker/worker_v1.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,16 +247,29 @@ def pin_lora(self, lora_id: int) -> bool:
247247

248248
def execute_dummy_batch(self) -> None:
249249
runner = self.model_runner
250-
max_num_tokens = 1
250+
251+
# If torchair graph is enabled, notify the other DP ranks that this is a
252+
# dummy run by using '-1' as a flag for num_tokens. This will be
253+
# replaced with the final determined graph size before the forward pass.
254+
num_tokens = (-1 if runner.torchair_graph_enabled and not with_prefill
255+
else 1)
256+
num_tokens_across_dp = None
251257
with_prefill = False
258+
252259
if runner.dp_size > 1:
253-
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
254-
max_num_tokens, with_prefill)
260+
num_tokens_across_dp, with_prefill = \
261+
runner._get_forward_metadata_across_dp(num_tokens, with_prefill)
262+
num_tokens = int(num_tokens_across_dp.max().item())
263+
255264
if runner.torchair_graph_enabled and not with_prefill:
256-
max_num_tokens = runner.select_torchair_padded_batch_size(
257-
max_num_tokens)
258-
runner._dummy_run(max_num_tokens,
265+
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)
266+
if num_tokens_across_dp is not None:
267+
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
268+
num_tokens)
269+
270+
runner._dummy_run(num_tokens,
259271
is_compile=False,
272+
num_tokens_across_dp=num_tokens_across_dp,
260273
with_prefill=with_prefill)
261274

262275
def _init_worker_distributed_environment(self) -> None:

0 commit comments

Comments
 (0)