Skip to content

Commit 34483a3

Browse files
committed
Fix P3L Arg parser
1 parent 222fa01 commit 34483a3

File tree

2 files changed

+226
-172
lines changed

2 files changed

+226
-172
lines changed

benchmarks/P3L.py

Lines changed: 104 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,38 @@
33
"""
44
Patch-Perplexity (P3L)
55
6-
This is a script that produces a realistic PPL measurement
7-
for the quantized KV cache system by processing a sequence of
8-
non-overlapping patches of the reference text. Generation of the
6+
This is a script that produces a realistic PPL measurement
7+
for the quantized KV cache system by processing a sequence of
8+
non-overlapping patches of the reference text. Generation of the
99
consecutive symbols in each patch is governed (forced)
1010
by the reference text.
1111
12-
The initial context size for the system is set by the parameter
12+
The initial context size for the system is set by the parameter
1313
"--context-size".
1414
15-
The number of output symbols to generate starting from a given
16-
context is set by the parameter "--sample-size". This variable also
15+
The number of output symbols to generate starting from a given
16+
context is set by the parameter "--sample-size". This variable also
1717
defines the size of the individual patch.
1818
19-
For the N-token reference text that is split into M patches with the
19+
For the N-token reference text that is split into M patches with the
2020
system's context size C it takes M*preload + (N-C)*generation time.
2121
2222
Quick correctness validation tips:
2323
24-
Running llama-2-7b model
25-
(
26-
./vllm/examples/P3L.py
27-
--model=meta-llama/Llama-2-7b-chat-hf
28-
--context-size=1024
24+
Running llama-2-7b model
25+
(
26+
./vllm/examples/P3L.py
27+
--model=meta-llama/Llama-2-7b-chat-hf
28+
--context-size=1024
2929
--sample-size=512
3030
)
3131
should result in PPL ~ 6.524227946419175
3232
33-
Running llama-2-7b model
34-
(
35-
./vllm/examples/P3L.py
36-
--model=meta-llama/Llama-2-7b-chat-hf
37-
--context-size=1024
33+
Running llama-2-7b model
34+
(
35+
./vllm/examples/P3L.py
36+
--model=meta-llama/Llama-2-7b-chat-hf
37+
--context-size=1024
3838
--sample-size=512
3939
--patch-size=1
4040
)
@@ -58,17 +58,20 @@
5858
from vllm import LLM, SamplingParams
5959
from vllm.engine.arg_utils import EngineArgs
6060
from vllm.logger import init_logger
61+
from vllm.utils import FlexibleArgumentParser
6162

6263
logger = init_logger(__name__)
6364

6465

6566
def get_wikitext2_text(tokenizer):
6667
with tempfile.TemporaryDirectory() as tmpdirname:
67-
hf_hub_download(repo_id='alexei-v-ivanov-amd/wiki',
68-
repo_type="dataset",
69-
filename='wiki.test.raw',
70-
local_dir=tmpdirname)
71-
with open(os.path.join(tmpdirname, 'wiki.test.raw')) as f:
68+
hf_hub_download(
69+
repo_id="alexei-v-ivanov-amd/wiki",
70+
repo_type="dataset",
71+
filename="wiki.test.raw",
72+
local_dir=tmpdirname,
73+
)
74+
with open(os.path.join(tmpdirname, "wiki.test.raw")) as f:
7275
test_text = "\n".join(line.strip() for line in f)
7376
test_enc = tokenizer(test_text)
7477

@@ -79,15 +82,17 @@ def vllm_init(args):
7982
engine_args = EngineArgs.from_cli_args(args)
8083
llm = LLM(**dataclasses.asdict(engine_args))
8184

82-
sampling_params = SamplingParams(n=1,
83-
temperature=0.0,
84-
top_p=1,
85-
ignore_eos=True,
86-
ppl_measurement=True,
87-
future_context=[],
88-
prompt_logprobs=1,
89-
logprobs=1,
90-
presence_penalty=0.0)
85+
sampling_params = SamplingParams(
86+
n=1,
87+
temperature=0.0,
88+
top_p=1,
89+
ignore_eos=True,
90+
ppl_measurement=True,
91+
future_context=[],
92+
prompt_logprobs=1,
93+
logprobs=1,
94+
presence_penalty=0.0,
95+
)
9196

9297
return llm, sampling_params
9398

@@ -98,7 +103,6 @@ def vllm_predict(CONT, llm, sampl_par):
98103

99104

100105
def main(args: argparse.Namespace):
101-
102106
MESSAGE = f"Initialising @ {datetime.datetime.now()}"
103107
logger.info(MESSAGE)
104108
print(MESSAGE)
@@ -112,14 +116,17 @@ def main(args: argparse.Namespace):
112116

113117
my_n_samples = args.sample_size
114118

115-
if (args.context_size+my_n_samples) > \
116-
my_llm.llm_engine.model_config.max_model_len:
117-
MESSAGE = ("" \
118-
"Error! The total number of tokens:\n" \
119-
f" prefix ({args.context_size}) + " \
120-
f"to be generated ({my_n_samples})" \
121-
f" can't be bigger than the model limit " \
122-
f"({my_llm.llm_engine.model_config.max_model_len}).")
119+
if (
120+
args.context_size + my_n_samples
121+
) > my_llm.llm_engine.model_config.max_model_len:
122+
MESSAGE = (
123+
""
124+
"Error! The total number of tokens:\n"
125+
f" prefix ({args.context_size}) + "
126+
f"to be generated ({my_n_samples})"
127+
f" can't be bigger than the model limit "
128+
f"({my_llm.llm_engine.model_config.max_model_len})."
129+
)
123130
logger.info(MESSAGE)
124131
print(MESSAGE)
125132
return
@@ -128,26 +135,28 @@ def main(args: argparse.Namespace):
128135
logger.info("Loaded the test data.")
129136

130137
my_n_patches = math.ceil(
131-
(len(my_test_enc['input_ids']) - args.context_size - 1) / my_n_samples)
138+
(len(my_test_enc["input_ids"]) - args.context_size - 1) / my_n_samples
139+
)
132140
if args.patch_size is not None:
133141
my_n_patches = args.patch_size
134142

135143
num_tokens_generated = 0
136144
starting_time = datetime.datetime.now()
137-
MESSAGE = (f"Starting generation @ {starting_time}\n" \
138-
" Have the test sample of "
139-
f"{len(my_test_enc['input_ids'])} tokens" \
140-
f" will try to process {my_n_patches} patche(s)," \
141-
f" generating {my_n_samples} tokens in each patch" \
142-
f" from the initial context of {args.context_size} tokens.")
145+
MESSAGE = (
146+
f"Starting generation @ {starting_time}\n"
147+
" Have the test sample of "
148+
f"{len(my_test_enc['input_ids'])} tokens"
149+
f" will try to process {my_n_patches} patche(s),"
150+
f" generating {my_n_samples} tokens in each patch"
151+
f" from the initial context of {args.context_size} tokens."
152+
)
143153

144154
logger.info(MESSAGE)
145155
print(MESSAGE)
146156

147157
my_batchsize = args.batch_size
148158

149159
for c in range(0, my_n_patches, my_batchsize):
150-
151160
CONTEXT = []
152161
my_sampl_par.future_context = []
153162
my_sampl_par.cntr = []
@@ -156,53 +165,68 @@ def main(args: argparse.Namespace):
156165
if (c + b) < my_n_patches:
157166
upper_boundary = min(
158167
(c + b + 1) * my_n_samples + args.context_size,
159-
len(my_test_enc['input_ids']))
168+
len(my_test_enc["input_ids"]),
169+
)
160170
CONTEXT.append(
161-
my_test_enc['input_ids'][(c + b) * my_n_samples:(c + b) *
162-
my_n_samples + args.context_size])
171+
my_test_enc["input_ids"][
172+
(c + b) * my_n_samples : (c + b) * my_n_samples
173+
+ args.context_size
174+
]
175+
)
163176

164177
my_sampl_par.future_context.append(
165-
my_test_enc['input_ids'][(c + b) * my_n_samples +
166-
args.context_size:upper_boundary])
178+
my_test_enc["input_ids"][
179+
(c + b) * my_n_samples + args.context_size : upper_boundary
180+
]
181+
)
167182

168183
my_sampl_par.cntr.append(c + b)
169184

170185
my_sampl_par.max_tokens = max(
171-
len(my_sampl_par.future_context[b]) for b in range(len(CONTEXT)))
186+
len(my_sampl_par.future_context[b]) for b in range(len(CONTEXT))
187+
)
172188

173189
LOGPROBS = vllm_predict(CONTEXT, my_llm, my_sampl_par)
174190
for b in range(len(CONTEXT)):
175191
num_tokens_generated += len(LOGPROBS[b].outputs[0].token_ids)
176192
my_ppl -= LOGPROBS[b].outputs[0].cumulative_logprob
177193

178-
if (num_tokens_generated < my_n_samples * len(CONTEXT)):
179-
MESSAGE = (f"Warning: The number of generated tokens is" \
180-
f"less than requested ({num_tokens_generated}" \
181-
f" < {my_n_samples*len(CONTEXT)}).")
194+
if num_tokens_generated < my_n_samples * len(CONTEXT):
195+
MESSAGE = (
196+
f"Warning: The number of generated tokens is"
197+
f"less than requested ({num_tokens_generated}"
198+
f" < {my_n_samples * len(CONTEXT)})."
199+
)
182200
logger.info(MESSAGE)
183201
print(MESSAGE)
184202

185-
MESSAGE = (f"Iterations {c+1} through {c+len(CONTEXT)}" \
186-
f" of {my_n_patches} Intermediate " \
187-
"Estimates:\n" \
188-
f"\tCross-entropy_intermediate={my_ppl/num_tokens_generated}\n" \
189-
f"\tPerplexity_intermediate=" \
190-
f"{math.exp(my_ppl/num_tokens_generated)}")
203+
MESSAGE = (
204+
f"Iterations {c + 1} through {c + len(CONTEXT)}"
205+
f" of {my_n_patches} Intermediate "
206+
"Estimates:\n"
207+
f"\tCross-entropy_intermediate={my_ppl / num_tokens_generated}\n"
208+
f"\tPerplexity_intermediate="
209+
f"{math.exp(my_ppl / num_tokens_generated)}"
210+
)
191211

192212
logger.info(MESSAGE)
193213
print(MESSAGE)
194214

195215
ending_time = datetime.datetime.now()
196-
MESSAGE = (f"Done @ {ending_time} after processing for" \
197-
f" {ending_time-starting_time}" \
198-
f" generated {num_tokens_generated} tokens.")
216+
MESSAGE = (
217+
f"Done @ {ending_time} after processing for"
218+
f" {ending_time - starting_time}"
219+
f" generated {num_tokens_generated} tokens."
220+
)
199221

200222
logger.info(MESSAGE)
201223
print(MESSAGE)
202224

203-
MESSAGE = (f"\tIntegral Cross-Entropy={my_ppl}\n\tAverage Cross-Entropy=" \
204-
f"{my_ppl/num_tokens_generated}" \
205-
f"\n\tPPL={math.exp(my_ppl/num_tokens_generated)}")
225+
MESSAGE = (
226+
f"\tIntegral Cross-Entropy={my_ppl}\n\tAverage Cross-Entropy="
227+
f"{my_ppl / num_tokens_generated}"
228+
f"\n\tPPL={math.exp(my_ppl / num_tokens_generated)}"
229+
)
206230

207231
if args.output_json:
208232
results = {
@@ -219,17 +243,19 @@ def main(args: argparse.Namespace):
219243

220244

221245
if __name__ == "__main__":
222-
parser = argparse.ArgumentParser(
223-
description='Measure the PPPL (P3L) score of a given model.')
224-
parser.add_argument('--context-size', type=int, default=4096)
225-
parser.add_argument('--sample-size', type=int, default=512)
226-
parser.add_argument('--batch-size', type=int, default=1)
227-
parser.add_argument('--patch-size', type=int, default=None)
246+
parser = FlexibleArgumentParser(
247+
description="Measure the PPPL (P3L) score of a given model."
248+
)
249+
parser.add_argument("--context-size", type=int, default=4096)
250+
parser.add_argument("--sample-size", type=int, default=512)
251+
parser.add_argument("--batch-size", type=int, default=1)
252+
parser.add_argument("--patch-size", type=int, default=None)
228253
parser.add_argument(
229-
'--output-json',
254+
"--output-json",
230255
type=str,
231256
default=None,
232-
help='Path to save the latency results in JSON format.')
257+
help="Path to save the latency results in JSON format.",
258+
)
233259

234260
parser = EngineArgs.add_cli_args(parser)
235261
args = parser.parse_args()

0 commit comments

Comments
 (0)