Skip to content

Commit 5f43916

Browse files
authored
[PromptLogprobs][V1] Support prompt logprobs to fix ceval accuracy in V1 (#1483)
### What this PR does / why we need it? Support prompt logprobs in V1. This also enable lm_eval to test accuracy on V1 ### Does this PR introduce _any_ user-facing change? support prompt logprobs output ### How was this patch tested? CI passed with accuracy test. Using lm_eval, which use prompt logprobs as output to test accuracy, to test: ```python VLLM_USE_V1=1 lm_eval \ --model vllm \ --model_args pretrained=Qwen/Qwen2.5-7B-Instruct,max_model_len=4096,block_size=4 \ --tasks ceval-valid_computer_network \ --batch_size 8 ``` After this pr, the accuracy test results of `Qwen/Qwen2.5-7B-Instruct` on V1 is: ```bash | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| |----------------------------|------:|------|-----:|--------|---|-----:|---|-----:| |ceval-valid_computer_network| 2|none | 0|acc |↑ |0.7368|± |0.1038| | | |none | 0|acc_norm|↑ |0.7368|± |0.1038| ``` Closes: #1043 Signed-off-by: MengqingCao <cmq0113@163.com>
1 parent 99e6855 commit 5f43916

File tree

1 file changed

+105
-3
lines changed

1 file changed

+105
-3
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@
5353
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
5454
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
5555
KVCacheSpec)
56-
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
56+
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
57+
ModelRunnerOutput)
5758
from vllm.v1.sample.metadata import SamplingMetadata
5859
from vllm.v1.sample.sampler import Sampler
5960
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -1506,6 +1507,12 @@ def execute_model(
15061507
logprobs_lists = logprobs_tensors.tolists() \
15071508
if logprobs_tensors is not None else None
15081509

1510+
# Compute prompt logprobs if needed.
1511+
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
1512+
hidden_states[:num_scheduled_tokens],
1513+
scheduler_output,
1514+
)
1515+
15091516
# Get the valid generated tokens.
15101517
sampled_token_ids = sampler_output.sampled_token_ids
15111518
max_gen_len = sampled_token_ids.shape[-1]
@@ -1540,7 +1547,7 @@ def execute_model(
15401547
sampled_token_ids=valid_sampled_token_ids,
15411548
spec_token_ids=spec_token_ids,
15421549
logprobs=logprobs_lists,
1543-
prompt_logprobs_dict={},
1550+
prompt_logprobs_dict=prompt_logprobs_dict,
15441551
)
15451552
else:
15461553
model_runner_output = ModelRunnerOutput(
@@ -1549,7 +1556,7 @@ def execute_model(
15491556
sampled_token_ids=valid_sampled_token_ids,
15501557
spec_token_ids=spec_token_ids,
15511558
logprobs=logprobs_lists,
1552-
prompt_logprobs_dict={},
1559+
prompt_logprobs_dict=prompt_logprobs_dict,
15531560
pooler_output=[],
15541561
)
15551562

@@ -2149,6 +2156,101 @@ def _generate_mtp_token_ids(
21492156
spec_token_ids = draft_token_ids.tolist()
21502157
return spec_token_ids
21512158

2159+
def _get_prompt_logprobs_dict(
2160+
self,
2161+
hidden_states: torch.Tensor,
2162+
scheduler_output: "SchedulerOutput",
2163+
) -> dict[str, Optional[LogprobsTensors]]:
2164+
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
2165+
if not num_prompt_logprobs_dict:
2166+
return {}
2167+
2168+
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
2169+
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
2170+
2171+
# Since prompt logprobs are a rare feature, prioritize simple,
2172+
# maintainable loop over optimal performance.
2173+
completed_prefill_reqs = []
2174+
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
2175+
2176+
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
2177+
2178+
# Get metadata for this request.
2179+
request = self.requests[req_id]
2180+
num_prompt_tokens = len(request.prompt_token_ids)
2181+
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
2182+
self.device, non_blocking=True)
2183+
2184+
# Set up target LogprobsTensors object.
2185+
logprobs_tensors = in_progress_dict.get(req_id)
2186+
if not logprobs_tensors:
2187+
# Create empty logprobs CPU tensors for the entire prompt.
2188+
# If chunked, we'll copy in slice by slice.
2189+
logprobs_tensors = LogprobsTensors.empty_cpu(
2190+
num_prompt_tokens - 1, num_prompt_logprobs + 1)
2191+
in_progress_dict[req_id] = logprobs_tensors
2192+
2193+
# Determine number of logits to retrieve.
2194+
start_idx = request.num_computed_tokens
2195+
start_tok = start_idx + 1
2196+
num_remaining_tokens = num_prompt_tokens - start_tok
2197+
if num_tokens <= num_remaining_tokens:
2198+
# This is a chunk, more tokens remain.
2199+
# In the == case, there are no more prompt logprobs to produce
2200+
# but we want to defer returning them to the next step where we
2201+
# have new generated tokens to return.
2202+
num_logits = num_tokens
2203+
else:
2204+
# This is the last chunk of prompt tokens to return.
2205+
num_logits = num_remaining_tokens
2206+
completed_prefill_reqs.append(req_id)
2207+
prompt_logprobs_dict[req_id] = logprobs_tensors
2208+
2209+
if num_logits <= 0:
2210+
# This can happen for the final chunk if we prefilled exactly
2211+
# (num_prompt_tokens - 1) tokens for this request in the prior
2212+
# step. There are no more prompt logprobs to produce.
2213+
continue
2214+
2215+
# Get the logits corresponding to this req's prompt tokens.
2216+
# If this is a partial request (i.e. chunked prefill),
2217+
# then there is prompt logprob generated for each index.
2218+
req_idx = self.input_batch.req_id_to_index[req_id]
2219+
offset = self.query_start_loc_np[req_idx].item()
2220+
prompt_hidden_states = hidden_states[offset:offset + num_logits]
2221+
logits = self.model.compute_logits(prompt_hidden_states, None)
2222+
2223+
# Get the "target" tokens for each index. For prompt at index i,
2224+
# the token at prompt index i+1 is the "sampled" token we want
2225+
# to gather the logprob for.
2226+
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
2227+
2228+
# Compute prompt logprobs.
2229+
logprobs = self.sampler.compute_logprobs(logits)
2230+
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
2231+
logprobs, num_prompt_logprobs, tgt_token_ids)
2232+
2233+
# Transfer NPU->CPU async.
2234+
chunk_slice = slice(start_idx, start_idx + num_logits)
2235+
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
2236+
token_ids, non_blocking=True)
2237+
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
2238+
non_blocking=True)
2239+
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
2240+
ranks, non_blocking=True)
2241+
2242+
# Remove requests that have completed prefill from the batch
2243+
# num_prompt_logprobs_dict.
2244+
for req_id in completed_prefill_reqs:
2245+
del num_prompt_logprobs_dict[req_id]
2246+
del in_progress_dict[req_id]
2247+
2248+
# Must synchronize the non-blocking NPU->CPU transfers.
2249+
if prompt_logprobs_dict:
2250+
torch.npu.synchronize()
2251+
2252+
return prompt_logprobs_dict
2253+
21522254
def init_torchair_graph_batch_sizes(self):
21532255
start_graph_batch_size = 4
21542256
tp_size = get_tensor_model_parallel_world_size()

0 commit comments

Comments
 (0)