Skip to content

Commit d6fd3a3

Browse files
[Misc] reuse num_tokens_across_dp of get_dp_padding to avoid unnecessary dp all reduce in set_forward_context (#18935)
Signed-off-by: Tyler Michael Smith <tysmith@redhat.com> Co-authored-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
1 parent 432ec99 commit d6fd3a3

File tree

2 files changed

+47
-18
lines changed

2 files changed

+47
-18
lines changed

vllm/forward_context.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,12 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int,
4747
return num_tokens_tensor
4848

4949
@staticmethod
50-
def make(parallel_config: ParallelConfig, attn_metadata: Any,
51-
num_tokens: int) -> "DPMetadata":
50+
def make(
51+
parallel_config: ParallelConfig,
52+
attn_metadata: Any,
53+
num_tokens: int,
54+
num_tokens_across_dp: Optional[torch.Tensor] = None
55+
) -> "DPMetadata":
5256

5357
assert parallel_config.data_parallel_size > 1
5458
dp_size = parallel_config.data_parallel_size
@@ -62,10 +66,15 @@ def make(parallel_config: ParallelConfig, attn_metadata: Any,
6266
# for v1 attention backends or no attn_metadata
6367
batchsize = num_tokens
6468

65-
num_tokens_tensor = DPMetadata.num_tokens_across_dp(
66-
batchsize, dp_size, dp_rank)
67-
max_tokens_across_dp_cpu = torch.max(num_tokens_tensor)
68-
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0)
69+
# If num_tokens_across_dp is None, it will be computed by all_reduce
70+
# Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
71+
assert (num_tokens_across_dp is None
72+
or num_tokens_across_dp[dp_rank] == batchsize)
73+
if num_tokens_across_dp is None:
74+
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
75+
batchsize, dp_size, dp_rank)
76+
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp)
77+
cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_across_dp, dim=0)
6978
return DPMetadata(max_tokens_across_dp_cpu, cu_tokens_across_dp_cpu)
7079

7180

@@ -101,7 +110,8 @@ def get_forward_context() -> ForwardContext:
101110
def set_forward_context(attn_metadata: Any,
102111
vllm_config: VllmConfig,
103112
virtual_engine: int = 0,
104-
num_tokens: Optional[int] = None):
113+
num_tokens: Optional[int] = None,
114+
num_tokens_across_dp: Optional[torch.Tensor] = None):
105115
"""A context manager that stores the current forward context,
106116
can be attention metadata, etc.
107117
Here we can inject common logic for every model forward pass.
@@ -114,7 +124,8 @@ def set_forward_context(attn_metadata: Any,
114124
if vllm_config.parallel_config.data_parallel_size > 1 and (
115125
attn_metadata is not None or num_tokens is not None):
116126
dp_metadata = DPMetadata.make(vllm_config.parallel_config,
117-
attn_metadata, num_tokens or 0)
127+
attn_metadata, num_tokens or 0,
128+
num_tokens_across_dp)
118129

119130
global _forward_context
120131
prev_context = _forward_context

vllm/v1/worker/gpu_model_runner.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,17 +1111,30 @@ def sync_and_slice_intermediate_tensors(
11111111
for k, v in self.intermediate_tensors.items()
11121112
})
11131113

1114-
def get_dp_padding(self, num_tokens: int):
1114+
def get_dp_padding(self,
1115+
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
11151116
dp_size = self.vllm_config.parallel_config.data_parallel_size
11161117
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
1117-
if dp_size == 1:
1118+
1119+
# For DP: Don't pad when setting enforce_eager.
1120+
# This lets us set enforce_eager on the prefiller in a P/D setup and
1121+
# still use CUDA graphs (enabled by this padding) on the decoder.
1122+
#
1123+
# TODO(tms) : There are many cases where padding is enabled for
1124+
# prefills, causing unnecessary and excessive padding of activations.
1125+
1126+
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
11181127
# Early exit.
1119-
return 0
1128+
return 0, None
11201129

11211130
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
11221131
num_tokens, dp_size, dp_rank)
11231132
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
1124-
return max_tokens_across_dp_cpu - num_tokens
1133+
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
1134+
dp_size,
1135+
device="cpu",
1136+
dtype=torch.int32)
1137+
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
11251138

11261139
@torch.inference_mode()
11271140
def execute_model(
@@ -1161,7 +1174,8 @@ def execute_model(
11611174
num_input_tokens = num_scheduled_tokens
11621175

11631176
# Padding for DP
1164-
num_input_tokens += self.get_dp_padding(num_input_tokens)
1177+
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
1178+
num_input_tokens += num_pad
11651179

11661180
# _prepare_inputs may reorder the batch, so we must gather multi
11671181
# modal outputs after that to ensure the correct order
@@ -1208,7 +1222,8 @@ def execute_model(
12081222
# Use persistent buffers for CUDA graphs.
12091223
with set_forward_context(attn_metadata,
12101224
self.vllm_config,
1211-
num_tokens=num_input_tokens):
1225+
num_tokens=num_input_tokens,
1226+
num_tokens_across_dp=num_tokens_across_dp):
12121227
self.maybe_setup_kv_connector(scheduler_output)
12131228

12141229
model_output = self.model(
@@ -1681,7 +1696,8 @@ def _dummy_run(
16811696
) -> torch.Tensor:
16821697

16831698
# Padding for DP
1684-
num_tokens += self.get_dp_padding(num_tokens)
1699+
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
1700+
num_tokens += num_pad
16851701

16861702
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
16871703
# for dummy run with LoRA so that the num_reqs collectively
@@ -1747,9 +1763,11 @@ def _dummy_run(
17471763
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
17481764
num_tokens, None, False)
17491765

1750-
with set_forward_context(attn_metadata,
1751-
self.vllm_config,
1752-
num_tokens=num_tokens):
1766+
with set_forward_context(
1767+
attn_metadata,
1768+
self.vllm_config,
1769+
num_tokens=num_tokens,
1770+
num_tokens_across_dp=num_tokens_across_dp):
17531771
outputs = model(
17541772
input_ids=input_ids,
17551773
positions=positions,

0 commit comments

Comments
 (0)