@@ -60,20 +60,25 @@ def __init__(
60
60
enable_prompt_tokens_details : bool = False ,
61
61
enable_force_include_usage : bool = False ,
62
62
):
63
- super ().__init__ (engine_client = engine_client ,
64
- model_config = model_config ,
65
- models = models ,
66
- request_logger = request_logger ,
67
- return_tokens_as_token_ids = return_tokens_as_token_ids ,
68
- enable_force_include_usage = enable_force_include_usage )
63
+ super ().__init__ (
64
+ engine_client = engine_client ,
65
+ model_config = model_config ,
66
+ models = models ,
67
+ request_logger = request_logger ,
68
+ return_tokens_as_token_ids = return_tokens_as_token_ids ,
69
+ enable_force_include_usage = enable_force_include_usage ,
70
+ )
69
71
self .enable_prompt_tokens_details = enable_prompt_tokens_details
70
72
self .default_sampling_params = (
71
73
self .model_config .get_diff_sampling_param ())
72
74
if self .default_sampling_params :
73
75
source = self .model_config .generation_config
74
76
source = "model" if source == "auto" else source
75
- logger .info ("Using default completion sampling params from %s: %s" ,
76
- source , self .default_sampling_params )
77
+ logger .info (
78
+ "Using default completion sampling params from %s: %s" ,
79
+ source ,
80
+ self .default_sampling_params ,
81
+ )
77
82
78
83
async def create_completion (
79
84
self ,
@@ -172,23 +177,28 @@ async def create_completion(
172
177
max_model_len = self .max_model_len ,
173
178
request = request ,
174
179
input_length = input_length ,
175
- default_sampling_params = self .default_sampling_params )
180
+ default_sampling_params = self .default_sampling_params ,
181
+ )
176
182
177
183
if request .use_beam_search :
178
184
sampling_params = request .to_beam_search_params (
179
185
max_tokens , self .default_sampling_params )
180
186
else :
181
187
sampling_params = request .to_sampling_params (
182
- max_tokens , self .model_config .logits_processor_pattern ,
183
- self .default_sampling_params )
188
+ max_tokens ,
189
+ self .model_config .logits_processor_pattern ,
190
+ self .default_sampling_params ,
191
+ )
184
192
185
193
request_id_item = f"{ request_id } -{ i } "
186
194
187
- self ._log_inputs (request_id_item ,
188
- request_prompts [i ],
189
- params = sampling_params ,
190
- lora_request = lora_request ,
191
- prompt_adapter_request = prompt_adapter_request )
195
+ self ._log_inputs (
196
+ request_id_item ,
197
+ request_prompts [i ],
198
+ params = sampling_params ,
199
+ lora_request = lora_request ,
200
+ prompt_adapter_request = prompt_adapter_request ,
201
+ )
192
202
193
203
trace_headers = (None if raw_request is None else await
194
204
self ._get_trace_headers (raw_request .headers ))
@@ -245,7 +255,8 @@ async def create_completion(
245
255
num_prompts = num_prompts ,
246
256
tokenizer = tokenizer ,
247
257
request_metadata = request_metadata ,
248
- enable_force_include_usage = self .enable_force_include_usage )
258
+ enable_force_include_usage = self .enable_force_include_usage ,
259
+ )
249
260
250
261
# Non-streaming response
251
262
final_res_batch : list [Optional [RequestOutput ]] = [None ] * num_prompts
@@ -321,10 +332,10 @@ async def completion_stream_generator(
321
332
322
333
stream_options = request .stream_options
323
334
if stream_options :
324
- include_usage = stream_options .include_usage or \
325
- enable_force_include_usage
326
- include_continuous_usage = include_usage and \
327
- stream_options .continuous_usage_stats
335
+ include_usage = ( stream_options .include_usage
336
+ or enable_force_include_usage )
337
+ include_continuous_usage = ( include_usage and
338
+ stream_options .continuous_usage_stats )
328
339
else :
329
340
include_usage , include_continuous_usage = False , False
330
341
@@ -370,7 +381,8 @@ async def completion_stream_generator(
370
381
# echo the prompt and first token
371
382
delta_text = prompt_text + output .text
372
383
delta_token_ids = [
373
- * prompt_token_ids , * output .token_ids
384
+ * prompt_token_ids ,
385
+ * output .token_ids ,
374
386
]
375
387
out_logprobs = [
376
388
* (prompt_logprobs or []),
@@ -383,8 +395,8 @@ async def completion_stream_generator(
383
395
delta_token_ids = output .token_ids
384
396
out_logprobs = output .logprobs
385
397
386
- if not delta_text and not delta_token_ids \
387
- and not previous_num_tokens [i ]:
398
+ if ( not delta_text and not delta_token_ids
399
+ and not previous_num_tokens [i ]) :
388
400
# Chunked prefill case, don't return empty chunks
389
401
continue
390
402
@@ -420,7 +432,8 @@ async def completion_stream_generator(
420
432
finish_reason = finish_reason ,
421
433
stop_reason = stop_reason ,
422
434
)
423
- ])
435
+ ],
436
+ )
424
437
if include_continuous_usage :
425
438
prompt_tokens = num_prompt_tokens [prompt_idx ]
426
439
completion_tokens = previous_num_tokens [i ]
@@ -438,7 +451,8 @@ async def completion_stream_generator(
438
451
final_usage_info = UsageInfo (
439
452
prompt_tokens = total_prompt_tokens ,
440
453
completion_tokens = total_completion_tokens ,
441
- total_tokens = total_prompt_tokens + total_completion_tokens )
454
+ total_tokens = total_prompt_tokens + total_completion_tokens ,
455
+ )
442
456
443
457
if self .enable_prompt_tokens_details and num_cached_tokens :
444
458
final_usage_info .prompt_tokens_details = PromptTokenUsageInfo (
@@ -452,8 +466,8 @@ async def completion_stream_generator(
452
466
choices = [],
453
467
usage = final_usage_info ,
454
468
)
455
- final_usage_data = ( final_usage_chunk .model_dump_json (
456
- exclude_unset = False , exclude_none = True ))
469
+ final_usage_data = final_usage_chunk .model_dump_json (
470
+ exclude_unset = False , exclude_none = True )
457
471
yield f"data: { final_usage_data } \n \n "
458
472
459
473
# report to FastAPI middleware aggregate usage across all choices
@@ -478,8 +492,10 @@ def request_output_to_completion_response(
478
492
choices : list [CompletionResponseChoice ] = []
479
493
num_prompt_tokens = 0
480
494
num_generated_tokens = 0
481
-
495
+ kv_transfer_params = None
496
+ last_final_res = None
482
497
for final_res in final_res_batch :
498
+ last_final_res = final_res
483
499
prompt_token_ids = final_res .prompt_token_ids
484
500
assert prompt_token_ids is not None
485
501
prompt_logprobs = clamp_prompt_logprobs (final_res .prompt_logprobs )
@@ -548,19 +564,22 @@ def request_output_to_completion_response(
548
564
total_tokens = num_prompt_tokens + num_generated_tokens ,
549
565
)
550
566
551
- if self .enable_prompt_tokens_details and final_res .num_cached_tokens :
567
+ if (self .enable_prompt_tokens_details and last_final_res
568
+ and last_final_res .num_cached_tokens ):
552
569
usage .prompt_tokens_details = PromptTokenUsageInfo (
553
- cached_tokens = final_res .num_cached_tokens )
570
+ cached_tokens = last_final_res .num_cached_tokens )
554
571
555
572
request_metadata .final_usage_info = usage
556
-
573
+ if final_res_batch :
574
+ kv_transfer_params = final_res_batch [0 ].kv_transfer_params
557
575
return CompletionResponse (
558
576
id = request_id ,
559
577
created = created_time ,
560
578
model = model_name ,
561
579
choices = choices ,
562
580
usage = usage ,
563
- kv_transfer_params = final_res_batch [0 ].kv_transfer_params )
581
+ kv_transfer_params = kv_transfer_params ,
582
+ )
564
583
565
584
def _create_completion_logprobs (
566
585
self ,
@@ -579,8 +598,9 @@ def _create_completion_logprobs(
579
598
580
599
last_token_len = 0
581
600
582
- should_return_as_token_id = return_as_token_id if \
583
- return_as_token_id is not None else self .return_tokens_as_token_ids
601
+ should_return_as_token_id = (return_as_token_id
602
+ if return_as_token_id is not None else
603
+ self .return_tokens_as_token_ids )
584
604
for i , token_id in enumerate (token_ids ):
585
605
step_top_logprobs = top_logprobs [i ]
586
606
if step_top_logprobs is None :
@@ -612,10 +632,12 @@ def _create_completion_logprobs(
612
632
out_top_logprobs .append ({
613
633
# Convert float("-inf") to the
614
634
# JSON-serializable float that OpenAI uses
615
- self ._get_decoded_token (top_lp [1 ],
616
- top_lp [0 ],
617
- tokenizer ,
618
- return_as_token_id = should_return_as_token_id ):
635
+ self ._get_decoded_token (
636
+ top_lp [1 ],
637
+ top_lp [0 ],
638
+ tokenizer ,
639
+ return_as_token_id = should_return_as_token_id ,
640
+ ):
619
641
max (top_lp [1 ].logprob , - 9999.0 )
620
642
for i , top_lp in enumerate (step_top_logprobs .items ())
621
643
if num_output_top_logprobs >= i
0 commit comments