Skip to content

Commit 8e8b5e4

Browse files
scripts: benchmark for HTTP server throughput
1 parent 982e347 commit 8e8b5e4

File tree

3 files changed

+203
-0
lines changed

3 files changed

+203
-0
lines changed

requirements/requirements-all.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
-r ../tools/server/tests/requirements.txt
44

55
-r ./requirements-compare-llama-bench.txt
6+
-r ./requirements-server-bench.txt
67
-r ./requirements-pydantic.txt
78
-r ./requirements-test-tokenizer-random.txt
89

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
datasets~=3.6.0
2+
matplotlib~=3.10.0
3+
numpy~=1.26.4
4+
requests~=2.32.3
5+
tqdm~=4.67.1

scripts/server-bench.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import json
5+
import subprocess
6+
from time import sleep, time
7+
from typing import Optional
8+
9+
import datasets
10+
import matplotlib.pyplot as plt
11+
import numpy as np
12+
import requests
13+
from tqdm.contrib.concurrent import thread_map
14+
15+
16+
def get_prompts(n_prompts: int) -> list[str]:
17+
print("Loading MMLU dataset...")
18+
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]
19+
if n_prompts >= 0:
20+
ret = ret[:n_prompts]
21+
return ret
22+
23+
24+
def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict:
25+
print("Starting the llama.cpp server...")
26+
address = f"http://localhost:{port}"
27+
28+
popen_args: list[str] = [
29+
path_server,
30+
"--flash-attn",
31+
"--n-gpu-layers", str(n_gpu_layers),
32+
"--parallel", str(parallel),
33+
"--ctx-size", str(parallel * ctx_size),
34+
"--model", path_model,
35+
"--port", str(port),
36+
"--swa-full", # FIXME performance bad otherwise
37+
# "--attn-streams",
38+
]
39+
fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL
40+
process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT)
41+
42+
n_failures: int = 0
43+
while True:
44+
try:
45+
sleep(1.0)
46+
exit_code = process.poll()
47+
if exit_code is not None:
48+
raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}")
49+
response = requests.get(f"{address}/health")
50+
if response.status_code == 200:
51+
break
52+
except requests.ConnectionError:
53+
n_failures += 1
54+
if n_failures >= 10:
55+
raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds")
56+
57+
return {"process": process, "address": address, "fout": fout}
58+
59+
60+
def get_prompt_length(data: dict) -> int:
61+
session = data["session"]
62+
server_address: str = data["server_address"]
63+
64+
response = session.post(
65+
f"{server_address}/apply-template",
66+
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
67+
)
68+
if response.status_code != 200:
69+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
70+
prompt: str = json.loads(response.text)["prompt"]
71+
response = session.post(
72+
f"{server_address}/tokenize",
73+
json={"content": prompt, "add_special": True}
74+
)
75+
if response.status_code != 200:
76+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
77+
tokens: list[str] = json.loads(response.text)["tokens"]
78+
return len(tokens)
79+
80+
81+
def send_prompt(data: dict) -> tuple[int, float, list[float]]:
82+
session = data["session"]
83+
server_address: str = data["server_address"]
84+
85+
response = session.post(
86+
f"{server_address}/apply-template",
87+
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
88+
)
89+
if response.status_code != 200:
90+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
91+
prompt: str = json.loads(response.text)["prompt"]
92+
93+
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
94+
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
95+
96+
token_arrival_times: list[float] = []
97+
for line in response.iter_lines(decode_unicode=True):
98+
if not line.startswith("data: "):
99+
continue
100+
last_valid_line = line
101+
token_arrival_times.append(time())
102+
token_arrival_times = token_arrival_times[:-1]
103+
104+
if response.status_code != 200:
105+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
106+
timings: dict = json.loads(last_valid_line[6:])["timings"]
107+
108+
return (timings["prompt_ms"], token_arrival_times)
109+
110+
111+
def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int):
112+
prompts: list[str] = get_prompts(n_prompts)
113+
114+
server = None
115+
try:
116+
server: dict = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size)
117+
server_address: str = server["address"]
118+
119+
with requests.Session() as session:
120+
data: list[dict] = []
121+
for i, p in enumerate(prompts):
122+
data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
123+
124+
print("Getting the prompt lengths...")
125+
prompt_n: list[int] = [get_prompt_length(d) for d in data]
126+
127+
print("Starting the benchmark...")
128+
print()
129+
t0 = time()
130+
results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=parallel + 1, chunksize=1)
131+
finally:
132+
if server is not None:
133+
server["process"].terminate()
134+
server["process"].wait()
135+
136+
prompt_ms = []
137+
token_t = []
138+
depth_sum: int = 0
139+
for pn, (pms, tat) in zip(prompt_n, results):
140+
prompt_ms.append(pms)
141+
token_t += tat
142+
n_tokens: int = len(tat)
143+
depth_sum += n_tokens * pn
144+
depth_sum += n_tokens * (n_tokens + 1) // 2
145+
prompt_n = np.array(prompt_n, dtype=np.int64)
146+
prompt_ms = np.array(prompt_ms, dtype=np.float64)
147+
token_t = np.array(token_t, dtype=np.float64)
148+
149+
token_t -= t0
150+
token_t_last = np.max(token_t)
151+
152+
print()
153+
print(f"Benchmark duration: {token_t_last:.2f} s")
154+
print(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
155+
print(f"Total prompt length: {np.sum(prompt_n)} tokens")
156+
print(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
157+
print(f"Average prompt latency: {np.mean(prompt_ms):.2f} ms")
158+
print(f"Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
159+
print(f"Total generated tokens: {token_t.shape[0]}")
160+
print(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
161+
print(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
162+
print(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
163+
164+
plt.figure()
165+
plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25)
166+
plt.xlim(0, 1.05 * np.max(prompt_n))
167+
plt.ylim(0, 1.05 * np.max(prompt_ms))
168+
plt.title(path_model)
169+
plt.xlabel("Prompt length [tokens]")
170+
plt.ylabel("Time to first token [ms]")
171+
plt.savefig("prompt_time.png", dpi=240)
172+
173+
bin_max = np.ceil(token_t_last) + 1
174+
plt.figure()
175+
plt.hist(token_t, np.arange(0, bin_max))
176+
plt.xlim(0, bin_max + 1)
177+
plt.title(path_model)
178+
plt.xlabel("Time [s]")
179+
plt.ylabel("Num. tokens generated per second")
180+
plt.savefig("gen_rate.png", dpi=240)
181+
182+
183+
if __name__ == "__main__":
184+
parser = argparse.ArgumentParser(
185+
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
186+
"Results are printed to console and visualized as plots (saved to current working directory).")
187+
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
188+
parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark")
189+
parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark")
190+
parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark")
191+
parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server")
192+
parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server")
193+
parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot")
194+
parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate")
195+
parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
196+
args = parser.parse_args()
197+
benchmark(**vars(args))

0 commit comments

Comments
 (0)