Skip to content

Commit 5ff3222

Browse files
larryliu0820malfet
authored andcommitted
Add time to first token and fix no print out issue for max_new_tokens < 4 (pytorch#859)
1 parent c6b72fc commit 5ff3222

File tree

1 file changed

+10
-12
lines changed

1 file changed

+10
-12
lines changed

generate.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def decode_n_tokens(
227227
)
228228
input_pos += 1
229229
new_tokens.append(next_token.clone())
230-
callback(new_tokens[-1])
230+
callback(new_tokens[-1], done_generating=_i==num_new_tokens-2)
231231
if need_probs:
232232
new_probs.append(next_prob.clone())
233233
cur_token = next_token.view(1, -1)
@@ -367,6 +367,7 @@ def generate(
367367
seq = empty
368368
input_pos = torch.arange(start_pos, T + start_pos, device=device, dtype=torch.int)
369369

370+
prefill_t0 = time.perf_counter()
370371
next_token = prefill(
371372
model,
372373
prompt.view(1, -1),
@@ -382,9 +383,10 @@ def generate(
382383
sequential_prefill=sequential_prefill,
383384
**sampling_kwargs,
384385
)
385-
# print(f"sizes: {T} {seq[T].shape} {seq.shape} {next_token.shape}")
386+
time_to_first_token = time.perf_counter() - prefill_t0
386387
seq[T] = next_token
387-
callback(next_token.clone().view(-1))
388+
# max_new_tokens <= 2 means we are effectively not calling decode_n_tokens().
389+
callback(next_token.clone().view(-1), done_generating=max_new_tokens<=2)
388390

389391
num_tokens_generated = 0
390392
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
@@ -425,7 +427,7 @@ def generate(
425427
: T + 1 + len(generated_tokens)
426428
] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
427429

428-
generate_stats = {"accept_counts": accept_counts}
430+
generate_stats = {"accept_counts": accept_counts, "time_to_first_token": time_to_first_token}
429431
return seq, generate_stats
430432

431433

@@ -460,9 +462,7 @@ def get_device_info(name: str) -> str:
460462
return ""
461463

462464

463-
def _callback(x, buffer, period_id, done_generating, tokenizer, is_llama3_model):
464-
if done_generating:
465-
return
465+
def _callback(x, *, buffer, period_id, done_generating, tokenizer, is_llama3_model):
466466
buffer.append(
467467
tokenizer.decode([period_id] + x.tolist())[1:]
468468
) # I think this results in the first output token being dropped from the display which is wrong.
@@ -669,9 +669,8 @@ def _main(
669669

670670
buffer = []
671671
period_id = tokenizer.encode(".")[0]
672-
done_generating = False
673672

674-
def callback(x):
673+
def callback(x, *, done_generating=False):
675674
return _callback(
676675
x,
677676
buffer=buffer,
@@ -685,9 +684,8 @@ def callback(x):
685684
assert not generator_args.chat_mode
686685
buffer = [generator_args.prompt]
687686
period_id = tokenizer.encode(".")[0]
688-
done_generating = False
689687

690-
def callback(x):
688+
def callback(x, *, done_generating=False):
691689
return _callback(
692690
x,
693691
buffer=buffer,
@@ -753,7 +751,7 @@ def callback(x):
753751
# continue
754752

755753
logging.info(
756-
f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated} tokens, {tokens_sec:.02f} tokens/sec, {1000 / tokens_sec:.02f} ms/token"
754+
f"Time for inference {i + 1}: {t:.02f} sec total, time to first token {metrics['time_to_first_token']:.02f}, {tokens_generated} tokens, {tokens_sec:.02f} tokens/sec, {1000 / tokens_sec:.02f} ms/token"
757755
)
758756
logging.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
759757
if i == 0:

0 commit comments

Comments
 (0)