23
23
CompletionResponseStreamChoice ,
24
24
CompletionStreamResponse ,
25
25
ErrorResponse ,
26
+ PromptTokenUsageInfo ,
26
27
RequestResponseMetadata ,
27
28
UsageInfo )
28
29
# yapf: enable
@@ -52,12 +53,14 @@ def __init__(
52
53
* ,
53
54
request_logger : Optional [RequestLogger ],
54
55
return_tokens_as_token_ids : bool = False ,
56
+ enable_prompt_tokens_details : bool = False ,
55
57
):
56
58
super ().__init__ (engine_client = engine_client ,
57
59
model_config = model_config ,
58
60
models = models ,
59
61
request_logger = request_logger ,
60
62
return_tokens_as_token_ids = return_tokens_as_token_ids )
63
+ self .enable_prompt_tokens_details = enable_prompt_tokens_details
61
64
self .default_sampling_params = (
62
65
self .model_config .get_diff_sampling_param ())
63
66
if self .default_sampling_params :
@@ -297,6 +300,7 @@ async def completion_stream_generator(
297
300
previous_num_tokens = [0 ] * num_choices * num_prompts
298
301
has_echoed = [False ] * num_choices * num_prompts
299
302
num_prompt_tokens = [0 ] * num_prompts
303
+ num_cached_tokens = [0 ] * num_prompts
300
304
301
305
stream_options = request .stream_options
302
306
if stream_options :
@@ -311,11 +315,15 @@ async def completion_stream_generator(
311
315
prompt_token_ids = res .prompt_token_ids
312
316
prompt_logprobs = res .prompt_logprobs
313
317
prompt_text = res .prompt
318
+ cached_tokens = res .num_cached_tokens
314
319
315
320
# Prompt details are excluded from later streamed outputs
316
321
if prompt_token_ids is not None :
317
322
num_prompt_tokens [prompt_idx ] = len (prompt_token_ids )
318
323
324
+ if cached_tokens is not None :
325
+ num_cached_tokens [prompt_idx ] = cached_tokens
326
+
319
327
delta_token_ids : GenericSequence [int ]
320
328
out_logprobs : Optional [GenericSequence [Optional [dict [
321
329
int , Logprob ]]]]
@@ -402,10 +410,15 @@ async def completion_stream_generator(
402
410
403
411
total_prompt_tokens = sum (num_prompt_tokens )
404
412
total_completion_tokens = sum (previous_num_tokens )
413
+ total_cached_tokens = sum (num_cached_tokens )
405
414
final_usage_info = UsageInfo (
406
415
prompt_tokens = total_prompt_tokens ,
407
416
completion_tokens = total_completion_tokens ,
408
417
total_tokens = total_prompt_tokens + total_completion_tokens )
418
+ if self .enable_prompt_tokens_details and total_cached_tokens :
419
+ final_usage_info .prompt_tokens_details = PromptTokenUsageInfo (
420
+ cached_tokens = total_cached_tokens
421
+ )
409
422
410
423
if include_usage :
411
424
final_usage_chunk = CompletionStreamResponse (
@@ -510,6 +523,9 @@ def request_output_to_completion_response(
510
523
completion_tokens = num_generated_tokens ,
511
524
total_tokens = num_prompt_tokens + num_generated_tokens ,
512
525
)
526
+ if self .enable_prompt_tokens_details and final_res_batch [0 ].num_cached_tokens :
527
+ usage .prompt_tokens_details = PromptTokenUsageInfo (
528
+ cached_tokens = final_res_batch [0 ].num_cached_tokens )
513
529
514
530
request_metadata .final_usage_info = usage
515
531
0 commit comments