Skip to content

Commit 494c589

Browse files
scripts: benchmark for HTTP server throughput (#14668)
* scripts: benchmark for HTTP server throughput * fix server connection reset
1 parent 0f4c6ec commit 494c589

File tree

4 files changed

+218
-0
lines changed

4 files changed

+218
-0
lines changed

requirements/requirements-all.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
-r ../tools/server/tests/requirements.txt
44

55
-r ./requirements-compare-llama-bench.txt
6+
-r ./requirements-server-bench.txt
67
-r ./requirements-pydantic.txt
78
-r ./requirements-test-tokenizer-random.txt
89

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
datasets~=3.2.0
2+
matplotlib~=3.10.0
3+
numpy~=1.26.4
4+
requests~=2.32.3
5+
tqdm~=4.67.1

scripts/server-bench.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
#!/usr/bin/env python3
2+
3+
import argparse
4+
import json
5+
import subprocess
6+
from time import sleep, time
7+
from typing import Optional
8+
9+
import datasets
10+
import logging
11+
import matplotlib.pyplot as plt
12+
import numpy as np
13+
import requests
14+
from tqdm.contrib.concurrent import thread_map
15+
16+
17+
logging.basicConfig(level=logging.INFO, format='%(message)s')
18+
logger = logging.getLogger("server-bench")
19+
20+
21+
def get_prompts(n_prompts: int) -> list[str]:
22+
logger.info("Loading MMLU dataset...")
23+
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore
24+
if n_prompts >= 0:
25+
ret = ret[:n_prompts]
26+
return ret
27+
28+
29+
def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict:
30+
logger.info("Starting the llama.cpp server...")
31+
address = f"http://localhost:{port}"
32+
33+
popen_args: list[str] = [
34+
path_server,
35+
"--flash-attn",
36+
"--n-gpu-layers", str(n_gpu_layers),
37+
"--parallel", str(parallel),
38+
"--ctx-size", str(parallel * ctx_size),
39+
"--model", path_model,
40+
"--port", str(port),
41+
"--swa-full", # FIXME performance bad otherwise
42+
# "--attn-streams",
43+
]
44+
fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL
45+
process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT)
46+
47+
n_failures: int = 0
48+
while True:
49+
try:
50+
sleep(1.0)
51+
exit_code = process.poll()
52+
if exit_code is not None:
53+
raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}")
54+
response = requests.get(f"{address}/health")
55+
if response.status_code == 200:
56+
break
57+
except requests.ConnectionError:
58+
n_failures += 1
59+
if n_failures >= 10:
60+
raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds")
61+
62+
return {"process": process, "address": address, "fout": fout}
63+
64+
65+
def get_prompt_length(data: dict) -> int:
66+
session = data["session"]
67+
server_address: str = data["server_address"]
68+
69+
response = session.post(
70+
f"{server_address}/apply-template",
71+
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
72+
)
73+
if response.status_code != 200:
74+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
75+
prompt: str = json.loads(response.text)["prompt"]
76+
response = session.post(
77+
f"{server_address}/tokenize",
78+
json={"content": prompt, "add_special": True}
79+
)
80+
if response.status_code != 200:
81+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
82+
tokens: list[str] = json.loads(response.text)["tokens"]
83+
return len(tokens)
84+
85+
86+
def send_prompt(data: dict) -> tuple[float, list[float]]:
87+
session = data["session"]
88+
server_address: str = data["server_address"]
89+
90+
response = session.post(
91+
f"{server_address}/apply-template",
92+
json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
93+
)
94+
if response.status_code != 200:
95+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
96+
prompt: str = json.loads(response.text)["prompt"]
97+
98+
json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
99+
response = session.post(f"{server_address}/completion", json=json_data, stream=True)
100+
101+
last_valid_line: str = ""
102+
token_arrival_times: list[float] = []
103+
for line in response.iter_lines(decode_unicode=True):
104+
if not line.startswith("data: "):
105+
continue
106+
last_valid_line = line
107+
token_arrival_times.append(time())
108+
token_arrival_times = token_arrival_times[:-1]
109+
110+
if response.status_code != 200:
111+
raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}")
112+
timings: dict = json.loads(last_valid_line[6:])["timings"]
113+
114+
return (timings["prompt_ms"], token_arrival_times)
115+
116+
117+
def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int):
118+
num_workers: int = parallel + 1
119+
prompts: list[str] = get_prompts(n_prompts)
120+
121+
server: Optional[dict] = None
122+
session = None
123+
try:
124+
server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size)
125+
server_address: str = server["address"]
126+
127+
adapter = requests.adapters.HTTPAdapter(pool_connections=num_workers, pool_maxsize=num_workers) # type: ignore
128+
session = requests.Session()
129+
session.mount("http://", adapter)
130+
session.mount("https://", adapter)
131+
132+
data: list[dict] = []
133+
for i, p in enumerate(prompts):
134+
data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
135+
136+
logger.info("Getting the prompt lengths...")
137+
prompt_n = [get_prompt_length(d) for d in data]
138+
139+
logger.info("Starting the benchmark...\n")
140+
t0 = time()
141+
results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=num_workers, chunksize=1)
142+
finally:
143+
if server is not None:
144+
server["process"].terminate()
145+
server["process"].wait()
146+
if session is not None:
147+
session.close()
148+
149+
prompt_ms = []
150+
token_t = []
151+
depth_sum: int = 0
152+
for pn, (pms, tat) in zip(prompt_n, results):
153+
prompt_ms.append(pms)
154+
token_t += tat
155+
n_tokens: int = len(tat)
156+
depth_sum += n_tokens * pn
157+
depth_sum += n_tokens * (n_tokens + 1) // 2
158+
prompt_n = np.array(prompt_n, dtype=np.int64)
159+
prompt_ms = np.array(prompt_ms, dtype=np.float64)
160+
token_t = np.array(token_t, dtype=np.float64)
161+
162+
token_t -= t0
163+
token_t_last = np.max(token_t)
164+
165+
logger.info("")
166+
logger.info(f"Benchmark duration: {token_t_last:.2f} s")
167+
logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
168+
logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
169+
logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
170+
logger.info(f"Average prompt latency: {np.mean(prompt_ms):.2f} ms")
171+
logger.info(f"Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
172+
logger.info(f"Total generated tokens: {token_t.shape[0]}")
173+
logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
174+
logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
175+
logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
176+
177+
plt.figure()
178+
plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25)
179+
plt.xlim(0, 1.05 * np.max(prompt_n))
180+
plt.ylim(0, 1.05 * np.max(prompt_ms))
181+
plt.title(path_model)
182+
plt.xlabel("Prompt length [tokens]")
183+
plt.ylabel("Time to first token [ms]")
184+
plt.savefig("prompt_time.png", dpi=240)
185+
186+
bin_max = np.ceil(token_t_last) + 1
187+
plt.figure()
188+
plt.hist(token_t, np.arange(0, bin_max))
189+
plt.xlim(0, bin_max + 1)
190+
plt.title(path_model)
191+
plt.xlabel("Time [s]")
192+
plt.ylabel("Num. tokens generated per second")
193+
plt.savefig("gen_rate.png", dpi=240)
194+
195+
196+
if __name__ == "__main__":
197+
parser = argparse.ArgumentParser(
198+
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
199+
"Results are printed to console and visualized as plots (saved to current working directory).")
200+
parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
201+
parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark")
202+
parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark")
203+
parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark")
204+
parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server")
205+
parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server")
206+
parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot")
207+
parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate")
208+
parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
209+
args = parser.parse_args()
210+
benchmark(**vars(args))

tools/server/utils.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
// increase max payload length to allow use of larger context size
1313
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
14+
// increase backlog size to avoid connection resets for >> 1 slots
15+
#define CPPHTTPLIB_LISTEN_BACKLOG 512
1416
// disable Nagle's algorithm
1517
#define CPPHTTPLIB_TCP_NODELAY true
1618
#include <cpp-httplib/httplib.h>

0 commit comments

Comments
 (0)