@@ -227,7 +227,7 @@ def decode_n_tokens(
227
227
)
228
228
input_pos += 1
229
229
new_tokens .append (next_token .clone ())
230
- callback (new_tokens [- 1 ])
230
+ callback (new_tokens [- 1 ], done_generating = _i == num_new_tokens - 2 )
231
231
if need_probs :
232
232
new_probs .append (next_prob .clone ())
233
233
cur_token = next_token .view (1 , - 1 )
@@ -367,6 +367,7 @@ def generate(
367
367
seq = empty
368
368
input_pos = torch .arange (start_pos , T + start_pos , device = device , dtype = torch .int )
369
369
370
+ prefill_t0 = time .perf_counter ()
370
371
next_token = prefill (
371
372
model ,
372
373
prompt .view (1 , - 1 ),
@@ -382,9 +383,10 @@ def generate(
382
383
sequential_prefill = sequential_prefill ,
383
384
** sampling_kwargs ,
384
385
)
385
- # print(f"sizes: {T} {seq[T].shape} {seq.shape} {next_token.shape}")
386
+ time_to_first_token = time . perf_counter () - prefill_t0
386
387
seq [T ] = next_token
387
- callback (next_token .clone ().view (- 1 ))
388
+ # max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
389
+ callback (next_token .clone ().view (- 1 ), done_generating = max_new_tokens <= 2 )
388
390
389
391
num_tokens_generated = 0
390
392
input_pos = torch .tensor ([start_pos + T ], device = device , dtype = torch .int )
@@ -425,7 +427,7 @@ def generate(
425
427
: T + 1 + len (generated_tokens )
426
428
] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
427
429
428
- generate_stats = {"accept_counts" : accept_counts }
430
+ generate_stats = {"accept_counts" : accept_counts , "time_to_first_token" : time_to_first_token }
429
431
return seq , generate_stats
430
432
431
433
@@ -460,9 +462,7 @@ def get_device_info(name: str) -> str:
460
462
return ""
461
463
462
464
463
- def _callback (x , buffer , period_id , done_generating , tokenizer , is_llama3_model ):
464
- if done_generating :
465
- return
465
+ def _callback (x , * , buffer , period_id , done_generating , tokenizer , is_llama3_model ):
466
466
buffer .append (
467
467
tokenizer .decode ([period_id ] + x .tolist ())[1 :]
468
468
) # I think this results in the first output token being dropped from the display which is wrong.
@@ -669,9 +669,8 @@ def _main(
669
669
670
670
buffer = []
671
671
period_id = tokenizer .encode ("." )[0 ]
672
- done_generating = False
673
672
674
- def callback (x ):
673
+ def callback (x , * , done_generating = False ):
675
674
return _callback (
676
675
x ,
677
676
buffer = buffer ,
@@ -685,9 +684,8 @@ def callback(x):
685
684
assert not generator_args .chat_mode
686
685
buffer = [generator_args .prompt ]
687
686
period_id = tokenizer .encode ("." )[0 ]
688
- done_generating = False
689
687
690
- def callback (x ):
688
+ def callback (x , * , done_generating = False ):
691
689
return _callback (
692
690
x ,
693
691
buffer = buffer ,
@@ -753,7 +751,7 @@ def callback(x):
753
751
# continue
754
752
755
753
logging .info (
756
- f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_generated } tokens, { tokens_sec :.02f} tokens/sec, { 1000 / tokens_sec :.02f} ms/token"
754
+ f"Time for inference { i + 1 } : { t :.02f} sec total, time to first token { metrics [ 'time_to_first_token' ]:.02f } , { tokens_generated } tokens, { tokens_sec :.02f} tokens/sec, { 1000 / tokens_sec :.02f} ms/token"
757
755
)
758
756
logging .info (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
759
757
if i == 0 :
0 commit comments