Skip to content

[Feature] Optimize forward metadata collection across dp ranks #1593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
64 changes: 39 additions & 25 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
import numpy.typing as npt
import torch
import torch._dynamo.cache_size
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ReduceOp
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig
Expand Down Expand Up @@ -629,16 +627,16 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
self.input_batch.refresh_sampling_metadata()

def _get_forward_metadata_across_dp(
self, total_num_scheduled_tokens: int,
with_prefill: bool) -> tuple[int, bool]:
forward_metadata = torch.tensor(
[total_num_scheduled_tokens, with_prefill],
device="cpu",
dtype=torch.int32)
dist.all_reduce(forward_metadata,
op=ReduceOp.MAX,
group=get_dp_group().cpu_group)
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
self, num_tokens: int,
with_prefill: bool) -> tuple[torch.Tensor, bool]:
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
device="npu",
dtype=torch.int32).unsqueeze(0)
global_forward_metadata = get_dp_group().all_gather(
local_forward_metadata, dim=0)
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
with_prefill = bool(global_forward_metadata[:, 1].any())
return num_tokens_across_dp, with_prefill

def get_eagle_atten_dict(
self,
Expand Down Expand Up @@ -1107,23 +1105,35 @@ def _process_reqs(
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
]

num_tokens_across_dp = None
if self.dp_size > 1:
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
total_num_scheduled_tokens, with_prefill)
num_tokens_across_dp, with_prefill = \
self._get_forward_metadata_across_dp(num_input_tokens,
with_prefill)
max_num_tokens = int(num_tokens_across_dp.max().item())
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill

# Add graph_pad_size here
if self.torchair_graph_enabled and not with_prefill:
if self.dp_size > 1:
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
else:
padded_batch_size = self.select_torchair_padded_batch_size(
total_num_scheduled_tokens)
max_num_tokens = (max_num_tokens
if self.dp_size > 1 else num_input_tokens)
padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens)
graph_pad_size = padded_batch_size - total_num_scheduled_tokens

extra_builder_kwargs['graph_pad_size'] = graph_pad_size
# If torchair graph is enabled and in decode mode, the dummy run
# batch size is set to the selected graph size.
dummy_num_tokens = padded_batch_size
else:
# If torchair graph is not enabled, or if with_prefill is True, the
# dummy run batch size is set to 1.
dummy_num_tokens = 1

if self.dp_size > 1:
assert num_tokens_across_dp is not None
num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
dummy_num_tokens)

if self.vllm_config.model_config.use_mla:
attn_metadata = self.attn_metadata_builder.build( # type: ignore
Expand Down Expand Up @@ -1196,7 +1206,8 @@ def _process_reqs(
# Run forward pass
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp):
with ProfileExecuteDuration().capture_async("forward"):
model_kwargs = {}
if self.torchair_graph_enabled:
Expand Down Expand Up @@ -1819,6 +1830,7 @@ def _dummy_run(
is_compile: bool = False,
with_prefill: bool = True,
skip_attn: bool = True,
num_tokens_across_dp: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
Expand Down Expand Up @@ -1871,9 +1883,11 @@ def _dummy_run(
for k, v in self.intermediate_tensors.items()
})

with set_forward_context(None,
self.vllm_config,
num_tokens=num_tokens):
with set_forward_context(
None,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
Expand Down
28 changes: 21 additions & 7 deletions vllm_ascend/worker/worker_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,30 @@ def pin_lora(self, lora_id: int) -> bool:

def execute_dummy_batch(self) -> None:
runner = self.model_runner
max_num_tokens = 1
with_prefill = False
if runner.dp_size > 1:
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
max_num_tokens, with_prefill)
if runner.dp_size <= 1:
raise ValueError(
"Dummy batch execution should only be "
"performed with data parallelism enabled, but got "
f"dp_size={runner.dp_size}.")

# If torchair graph is enabled, notify the other DP ranks that this is a
# dummy run by using '-1' as a flag for num_tokens. This will be
# replaced with the final determined graph size before the forward pass.
num_tokens_across_dp, with_prefill = \
runner._get_forward_metadata_across_dp(-1, False)

if runner.torchair_graph_enabled and not with_prefill:
max_num_tokens = runner.select_torchair_padded_batch_size(
max_num_tokens = int(num_tokens_across_dp.max().item())
num_tokens = runner.select_torchair_padded_batch_size(
max_num_tokens)
runner._dummy_run(max_num_tokens,
else:
num_tokens = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If graph mode is off, a dummy run only needs to be executed; computational requirements are not a factor.


num_tokens_across_dp.masked_fill_(num_tokens_across_dp == -1,
num_tokens)
runner._dummy_run(num_tokens,
is_compile=False,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill)

def _init_worker_distributed_environment(self) -> None:
Expand Down
Loading