4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import time
7
8
from abc import ABC , abstractmethod
8
9
from typing import List , Optional
9
10
@@ -97,6 +98,7 @@ def generate( # noqa: C901
97
98
pos_base : int = 0 ,
98
99
) -> List [int ]:
99
100
# Prefill
101
+ prefill_start = time .time ()
100
102
logits = self .forward (
101
103
tokens = torch .tensor ([prompt_tokens ], dtype = torch .long , device = self .device ),
102
104
input_pos = (
@@ -105,11 +107,13 @@ def generate( # noqa: C901
105
107
else None
106
108
),
107
109
)
110
+ prefill_time = time .time () - prefill_start
108
111
109
112
current_token = next_token (logits , temperature , top_p )
110
113
print (f"{ self .tokenizer .decode_token (current_token )} " , end = "" , flush = True )
111
114
tokens = prompt_tokens + [current_token ]
112
115
116
+ generate_start = time .time ()
113
117
while len (tokens ) < max_seq_len :
114
118
if self .use_kv_cache :
115
119
logits = self .forward (
@@ -140,6 +144,10 @@ def generate( # noqa: C901
140
144
print (f"{ self .tokenizer .decode_token (current_token )} " , end = "" , flush = True )
141
145
print ("\n " )
142
146
147
+ generate_time = time .time () - generate_start
148
+ print (f"Prefill time: { prefill_time } " )
149
+ print (f"Generation tok/s: { len (tokens ) / generate_time } " )
150
+
143
151
return tokens if echo else tokens [len (prompt_tokens ) :]
144
152
145
153
def text_completion (
0 commit comments