Skip to content

Commit 41cf067

Browse files
fix CI
1 parent 8e8b5e4 commit 41cf067

File tree

2 files changed

+21
-17
lines changed

2 files changed

+21
-17
lines changed

requirements/requirements-server-bench.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
datasets~=3.6.0
1+
datasets
22
matplotlib~=3.10.0
33
numpy~=1.26.4
44
requests~=2.32.3

scripts/server-bench.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,27 @@
77
from typing import Optional
88

99
import datasets
10+
import logging
1011
import matplotlib.pyplot as plt
1112
import numpy as np
1213
import requests
1314
from tqdm.contrib.concurrent import thread_map
1415

1516

17+
logging.basicConfig(level=logging.INFO)
18+
logger = logging.getLogger("server-bench")
19+
20+
1621
def get_prompts(n_prompts: int) -> list[str]:
17-
print("Loading MMLU dataset...")
22+
logger.info(" Loading MMLU dataset...")
1823
ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]
1924
if n_prompts >= 0:
2025
ret = ret[:n_prompts]
2126
return ret
2227

2328

2429
def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict:
25-
print("Starting the llama.cpp server...")
30+
logger.info(" Starting the llama.cpp server...")
2631
address = f"http://localhost:{port}"
2732

2833
popen_args: list[str] = [
@@ -121,11 +126,10 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
121126
for i, p in enumerate(prompts):
122127
data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i})
123128

124-
print("Getting the prompt lengths...")
129+
logger.info(" Getting the prompt lengths...")
125130
prompt_n: list[int] = [get_prompt_length(d) for d in data]
126131

127-
print("Starting the benchmark...")
128-
print()
132+
logger.info(" Starting the benchmark...\n")
129133
t0 = time()
130134
results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=parallel + 1, chunksize=1)
131135
finally:
@@ -149,17 +153,17 @@ def benchmark(path_server: str, path_model: str, path_log: Optional[str], port:
149153
token_t -= t0
150154
token_t_last = np.max(token_t)
151155

152-
print()
153-
print(f"Benchmark duration: {token_t_last:.2f} s")
154-
print(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
155-
print(f"Total prompt length: {np.sum(prompt_n)} tokens")
156-
print(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
157-
print(f"Average prompt latency: {np.mean(prompt_ms):.2f} ms")
158-
print(f"Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
159-
print(f"Total generated tokens: {token_t.shape[0]}")
160-
print(f"Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
161-
print(f"Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
162-
print(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
156+
logger.info("")
157+
logger.info(f" Benchmark duration: {token_t_last:.2f} s")
158+
logger.info(f" Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
159+
logger.info(f" Total prompt length: {np.sum(prompt_n)} tokens")
160+
logger.info(f" Average prompt length: {np.mean(prompt_n):.2f} tokens")
161+
logger.info(f" Average prompt latency: {np.mean(prompt_ms):.2f} ms")
162+
logger.info(f" Average prompt speed: {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s")
163+
logger.info(f" Total generated tokens: {token_t.shape[0]}")
164+
logger.info(f" Average generation depth: {depth_sum / token_t.shape[0]:.2f} tokens")
165+
logger.info(f" Average total generation speed: {token_t.shape[0] / token_t_last:.2f} tokens/s")
166+
logger.info(f" Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
163167

164168
plt.figure()
165169
plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25)

0 commit comments

Comments
 (0)