Skip to content

Commit 40e80e4

Browse files
authored
Basic performance logging for Llama python runner (#8862)
Add basic performance metrics to native llama runner
1 parent 1ce7ed7 commit 40e80e4

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

examples/models/llama/runner/generation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import time
78
from abc import ABC, abstractmethod
89
from typing import List, Optional
910

@@ -97,6 +98,7 @@ def generate( # noqa: C901
9798
pos_base: int = 0,
9899
) -> List[int]:
99100
# Prefill
101+
prefill_start = time.time()
100102
logits = self.forward(
101103
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
102104
input_pos=(
@@ -105,11 +107,13 @@ def generate( # noqa: C901
105107
else None
106108
),
107109
)
110+
prefill_time = time.time() - prefill_start
108111

109112
current_token = next_token(logits, temperature, top_p)
110113
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
111114
tokens = prompt_tokens + [current_token]
112115

116+
generate_start = time.time()
113117
while len(tokens) < max_seq_len:
114118
if self.use_kv_cache:
115119
logits = self.forward(
@@ -140,6 +144,10 @@ def generate( # noqa: C901
140144
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
141145
print("\n")
142146

147+
generate_time = time.time() - generate_start
148+
print(f"Prefill time: {prefill_time}")
149+
print(f"Generation tok/s: {len(tokens) / generate_time}")
150+
143151
return tokens if echo else tokens[len(prompt_tokens) :]
144152

145153
def text_completion(

0 commit comments

Comments
 (0)