Skip to content

Commit 24c503a

Browse files
authored
add long context example (#304)
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
1 parent e9604ff commit 24c503a

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
This example exercise long context lengths
3+
4+
Let's say you want to test the following configuration
5+
6+
Prefill: Max_prompt = 4K, prefill batch-size = 1.
7+
Generation: Max_context = 8K, Max_batch = 4.
8+
9+
Then the command line will be
10+
11+
```
12+
python long_context.py --max-num-seqs 4 --max-prompt-len 4096 \
13+
--max-model-len 8192
14+
```
15+
16+
To compare with cpu, add `--compare-with-cpu`.
17+
18+
All sequences will run up to the max context length.
19+
20+
"""
21+
22+
import argparse
23+
import os
24+
import platform
25+
import sys
26+
import time
27+
28+
import torch
29+
from transformers import AutoTokenizer
30+
from vllm import LLM, SamplingParams
31+
from vllm.inputs import TokensPrompt
32+
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--model",
35+
type=str,
36+
default="ibm-ai-platform/micro-g3.3-8b-instruct-1b")
37+
parser.add_argument("--max_model_len",
38+
"--max-model-len",
39+
type=int,
40+
default=2048)
41+
parser.add_argument("--max_prompt_len",
42+
"--max-prompt-len",
43+
type=int,
44+
default=1024)
45+
parser.add_argument("--max_num_seqs", "--max-num-seqs", type=int, default=2)
46+
parser.add_argument("--tp", type=int, default=1)
47+
parser.add_argument("--num-prompts", "-n", type=int, default=8)
48+
parser.add_argument("--compare-with-cpu",
49+
action=argparse.BooleanOptionalAction)
50+
args = parser.parse_args()
51+
52+
max_num_seqs = args.max_num_seqs # defines the max batch size
53+
assert args.max_prompt_len < args.max_model_len
54+
55+
if platform.machine() == "arm64":
56+
print("Detected arm64 running environment. "
57+
"Setting HF_HUB_OFFLINE=1 otherwise vllm tries to download a "
58+
"different version of the model using HF API which might not work "
59+
"locally on arm64.")
60+
os.environ["HF_HUB_OFFLINE"] = "1"
61+
62+
if "VLLM_SPYRE_DYNAMO_BACKEND" not in os.environ:
63+
os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = 'eager'
64+
os.environ['VLLM_SPYRE_USE_CB'] = '1'
65+
os.environ['VLLM_USE_V1'] = '1'
66+
67+
template = ("Summarize the following code: \n\n{}")
68+
69+
70+
def get_python_file(source_file):
71+
for path in sys.path:
72+
file_path = os.path.join(path, source_file)
73+
if os.path.isfile(file_path):
74+
with open(file_path, encoding="utf-8") as f:
75+
return f.read()
76+
raise Exception(f"File {source_file} not found")
77+
78+
79+
example_files = [
80+
"os.py",
81+
"gzip.py",
82+
"inspect.py",
83+
"abc.py",
84+
"dataclasses.py",
85+
"enum.py",
86+
"functools.py",
87+
"io.py",
88+
]
89+
90+
file_contents = [get_python_file(e) for e in example_files]
91+
92+
prompts = [template.format(c) for c in file_contents]
93+
94+
prompts = prompts * (args.num_prompts // len(prompts) + 1)
95+
prompts = prompts[0:args.num_prompts]
96+
97+
tokenizer = AutoTokenizer.from_pretrained(args.model)
98+
99+
tokenized_prompts = tokenizer(prompts)["input_ids"]
100+
tokenized_prompts = [p[:args.max_prompt_len] for p in tokenized_prompts]
101+
102+
prompt_lens = [len(p) for p in tokenized_prompts]
103+
104+
max_prompt = max(prompt_lens)
105+
min_prompt = min(prompt_lens)
106+
107+
if max_prompt < args.max_prompt_len:
108+
print(f"Warning, none of the prompts reach the maximum length"
109+
f"({args.max_prompt_len})")
110+
111+
print(f"All prompts have lengths between {min_prompt} and {max_prompt}")
112+
113+
114+
def round_up(t):
115+
return ((t + 63) // 64) * 64
116+
117+
118+
tokens_to_generate = [
119+
args.max_model_len - round_up(plen) for plen in prompt_lens
120+
]
121+
122+
sampling_params = [
123+
SamplingParams(max_tokens=t, temperature=0.0, ignore_eos=True)
124+
for t in tokens_to_generate
125+
]
126+
127+
vllm_token_prompts = [
128+
TokensPrompt(prompt_token_ids=p) for p in tokenized_prompts
129+
]
130+
131+
# Create an LLM.
132+
llm = LLM(model=args.model,
133+
tokenizer=args.model,
134+
max_model_len=args.max_model_len,
135+
block_size=2048,
136+
max_num_seqs=max_num_seqs,
137+
tensor_parallel_size=args.tp)
138+
139+
# Generate texts from the prompts. The output is a list of RequestOutput objects
140+
# that contain the prompt, generated text, and other information.
141+
print("=============== GENERATE")
142+
t0 = time.time()
143+
outputs = llm.generate(vllm_token_prompts, sampling_params)
144+
print("Time elapsed for all prompts is %.2f sec" % (time.time() - t0))
145+
print("===============")
146+
for output, prompt in zip(outputs, prompts):
147+
generated_text = output.outputs[0].text[:100]
148+
prompt = prompt[:100]
149+
print(f"\nPrompt:\n {prompt!r}")
150+
print(f"\nGenerated text (truncated):\n {generated_text!r}\n")
151+
print("-----------------------------------")
152+
153+
if args.compare_with_cpu:
154+
print("Comparing results with HF on cpu")
155+
print("===============")
156+
any_differ = False
157+
158+
from transformers import AutoModelForCausalLM
159+
model = AutoModelForCausalLM.from_pretrained(args.model)
160+
161+
for i in range(args.num_prompts):
162+
prompt = prompts[i]
163+
164+
hf_input_tokens = torch.tensor(tokenized_prompts[i]).unsqueeze(0)
165+
hf_output = model.generate(hf_input_tokens,
166+
do_sample=False,
167+
min_new_tokens=tokens_to_generate[i],
168+
max_new_tokens=tokens_to_generate[i],
169+
return_dict_in_generate=True,
170+
output_scores=True)
171+
172+
# decode output tokens after first removing input tokens (prompt)
173+
hf_generated_text = tokenizer.batch_decode(
174+
hf_output.sequences[:, len(hf_input_tokens[0]):])[0]
175+
176+
if hf_generated_text != outputs[i].outputs[0].text:
177+
any_differ = True
178+
spyre_output = outputs[i].outputs[0].text
179+
print(f"Results for prompt {i} differ on cpu")
180+
print(f"\nPrompt:\n {prompt[:100]!r}")
181+
print(f"\nSpyre generated text:\n {spyre_output[:100]!r}\n")
182+
print(f"\nCPU generated text:\n {hf_generated_text[:100]!r}\n")
183+
print("-----------------------------------")
184+
185+
if not any_differ:
186+
print("\nAll results match!\n")

0 commit comments

Comments
 (0)