Skip to content

Commit a6a6e61

Browse files
yanbing-jJack-Khuu
andauthored
Ignore tokens per sec from jit_compile iteration (pytorch#1378)
* Remove tokens per sec in aggregate_metrics when jit_compile * Add warning to user * Update --------- Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
1 parent 5da240a commit a6a6e61

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

torchchat/generate.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1149,9 +1149,11 @@ def callback(x, *, done_generating=False):
11491149
print(
11501150
f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
11511151
)
1152-
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1153-
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1154-
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
1152+
else:
1153+
# aggregate_metrics will not append when is jit_compile, which will affect the average numbers.
1154+
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
1155+
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
1156+
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
11551157

11561158
logging.info(
11571159
f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\
@@ -1205,7 +1207,8 @@ def callback(x, *, done_generating=False):
12051207
or torch.isnan(torch.tensor(avg_next_tokens_sec))
12061208
):
12071209
print(
1208-
f"\n Average tokens/sec (total): {avg_tokens_sec:.2f} \
1210+
f"\nWarning: Excluding compile in calculations \
1211+
\n Average tokens/sec (total): {avg_tokens_sec:.2f} \
12091212
\nAverage tokens/sec (first token): {avg_first_token_sec:.2f} \
12101213
\nAverage tokens/sec (next tokens): {avg_next_tokens_sec:.2f} \n\
12111214
"

0 commit comments

Comments
 (0)