3
3
"""
4
4
Patch-Perplexity (P3L)
5
5
6
- This is a script that produces a realistic PPL measurement
7
- for the quantized KV cache system by processing a sequence of
8
- non-overlapping patches of the reference text. Generation of the
6
+ This is a script that produces a realistic PPL measurement
7
+ for the quantized KV cache system by processing a sequence of
8
+ non-overlapping patches of the reference text. Generation of the
9
9
consecutive symbols in each patch is governed (forced)
10
10
by the reference text.
11
11
12
- The initial context size for the system is set by the parameter
12
+ The initial context size for the system is set by the parameter
13
13
"--context-size".
14
14
15
- The number of output symbols to generate starting from a given
16
- context is set by the parameter "--sample-size". This variable also
15
+ The number of output symbols to generate starting from a given
16
+ context is set by the parameter "--sample-size". This variable also
17
17
defines the size of the individual patch.
18
18
19
- For the N-token reference text that is split into M patches with the
19
+ For the N-token reference text that is split into M patches with the
20
20
system's context size C it takes M*preload + (N-C)*generation time.
21
21
22
22
Quick correctness validation tips:
23
23
24
- Running llama-2-7b model
25
- (
26
- ./vllm/examples/P3L.py
27
- --model=meta-llama/Llama-2-7b-chat-hf
28
- --context-size=1024
24
+ Running llama-2-7b model
25
+ (
26
+ ./vllm/examples/P3L.py
27
+ --model=meta-llama/Llama-2-7b-chat-hf
28
+ --context-size=1024
29
29
--sample-size=512
30
30
)
31
31
should result in PPL ~ 6.524227946419175
32
32
33
- Running llama-2-7b model
34
- (
35
- ./vllm/examples/P3L.py
36
- --model=meta-llama/Llama-2-7b-chat-hf
37
- --context-size=1024
33
+ Running llama-2-7b model
34
+ (
35
+ ./vllm/examples/P3L.py
36
+ --model=meta-llama/Llama-2-7b-chat-hf
37
+ --context-size=1024
38
38
--sample-size=512
39
39
--patch-size=1
40
40
)
58
58
from vllm import LLM , SamplingParams
59
59
from vllm .engine .arg_utils import EngineArgs
60
60
from vllm .logger import init_logger
61
+ from vllm .utils import FlexibleArgumentParser
61
62
62
63
logger = init_logger (__name__ )
63
64
64
65
65
66
def get_wikitext2_text (tokenizer ):
66
67
with tempfile .TemporaryDirectory () as tmpdirname :
67
- hf_hub_download (repo_id = 'alexei-v-ivanov-amd/wiki' ,
68
- repo_type = "dataset" ,
69
- filename = 'wiki.test.raw' ,
70
- local_dir = tmpdirname )
71
- with open (os .path .join (tmpdirname , 'wiki.test.raw' )) as f :
68
+ hf_hub_download (
69
+ repo_id = "alexei-v-ivanov-amd/wiki" ,
70
+ repo_type = "dataset" ,
71
+ filename = "wiki.test.raw" ,
72
+ local_dir = tmpdirname ,
73
+ )
74
+ with open (os .path .join (tmpdirname , "wiki.test.raw" )) as f :
72
75
test_text = "\n " .join (line .strip () for line in f )
73
76
test_enc = tokenizer (test_text )
74
77
@@ -79,15 +82,17 @@ def vllm_init(args):
79
82
engine_args = EngineArgs .from_cli_args (args )
80
83
llm = LLM (** dataclasses .asdict (engine_args ))
81
84
82
- sampling_params = SamplingParams (n = 1 ,
83
- temperature = 0.0 ,
84
- top_p = 1 ,
85
- ignore_eos = True ,
86
- ppl_measurement = True ,
87
- future_context = [],
88
- prompt_logprobs = 1 ,
89
- logprobs = 1 ,
90
- presence_penalty = 0.0 )
85
+ sampling_params = SamplingParams (
86
+ n = 1 ,
87
+ temperature = 0.0 ,
88
+ top_p = 1 ,
89
+ ignore_eos = True ,
90
+ ppl_measurement = True ,
91
+ future_context = [],
92
+ prompt_logprobs = 1 ,
93
+ logprobs = 1 ,
94
+ presence_penalty = 0.0 ,
95
+ )
91
96
92
97
return llm , sampling_params
93
98
@@ -98,7 +103,6 @@ def vllm_predict(CONT, llm, sampl_par):
98
103
99
104
100
105
def main (args : argparse .Namespace ):
101
-
102
106
MESSAGE = f"Initialising @ { datetime .datetime .now ()} "
103
107
logger .info (MESSAGE )
104
108
print (MESSAGE )
@@ -112,14 +116,17 @@ def main(args: argparse.Namespace):
112
116
113
117
my_n_samples = args .sample_size
114
118
115
- if (args .context_size + my_n_samples ) > \
116
- my_llm .llm_engine .model_config .max_model_len :
117
- MESSAGE = ("" \
118
- "Error! The total number of tokens:\n " \
119
- f" prefix ({ args .context_size } ) + " \
120
- f"to be generated ({ my_n_samples } )" \
121
- f" can't be bigger than the model limit " \
122
- f"({ my_llm .llm_engine .model_config .max_model_len } )." )
119
+ if (
120
+ args .context_size + my_n_samples
121
+ ) > my_llm .llm_engine .model_config .max_model_len :
122
+ MESSAGE = (
123
+ ""
124
+ "Error! The total number of tokens:\n "
125
+ f" prefix ({ args .context_size } ) + "
126
+ f"to be generated ({ my_n_samples } )"
127
+ f" can't be bigger than the model limit "
128
+ f"({ my_llm .llm_engine .model_config .max_model_len } )."
129
+ )
123
130
logger .info (MESSAGE )
124
131
print (MESSAGE )
125
132
return
@@ -128,26 +135,28 @@ def main(args: argparse.Namespace):
128
135
logger .info ("Loaded the test data." )
129
136
130
137
my_n_patches = math .ceil (
131
- (len (my_test_enc ['input_ids' ]) - args .context_size - 1 ) / my_n_samples )
138
+ (len (my_test_enc ["input_ids" ]) - args .context_size - 1 ) / my_n_samples
139
+ )
132
140
if args .patch_size is not None :
133
141
my_n_patches = args .patch_size
134
142
135
143
num_tokens_generated = 0
136
144
starting_time = datetime .datetime .now ()
137
- MESSAGE = (f"Starting generation @ { starting_time } \n " \
138
- " Have the test sample of "
139
- f"{ len (my_test_enc ['input_ids' ])} tokens" \
140
- f" will try to process { my_n_patches } patche(s)," \
141
- f" generating { my_n_samples } tokens in each patch" \
142
- f" from the initial context of { args .context_size } tokens." )
145
+ MESSAGE = (
146
+ f"Starting generation @ { starting_time } \n "
147
+ " Have the test sample of "
148
+ f"{ len (my_test_enc ['input_ids' ])} tokens"
149
+ f" will try to process { my_n_patches } patche(s),"
150
+ f" generating { my_n_samples } tokens in each patch"
151
+ f" from the initial context of { args .context_size } tokens."
152
+ )
143
153
144
154
logger .info (MESSAGE )
145
155
print (MESSAGE )
146
156
147
157
my_batchsize = args .batch_size
148
158
149
159
for c in range (0 , my_n_patches , my_batchsize ):
150
-
151
160
CONTEXT = []
152
161
my_sampl_par .future_context = []
153
162
my_sampl_par .cntr = []
@@ -156,53 +165,68 @@ def main(args: argparse.Namespace):
156
165
if (c + b ) < my_n_patches :
157
166
upper_boundary = min (
158
167
(c + b + 1 ) * my_n_samples + args .context_size ,
159
- len (my_test_enc ['input_ids' ]))
168
+ len (my_test_enc ["input_ids" ]),
169
+ )
160
170
CONTEXT .append (
161
- my_test_enc ['input_ids' ][(c + b ) * my_n_samples :(c + b ) *
162
- my_n_samples + args .context_size ])
171
+ my_test_enc ["input_ids" ][
172
+ (c + b ) * my_n_samples : (c + b ) * my_n_samples
173
+ + args .context_size
174
+ ]
175
+ )
163
176
164
177
my_sampl_par .future_context .append (
165
- my_test_enc ['input_ids' ][(c + b ) * my_n_samples +
166
- args .context_size :upper_boundary ])
178
+ my_test_enc ["input_ids" ][
179
+ (c + b ) * my_n_samples + args .context_size : upper_boundary
180
+ ]
181
+ )
167
182
168
183
my_sampl_par .cntr .append (c + b )
169
184
170
185
my_sampl_par .max_tokens = max (
171
- len (my_sampl_par .future_context [b ]) for b in range (len (CONTEXT )))
186
+ len (my_sampl_par .future_context [b ]) for b in range (len (CONTEXT ))
187
+ )
172
188
173
189
LOGPROBS = vllm_predict (CONTEXT , my_llm , my_sampl_par )
174
190
for b in range (len (CONTEXT )):
175
191
num_tokens_generated += len (LOGPROBS [b ].outputs [0 ].token_ids )
176
192
my_ppl -= LOGPROBS [b ].outputs [0 ].cumulative_logprob
177
193
178
- if (num_tokens_generated < my_n_samples * len (CONTEXT )):
179
- MESSAGE = (f"Warning: The number of generated tokens is" \
180
- f"less than requested ({ num_tokens_generated } " \
181
- f" < { my_n_samples * len (CONTEXT )} )." )
194
+ if num_tokens_generated < my_n_samples * len (CONTEXT ):
195
+ MESSAGE = (
196
+ f"Warning: The number of generated tokens is"
197
+ f"less than requested ({ num_tokens_generated } "
198
+ f" < { my_n_samples * len (CONTEXT )} )."
199
+ )
182
200
logger .info (MESSAGE )
183
201
print (MESSAGE )
184
202
185
- MESSAGE = (f"Iterations { c + 1 } through { c + len (CONTEXT )} " \
186
- f" of { my_n_patches } Intermediate " \
187
- "Estimates:\n " \
188
- f"\t Cross-entropy_intermediate={ my_ppl / num_tokens_generated } \n " \
189
- f"\t Perplexity_intermediate=" \
190
- f"{ math .exp (my_ppl / num_tokens_generated )} " )
203
+ MESSAGE = (
204
+ f"Iterations { c + 1 } through { c + len (CONTEXT )} "
205
+ f" of { my_n_patches } Intermediate "
206
+ "Estimates:\n "
207
+ f"\t Cross-entropy_intermediate={ my_ppl / num_tokens_generated } \n "
208
+ f"\t Perplexity_intermediate="
209
+ f"{ math .exp (my_ppl / num_tokens_generated )} "
210
+ )
191
211
192
212
logger .info (MESSAGE )
193
213
print (MESSAGE )
194
214
195
215
ending_time = datetime .datetime .now ()
196
- MESSAGE = (f"Done @ { ending_time } after processing for" \
197
- f" { ending_time - starting_time } " \
198
- f" generated { num_tokens_generated } tokens." )
216
+ MESSAGE = (
217
+ f"Done @ { ending_time } after processing for"
218
+ f" { ending_time - starting_time } "
219
+ f" generated { num_tokens_generated } tokens."
220
+ )
199
221
200
222
logger .info (MESSAGE )
201
223
print (MESSAGE )
202
224
203
- MESSAGE = (f"\t Integral Cross-Entropy={ my_ppl } \n \t Average Cross-Entropy=" \
204
- f"{ my_ppl / num_tokens_generated } " \
205
- f"\n \t PPL={ math .exp (my_ppl / num_tokens_generated )} " )
225
+ MESSAGE = (
226
+ f"\t Integral Cross-Entropy={ my_ppl } \n \t Average Cross-Entropy="
227
+ f"{ my_ppl / num_tokens_generated } "
228
+ f"\n \t PPL={ math .exp (my_ppl / num_tokens_generated )} "
229
+ )
206
230
207
231
if args .output_json :
208
232
results = {
@@ -219,17 +243,19 @@ def main(args: argparse.Namespace):
219
243
220
244
221
245
if __name__ == "__main__" :
222
- parser = argparse .ArgumentParser (
223
- description = 'Measure the PPPL (P3L) score of a given model.' )
224
- parser .add_argument ('--context-size' , type = int , default = 4096 )
225
- parser .add_argument ('--sample-size' , type = int , default = 512 )
226
- parser .add_argument ('--batch-size' , type = int , default = 1 )
227
- parser .add_argument ('--patch-size' , type = int , default = None )
246
+ parser = FlexibleArgumentParser (
247
+ description = "Measure the PPPL (P3L) score of a given model."
248
+ )
249
+ parser .add_argument ("--context-size" , type = int , default = 4096 )
250
+ parser .add_argument ("--sample-size" , type = int , default = 512 )
251
+ parser .add_argument ("--batch-size" , type = int , default = 1 )
252
+ parser .add_argument ("--patch-size" , type = int , default = None )
228
253
parser .add_argument (
229
- ' --output-json' ,
254
+ " --output-json" ,
230
255
type = str ,
231
256
default = None ,
232
- help = 'Path to save the latency results in JSON format.' )
257
+ help = "Path to save the latency results in JSON format." ,
258
+ )
233
259
234
260
parser = EngineArgs .add_cli_args (parser )
235
261
args = parser .parse_args ()
0 commit comments