We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6100e0d commit bfec114Copy full SHA for bfec114
vllm_ascend/worker/model_runner_v1.py
@@ -625,9 +625,9 @@ def _get_forward_metadata_across_dp(
625
with_prefill: bool) -> tuple[torch.Tensor, bool]:
626
local_forward_metadata = torch.tensor([num_tokens, with_prefill],
627
device="npu",
628
- dtype=torch.int32)
+ dtype=torch.int32).unsqueeze(0)
629
global_forward_metadata = get_dp_group().all_gather(
630
- local_forward_metadata)
+ local_forward_metadata, dim=0)
631
num_tokens_across_dp = global_forward_metadata[:, 0].cpu()
632
with_prefill = bool(global_forward_metadata[:, 1].any())
633
return num_tokens_across_dp, with_prefill
0 commit comments