53
53
from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
54
54
from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
55
55
KVCacheSpec )
56
- from vllm .v1 .outputs import EMPTY_MODEL_RUNNER_OUTPUT , ModelRunnerOutput
56
+ from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , LogprobsTensors ,
57
+ ModelRunnerOutput )
57
58
from vllm .v1 .sample .metadata import SamplingMetadata
58
59
from vllm .v1 .sample .sampler import Sampler
59
60
from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
@@ -1506,6 +1507,12 @@ def execute_model(
1506
1507
logprobs_lists = logprobs_tensors .tolists () \
1507
1508
if logprobs_tensors is not None else None
1508
1509
1510
+ # Compute prompt logprobs if needed.
1511
+ prompt_logprobs_dict = self ._get_prompt_logprobs_dict (
1512
+ hidden_states [:num_scheduled_tokens ],
1513
+ scheduler_output ,
1514
+ )
1515
+
1509
1516
# Get the valid generated tokens.
1510
1517
sampled_token_ids = sampler_output .sampled_token_ids
1511
1518
max_gen_len = sampled_token_ids .shape [- 1 ]
@@ -1540,7 +1547,7 @@ def execute_model(
1540
1547
sampled_token_ids = valid_sampled_token_ids ,
1541
1548
spec_token_ids = spec_token_ids ,
1542
1549
logprobs = logprobs_lists ,
1543
- prompt_logprobs_dict = {} ,
1550
+ prompt_logprobs_dict = prompt_logprobs_dict ,
1544
1551
)
1545
1552
else :
1546
1553
model_runner_output = ModelRunnerOutput (
@@ -1549,7 +1556,7 @@ def execute_model(
1549
1556
sampled_token_ids = valid_sampled_token_ids ,
1550
1557
spec_token_ids = spec_token_ids ,
1551
1558
logprobs = logprobs_lists ,
1552
- prompt_logprobs_dict = {} ,
1559
+ prompt_logprobs_dict = prompt_logprobs_dict ,
1553
1560
pooler_output = [],
1554
1561
)
1555
1562
@@ -2149,6 +2156,101 @@ def _generate_mtp_token_ids(
2149
2156
spec_token_ids = draft_token_ids .tolist ()
2150
2157
return spec_token_ids
2151
2158
2159
+ def _get_prompt_logprobs_dict (
2160
+ self ,
2161
+ hidden_states : torch .Tensor ,
2162
+ scheduler_output : "SchedulerOutput" ,
2163
+ ) -> dict [str , Optional [LogprobsTensors ]]:
2164
+ num_prompt_logprobs_dict = self .input_batch .num_prompt_logprobs
2165
+ if not num_prompt_logprobs_dict :
2166
+ return {}
2167
+
2168
+ in_progress_dict = self .input_batch .in_progress_prompt_logprobs_cpu
2169
+ prompt_logprobs_dict : dict [str , Optional [LogprobsTensors ]] = {}
2170
+
2171
+ # Since prompt logprobs are a rare feature, prioritize simple,
2172
+ # maintainable loop over optimal performance.
2173
+ completed_prefill_reqs = []
2174
+ for req_id , num_prompt_logprobs in num_prompt_logprobs_dict .items ():
2175
+
2176
+ num_tokens = scheduler_output .num_scheduled_tokens [req_id ]
2177
+
2178
+ # Get metadata for this request.
2179
+ request = self .requests [req_id ]
2180
+ num_prompt_tokens = len (request .prompt_token_ids )
2181
+ prompt_token_ids = torch .tensor (request .prompt_token_ids ).to (
2182
+ self .device , non_blocking = True )
2183
+
2184
+ # Set up target LogprobsTensors object.
2185
+ logprobs_tensors = in_progress_dict .get (req_id )
2186
+ if not logprobs_tensors :
2187
+ # Create empty logprobs CPU tensors for the entire prompt.
2188
+ # If chunked, we'll copy in slice by slice.
2189
+ logprobs_tensors = LogprobsTensors .empty_cpu (
2190
+ num_prompt_tokens - 1 , num_prompt_logprobs + 1 )
2191
+ in_progress_dict [req_id ] = logprobs_tensors
2192
+
2193
+ # Determine number of logits to retrieve.
2194
+ start_idx = request .num_computed_tokens
2195
+ start_tok = start_idx + 1
2196
+ num_remaining_tokens = num_prompt_tokens - start_tok
2197
+ if num_tokens <= num_remaining_tokens :
2198
+ # This is a chunk, more tokens remain.
2199
+ # In the == case, there are no more prompt logprobs to produce
2200
+ # but we want to defer returning them to the next step where we
2201
+ # have new generated tokens to return.
2202
+ num_logits = num_tokens
2203
+ else :
2204
+ # This is the last chunk of prompt tokens to return.
2205
+ num_logits = num_remaining_tokens
2206
+ completed_prefill_reqs .append (req_id )
2207
+ prompt_logprobs_dict [req_id ] = logprobs_tensors
2208
+
2209
+ if num_logits <= 0 :
2210
+ # This can happen for the final chunk if we prefilled exactly
2211
+ # (num_prompt_tokens - 1) tokens for this request in the prior
2212
+ # step. There are no more prompt logprobs to produce.
2213
+ continue
2214
+
2215
+ # Get the logits corresponding to this req's prompt tokens.
2216
+ # If this is a partial request (i.e. chunked prefill),
2217
+ # then there is prompt logprob generated for each index.
2218
+ req_idx = self .input_batch .req_id_to_index [req_id ]
2219
+ offset = self .query_start_loc_np [req_idx ].item ()
2220
+ prompt_hidden_states = hidden_states [offset :offset + num_logits ]
2221
+ logits = self .model .compute_logits (prompt_hidden_states , None )
2222
+
2223
+ # Get the "target" tokens for each index. For prompt at index i,
2224
+ # the token at prompt index i+1 is the "sampled" token we want
2225
+ # to gather the logprob for.
2226
+ tgt_token_ids = prompt_token_ids [start_tok :start_tok + num_logits ]
2227
+
2228
+ # Compute prompt logprobs.
2229
+ logprobs = self .sampler .compute_logprobs (logits )
2230
+ token_ids , logprobs , ranks = self .sampler .gather_logprobs (
2231
+ logprobs , num_prompt_logprobs , tgt_token_ids )
2232
+
2233
+ # Transfer NPU->CPU async.
2234
+ chunk_slice = slice (start_idx , start_idx + num_logits )
2235
+ logprobs_tensors .logprob_token_ids [chunk_slice ].copy_ (
2236
+ token_ids , non_blocking = True )
2237
+ logprobs_tensors .logprobs [chunk_slice ].copy_ (logprobs ,
2238
+ non_blocking = True )
2239
+ logprobs_tensors .selected_token_ranks [chunk_slice ].copy_ (
2240
+ ranks , non_blocking = True )
2241
+
2242
+ # Remove requests that have completed prefill from the batch
2243
+ # num_prompt_logprobs_dict.
2244
+ for req_id in completed_prefill_reqs :
2245
+ del num_prompt_logprobs_dict [req_id ]
2246
+ del in_progress_dict [req_id ]
2247
+
2248
+ # Must synchronize the non-blocking NPU->CPU transfers.
2249
+ if prompt_logprobs_dict :
2250
+ torch .npu .synchronize ()
2251
+
2252
+ return prompt_logprobs_dict
2253
+
2152
2254
def init_torchair_graph_batch_sizes (self ):
2153
2255
start_graph_batch_size = 4
2154
2256
tp_size = get_tensor_model_parallel_world_size ()
0 commit comments