Skip to content

Commit 3aee657

Browse files
authored
[V1] Aggregate chunked prompt logprobs in model runner (vllm-project#14875)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 9cc6451 commit 3aee657

File tree

7 files changed

+68
-44
lines changed

7 files changed

+68
-44
lines changed

vllm/v1/core/sched/scheduler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,8 +627,7 @@ def update_from_output(
627627

628628
# Get prompt logprobs for this request.
629629
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
630-
# Transmit partial if chunked prefill & prompt logprobs is enabled
631-
if new_token_ids or prompt_logprobs_tensors is not None:
630+
if new_token_ids:
632631
# Add EngineCoreOutput for this Request.
633632
outputs.append(
634633
EngineCoreOutput(
@@ -639,6 +638,9 @@ def update_from_output(
639638
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
640639
stop_reason=request.stop_reason,
641640
events=request.take_events()))
641+
else:
642+
# Invariant: EngineCore returns no partial prefill outputs.
643+
assert not prompt_logprobs_tensors
642644

643645
self.scheduled_req_ids.remove(request.request_id)
644646
if not stopped:

vllm/v1/engine/logprobs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ def _update_prompt_logprobs(
115115
num_prompt_tokens, num_logprobs = logprobs.shape
116116

117117
# Pythonize the torch tensors.
118-
# TODO(rob): experiment with doing this in EngineCore?
119118
prompt_token_ranks = ranks.tolist()
120119
prompt_logprobs = logprobs.tolist()
121120
token_ids = token_ids.tolist()

vllm/v1/engine/output_processor.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def make_request_output(
105105
finished = finish_reason is not None
106106
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
107107

108-
# In follow up, we will switch to invariant where EngineCore
109-
# does not stream partial prefills.
110-
if not finished and (self.is_prefilling or final_only):
108+
if not finished and final_only:
111109
# Only the final output is required in FINAL_ONLY mode.
112110
return None
113111

@@ -285,19 +283,7 @@ def process_outputs(
285283
finish_reason = engine_core_output.finish_reason
286284
stop_reason = engine_core_output.stop_reason
287285

288-
# TODO(andy): prompt logprobs + chunked prefill can
289-
# result in engine core returning an output for a
290-
# partial prefill (in order to send back partial
291-
# prompt logprobs.) This breaks the invariant that
292-
# process_outputs is only operating on engine core
293-
# outputs associated with non-partial completions.
294-
# Currently this is handled by having `is_prefilling`
295-
# check for new decoded tokens, indicating that
296-
# the completion is not partial.
297-
#
298-
# Follow up will aggregate partial prompt logprobs
299-
# in the EngineCore.
300-
req_state.is_prefilling = not new_token_ids
286+
req_state.is_prefilling = False
301287

302288
# 2) Detokenize the token ids into text and perform stop checks.
303289
stop_string = req_state.detokenizer.update(
@@ -306,8 +292,7 @@ def process_outputs(
306292
finish_reason = FinishReason.STOP
307293
stop_reason = stop_string
308294

309-
# 3) Compute sample and prompt logprobs for request,
310-
# if required.
295+
# 3) Compute sample and prompt logprobs for request, if required.
311296
req_state.logprobs_processor.update_from_output(engine_core_output)
312297

313298
# 4) Create and handle RequestOutput objects.

vllm/v1/metrics/stats.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,8 @@ def update_from_output(self, output: "EngineCoreOutput",
100100
num_new_generation_tokens = len(output.new_token_ids)
101101

102102
self.num_generation_tokens += num_new_generation_tokens
103-
if is_prefilling and num_new_generation_tokens > 0:
104-
# TODO(andy): we used to assert that num_new_generation_tokens
105-
# > 0 with an invariant that EngineCore does not stream outputs
106-
# for partially completed prefills (scheduler.update_from_output
107-
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
108-
# When prompt logprobs are enabled, we currently stream out the
109-
# partially completed prompt.
110-
# This will be reverted in a follow up PR and we should re-enable
111-
# this assertion / invariant.
103+
if is_prefilling:
104+
assert num_new_generation_tokens > 0
112105
self.num_prompt_tokens += prompt_len
113106

114107
first_token_latency = self._time_since(req_stats.arrival_time)
@@ -123,16 +116,12 @@ def update_from_output(self, output: "EngineCoreOutput",
123116

124117
# Process the batch-level "new tokens" engine core event
125118
if is_prefilling:
126-
# TODO: re-enable no-output-for-partial-prefills invariant as above
127-
if num_new_generation_tokens > 0:
128-
req_stats.first_token_ts = engine_core_timestamp
119+
req_stats.first_token_ts = engine_core_timestamp
129120
else:
130121
tpot = engine_core_timestamp - req_stats.last_token_ts
131122
self.time_per_output_tokens_iter.append(tpot)
132123

133-
# TODO: re-enable no-output-for-partial-prefills invariant as above
134-
if num_new_generation_tokens > 0:
135-
req_stats.last_token_ts = engine_core_timestamp
124+
req_stats.last_token_ts = engine_core_timestamp
136125

137126
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
138127
is_prefilling: bool, req_stats: RequestStateStats,

vllm/v1/outputs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,25 @@ def tolists(self):
3939
self.selected_token_ranks.tolist(),
4040
)
4141

42+
@staticmethod
43+
def empty_cpu(num_positions: int,
44+
num_tokens_per_position: int) -> "LogprobsTensors":
45+
"""Create empty LogprobsTensors on CPU."""
46+
47+
logprob_token_ids = torch.empty(
48+
(num_positions, num_tokens_per_position),
49+
dtype=torch.int32,
50+
device="cpu")
51+
logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
52+
selected_token_ranks = torch.empty(num_positions,
53+
dtype=torch.int32,
54+
device="cpu")
55+
return LogprobsTensors(
56+
logprob_token_ids=logprob_token_ids,
57+
logprobs=logprobs,
58+
selected_token_ranks=selected_token_ranks,
59+
)
60+
4261

4362
@dataclass
4463
class SamplerOutput:

vllm/v1/worker/gpu_input_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.multimodal import MultiModalKwargs
1212
from vllm.sampling_params import SamplingParams, SamplingType
1313
from vllm.utils import swap_dict_values
14+
from vllm.v1.outputs import LogprobsTensors
1415
from vllm.v1.sample.metadata import SamplingMetadata
1516
from vllm.v1.utils import copy_slice
1617
from vllm.v1.worker.block_table import BlockTable
@@ -197,6 +198,9 @@ def __init__(
197198
# that are currently in the prefill phase.
198199
self.num_prompt_logprobs: dict[str, int] = {}
199200

201+
# To accumulate prompt logprobs tensor chunks across prefill steps.
202+
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
203+
200204
self.logit_bias: list[Optional[dict[int,
201205
float]]] = [None] * max_num_reqs
202206
self.has_allowed_token_ids: set[str] = set()
@@ -362,6 +366,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
362366
self.generators.pop(req_index, None)
363367
self.num_logprobs.pop(req_id, None)
364368
self.num_prompt_logprobs.pop(req_id, None)
369+
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
365370

366371
# LoRA
367372
lora_id = self.request_lora_mapping[req_index]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,6 +1191,7 @@ def _get_prompt_logprobs_dict(
11911191
if not num_prompt_logprobs_dict:
11921192
return {}
11931193

1194+
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
11941195
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
11951196

11961197
# Since prompt logprobs are a rare feature, prioritize simple,
@@ -1206,16 +1207,36 @@ def _get_prompt_logprobs_dict(
12061207
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
12071208
self.device, non_blocking=True)
12081209

1210+
# Set up target LogprobsTensors object.
1211+
logprobs_tensors = in_progress_dict.get(req_id)
1212+
if not logprobs_tensors:
1213+
# Create empty logprobs CPU tensors for the entire prompt.
1214+
# If chunked, we'll copy in slice by slice.
1215+
logprobs_tensors = LogprobsTensors.empty_cpu(
1216+
num_prompt_tokens - 1, num_prompt_logprobs + 1)
1217+
in_progress_dict[req_id] = logprobs_tensors
1218+
12091219
# Determine number of logits to retrieve.
1210-
start_tok = request.num_computed_tokens + 1
1220+
start_idx = request.num_computed_tokens
1221+
start_tok = start_idx + 1
12111222
num_remaining_tokens = num_prompt_tokens - start_tok
1212-
if num_tokens < num_remaining_tokens:
1223+
if num_tokens <= num_remaining_tokens:
12131224
# This is a chunk, more tokens remain.
1225+
# In the == case, there are no more prompt logprobs to produce
1226+
# but we want to defer returning them to the next step where we
1227+
# have new generated tokens to return.
12141228
num_logits = num_tokens
12151229
else:
12161230
# This is the last chunk of prompt tokens to return.
12171231
num_logits = num_remaining_tokens
12181232
completed_prefill_reqs.append(req_id)
1233+
prompt_logprobs_dict[req_id] = logprobs_tensors
1234+
1235+
if num_logits <= 0:
1236+
# This can happen for the final chunk if we prefilled exactly
1237+
# (num_prompt_tokens - 1) tokens for this request in the prior
1238+
# step. There are no more prompt logprobs to produce.
1239+
continue
12191240

12201241
# Get the logits corresponding to this req's prompt tokens.
12211242
# If this is a partial request (i.e. chunked prefill),
@@ -1236,19 +1257,23 @@ def _get_prompt_logprobs_dict(
12361257
logprobs, num_prompt_logprobs, tgt_token_ids)
12371258

12381259
# Transfer GPU->CPU async.
1239-
prompt_logprobs_dict[req_id] = LogprobsTensors(
1240-
token_ids.to("cpu", non_blocking=True),
1241-
logprobs.to("cpu", non_blocking=True),
1242-
ranks.to("cpu", non_blocking=True),
1243-
)
1260+
chunk_slice = slice(start_idx, start_idx + num_logits)
1261+
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
1262+
token_ids, non_blocking=True)
1263+
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
1264+
non_blocking=True)
1265+
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
1266+
ranks, non_blocking=True)
12441267

12451268
# Remove requests that have completed prefill from the batch
12461269
# num_prompt_logprobs_dict.
12471270
for req_id in completed_prefill_reqs:
12481271
del num_prompt_logprobs_dict[req_id]
1272+
del in_progress_dict[req_id]
12491273

12501274
# Must synchronize the non-blocking GPU->CPU transfers.
1251-
torch.cuda.synchronize()
1275+
if prompt_logprobs_dict:
1276+
torch.cuda.synchronize()
12521277

12531278
return prompt_logprobs_dict
12541279

0 commit comments

Comments
 (0)