|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +import argparse |
| 4 | +import json |
| 5 | +import subprocess |
| 6 | +from time import sleep, time |
| 7 | +from typing import Optional |
| 8 | + |
| 9 | +import datasets |
| 10 | +import logging |
| 11 | +import matplotlib.pyplot as plt |
| 12 | +import numpy as np |
| 13 | +import requests |
| 14 | +from tqdm.contrib.concurrent import thread_map |
| 15 | + |
| 16 | + |
| 17 | +logging.basicConfig(level=logging.INFO, format='%(message)s') |
| 18 | +logger = logging.getLogger("server-bench") |
| 19 | + |
| 20 | + |
| 21 | +def get_prompts(n_prompts: int) -> list[str]: |
| 22 | + logger.info("Loading MMLU dataset...") |
| 23 | + ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"] # type: ignore |
| 24 | + if n_prompts >= 0: |
| 25 | + ret = ret[:n_prompts] |
| 26 | + return ret |
| 27 | + |
| 28 | + |
| 29 | +def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict: |
| 30 | + logger.info("Starting the llama.cpp server...") |
| 31 | + address = f"http://localhost:{port}" |
| 32 | + |
| 33 | + popen_args: list[str] = [ |
| 34 | + path_server, |
| 35 | + "--flash-attn", |
| 36 | + "--n-gpu-layers", str(n_gpu_layers), |
| 37 | + "--parallel", str(parallel), |
| 38 | + "--ctx-size", str(parallel * ctx_size), |
| 39 | + "--model", path_model, |
| 40 | + "--port", str(port), |
| 41 | + "--swa-full", # FIXME performance bad otherwise |
| 42 | + # "--attn-streams", |
| 43 | + ] |
| 44 | + fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL |
| 45 | + process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT) |
| 46 | + |
| 47 | + n_failures: int = 0 |
| 48 | + while True: |
| 49 | + try: |
| 50 | + sleep(1.0) |
| 51 | + exit_code = process.poll() |
| 52 | + if exit_code is not None: |
| 53 | + raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}") |
| 54 | + response = requests.get(f"{address}/health") |
| 55 | + if response.status_code == 200: |
| 56 | + break |
| 57 | + except requests.ConnectionError: |
| 58 | + n_failures += 1 |
| 59 | + if n_failures >= 10: |
| 60 | + raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds") |
| 61 | + |
| 62 | + return {"process": process, "address": address, "fout": fout} |
| 63 | + |
| 64 | + |
| 65 | +def get_prompt_length(data: dict) -> int: |
| 66 | + session = data["session"] |
| 67 | + server_address: str = data["server_address"] |
| 68 | + |
| 69 | + response = session.post( |
| 70 | + f"{server_address}/apply-template", |
| 71 | + json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} |
| 72 | + ) |
| 73 | + if response.status_code != 200: |
| 74 | + raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") |
| 75 | + prompt: str = json.loads(response.text)["prompt"] |
| 76 | + response = session.post( |
| 77 | + f"{server_address}/tokenize", |
| 78 | + json={"content": prompt, "add_special": True} |
| 79 | + ) |
| 80 | + if response.status_code != 200: |
| 81 | + raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") |
| 82 | + tokens: list[str] = json.loads(response.text)["tokens"] |
| 83 | + return len(tokens) |
| 84 | + |
| 85 | + |
| 86 | +def send_prompt(data: dict) -> tuple[float, list[float]]: |
| 87 | + session = data["session"] |
| 88 | + server_address: str = data["server_address"] |
| 89 | + |
| 90 | + response = session.post( |
| 91 | + f"{server_address}/apply-template", |
| 92 | + json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} |
| 93 | + ) |
| 94 | + if response.status_code != 200: |
| 95 | + raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") |
| 96 | + prompt: str = json.loads(response.text)["prompt"] |
| 97 | + |
| 98 | + json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} |
| 99 | + response = session.post(f"{server_address}/completion", json=json_data, stream=True) |
| 100 | + |
| 101 | + last_valid_line: str = "" |
| 102 | + token_arrival_times: list[float] = [] |
| 103 | + for line in response.iter_lines(decode_unicode=True): |
| 104 | + if not line.startswith("data: "): |
| 105 | + continue |
| 106 | + last_valid_line = line |
| 107 | + token_arrival_times.append(time()) |
| 108 | + token_arrival_times = token_arrival_times[:-1] |
| 109 | + |
| 110 | + if response.status_code != 200: |
| 111 | + raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") |
| 112 | + timings: dict = json.loads(last_valid_line[6:])["timings"] |
| 113 | + |
| 114 | + return (timings["prompt_ms"], token_arrival_times) |
| 115 | + |
| 116 | + |
| 117 | +def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int): |
| 118 | + num_workers: int = parallel + 1 |
| 119 | + prompts: list[str] = get_prompts(n_prompts) |
| 120 | + |
| 121 | + server: Optional[dict] = None |
| 122 | + session = None |
| 123 | + try: |
| 124 | + server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size) |
| 125 | + server_address: str = server["address"] |
| 126 | + |
| 127 | + adapter = requests.adapters.HTTPAdapter(pool_connections=num_workers, pool_maxsize=num_workers) # type: ignore |
| 128 | + session = requests.Session() |
| 129 | + session.mount("http://", adapter) |
| 130 | + session.mount("https://", adapter) |
| 131 | + |
| 132 | + data: list[dict] = [] |
| 133 | + for i, p in enumerate(prompts): |
| 134 | + data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i}) |
| 135 | + |
| 136 | + logger.info("Getting the prompt lengths...") |
| 137 | + prompt_n = [get_prompt_length(d) for d in data] |
| 138 | + |
| 139 | + logger.info("Starting the benchmark...\n") |
| 140 | + t0 = time() |
| 141 | + results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=num_workers, chunksize=1) |
| 142 | + finally: |
| 143 | + if server is not None: |
| 144 | + server["process"].terminate() |
| 145 | + server["process"].wait() |
| 146 | + if session is not None: |
| 147 | + session.close() |
| 148 | + |
| 149 | + prompt_ms = [] |
| 150 | + token_t = [] |
| 151 | + depth_sum: int = 0 |
| 152 | + for pn, (pms, tat) in zip(prompt_n, results): |
| 153 | + prompt_ms.append(pms) |
| 154 | + token_t += tat |
| 155 | + n_tokens: int = len(tat) |
| 156 | + depth_sum += n_tokens * pn |
| 157 | + depth_sum += n_tokens * (n_tokens + 1) // 2 |
| 158 | + prompt_n = np.array(prompt_n, dtype=np.int64) |
| 159 | + prompt_ms = np.array(prompt_ms, dtype=np.float64) |
| 160 | + token_t = np.array(token_t, dtype=np.float64) |
| 161 | + |
| 162 | + token_t -= t0 |
| 163 | + token_t_last = np.max(token_t) |
| 164 | + |
| 165 | + logger.info("") |
| 166 | + logger.info(f"Benchmark duration: {token_t_last:.2f} s") |
| 167 | + logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min") |
| 168 | + logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens") |
| 169 | + logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens") |
| 170 | + logger.info(f"Average prompt latency: {np.mean(prompt_ms):.2f} ms") |
| 171 | + logger.info(f"Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s") |
| 172 | + logger.info(f"Total generated tokens: {token_t.shape[0]}") |
| 173 | + logger.info(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens") |
| 174 | + logger.info(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s") |
| 175 | + logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot") |
| 176 | + |
| 177 | + plt.figure() |
| 178 | + plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25) |
| 179 | + plt.xlim(0, 1.05 * np.max(prompt_n)) |
| 180 | + plt.ylim(0, 1.05 * np.max(prompt_ms)) |
| 181 | + plt.title(path_model) |
| 182 | + plt.xlabel("Prompt length [tokens]") |
| 183 | + plt.ylabel("Time to first token [ms]") |
| 184 | + plt.savefig("prompt_time.png", dpi=240) |
| 185 | + |
| 186 | + bin_max = np.ceil(token_t_last) + 1 |
| 187 | + plt.figure() |
| 188 | + plt.hist(token_t, np.arange(0, bin_max)) |
| 189 | + plt.xlim(0, bin_max + 1) |
| 190 | + plt.title(path_model) |
| 191 | + plt.xlabel("Time [s]") |
| 192 | + plt.ylabel("Num. tokens generated per second") |
| 193 | + plt.savefig("gen_rate.png", dpi=240) |
| 194 | + |
| 195 | + |
| 196 | +if __name__ == "__main__": |
| 197 | + parser = argparse.ArgumentParser( |
| 198 | + description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " |
| 199 | + "Results are printed to console and visualized as plots (saved to current working directory).") |
| 200 | + parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary") |
| 201 | + parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark") |
| 202 | + parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark") |
| 203 | + parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark") |
| 204 | + parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server") |
| 205 | + parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server") |
| 206 | + parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot") |
| 207 | + parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate") |
| 208 | + parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt") |
| 209 | + args = parser.parse_args() |
| 210 | + benchmark(**vars(args)) |
0 commit comments