Skip to content

Commit 8d72ccc

Browse files
committed
resolve conflict
2 parents a74ff1d + 888780f commit 8d72ccc

File tree

141 files changed

+11925
-1145
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

141 files changed

+11925
-1145
lines changed

benchmarks/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,30 @@ python benchmark_serving.py \
105105
--save-result > infer_log.txt 2>&1 &
106106
```
107107

108+
### 投机解码性能测试工具
109+
110+
#### 使用方式:
111+
112+
```bash
113+
python benchmarks/benchmark_mtp.py \
114+
--host 127.0.0.1 --port 8000 \
115+
--max-concurrency 16 32 64 96 --num-prompts 256 \
116+
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
117+
--s_itl-base-model 15.88 22.84 16.47 16.93 \
118+
--dataset-name EBChat \
119+
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
120+
```
121+
122+
#### 参数说明
123+
124+
```bash
125+
--host:服务ip地址,用于组url
126+
--port:服务HTTP端口,用于组url
127+
--max-concurrency:测试并发数
128+
--num-prompts:总计发送多少条请求
129+
--acceptance-rate:投机解码的模拟接受率
130+
--draft-token-steps:投机解码的步数
131+
--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应
132+
--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
133+
--dataset-path:测试数据集路径
134+
```

benchmarks/benchmark_mtp.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import argparse
18+
import asyncio
19+
import contextlib
20+
import os
21+
import signal
22+
import socket
23+
import subprocess
24+
import time
25+
from typing import Union
26+
27+
import openai
28+
import yaml
29+
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
30+
from benchmark_serving import benchmark
31+
32+
33+
def prepare_input_requests(
34+
num_prompts: int, dataset_name: str, dataset_path: str
35+
) -> Union[EBDataset, EBChatDataset]:
36+
dataset_mapping = {
37+
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
38+
num_requests=num_prompts
39+
),
40+
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
41+
num_requests=num_prompts
42+
),
43+
}
44+
45+
try:
46+
input_requests = dataset_mapping[dataset_name]()
47+
except KeyError as err:
48+
raise ValueError(f"Unknown dataset: {dataset_name}") from err
49+
50+
return input_requests
51+
52+
53+
class FakeTokenizer:
54+
def encode(self, text: str, add_special_tokens: bool = False):
55+
return []
56+
57+
58+
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
59+
selected_percentile_metrics = ["s_itl"]
60+
selected_percentiles = []
61+
# Run benchmark
62+
results = asyncio.run(
63+
benchmark(
64+
backend="openai-chat",
65+
api_url=f"{base_url}/v1/chat/completions",
66+
base_url=base_url,
67+
model_id="default",
68+
model_name="default",
69+
input_requests=input_requests,
70+
hyper_parameters={},
71+
logprobs=None,
72+
request_rate=float("inf"),
73+
burstiness=1.0,
74+
disable_tqdm=disable_tqdm,
75+
profile=False,
76+
selected_percentile_metrics=selected_percentile_metrics,
77+
selected_percentiles=selected_percentiles,
78+
ignore_eos=False,
79+
goodput_config_dict=None,
80+
max_concurrency=max_concurrency,
81+
lora_modules=None,
82+
extra_body=None,
83+
)
84+
)
85+
86+
record = {
87+
"mean_s_itl_ms": results["mean_s_itl_ms"],
88+
}
89+
90+
return record
91+
92+
93+
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
94+
95+
tmp = 0.0
96+
for i in range(draft_token_step):
97+
tmp += pow(acceptance_rate, i + 1)
98+
99+
r_ac = tmp / (1 + tmp)
100+
101+
return t_ori / ((1 - r_ac) * t_mtp)
102+
103+
104+
def main(args):
105+
base_url = f"http://{args.host}:{args.port}"
106+
107+
input_requests = prepare_input_requests(
108+
args.num_prompts, args.dataset_name, args.dataset_path
109+
)
110+
111+
if len(args.max_concurrency) != len(args.s_itl_base_model):
112+
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
113+
114+
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
115+
# Wramup
116+
print("Starting warmup...")
117+
with open(os.devnull, "w") as f:
118+
with contextlib.redirect_stdout(f):
119+
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
120+
121+
# Benchmark
122+
record = send_one_batch(base_url, max_concurrency, input_requests, False)
123+
124+
metric_header = f"Speed up"
125+
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
126+
for draft_token_step in args.draft_token_steps:
127+
speedup = calculate_speedup(
128+
args.acceptance_rate,
129+
draft_token_step,
130+
s_itl,
131+
record["mean_s_itl_ms"],
132+
)
133+
print(
134+
"{:<40} {:<10.2f}".format(
135+
f"Speed up on {draft_token_step} steps draft", speedup
136+
)
137+
)
138+
print("=" * 50)
139+
140+
141+
if __name__ == "__main__":
142+
parser = argparse.ArgumentParser()
143+
parser.add_argument(
144+
"--host",
145+
type=str,
146+
default="127.0.0.1",
147+
)
148+
parser.add_argument(
149+
"--port",
150+
type=str,
151+
default="8000",
152+
)
153+
parser.add_argument(
154+
"--max-concurrency",
155+
type=int,
156+
nargs="+",
157+
default=(1, 2, 4, 8, 16, 32),
158+
)
159+
parser.add_argument(
160+
"--num-prompts",
161+
type=int,
162+
default=128,
163+
)
164+
parser.add_argument(
165+
"--acceptance-rate",
166+
type=float,
167+
default=0.8,
168+
)
169+
parser.add_argument(
170+
"--draft-token-steps",
171+
type=int,
172+
nargs="+",
173+
default=(1, 2),
174+
)
175+
parser.add_argument(
176+
"--s_itl-base-model",
177+
type=float,
178+
nargs="+",
179+
)
180+
parser.add_argument(
181+
"--dataset-name",
182+
type=str,
183+
default="EBChat",
184+
)
185+
parser.add_argument(
186+
"--dataset-path",
187+
type=str,
188+
)
189+
args = parser.parse_args()
190+
191+
main(args)

build.sh

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ BUILD_WHEEL=${1:-1}
1818
PYTHON_VERSION=${2:-"python"}
1919
export python=$PYTHON_VERSION
2020
FD_CPU_USE_BF16=${3:-"false"}
21+
# FD_BUILDING_ARCS: Specify target CUDA architectures for custom ops, e.g., "[80, 90, 100]".
22+
# For SM90 (Hopper), use 90. For SM100 (Blackwell), use 100.
23+
# These will be translated to 90a / 100a in setup_ops.py for specific features.
2124
FD_BUILDING_ARCS=${4:-""}
2225

2326

@@ -74,8 +77,10 @@ function copy_ops(){
7477
is_rocm=`$python -c "import paddle; print(paddle.is_compiled_with_rocm())"`
7578
if [ "$is_rocm" = "True" ]; then
7679
DEVICE_TYPE="rocm"
80+
mkdir -p ../fastdeploy/model_executor/ops/base
81+
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
7782
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gpu
78-
echo -e "ROCM ops have been copy to fastdeploy"
83+
echo -e "BASE and ROCM ops have been copy to fastdeploy"
7984
return
8085
fi
8186
mkdir -p ../fastdeploy/model_executor/ops/base
@@ -104,6 +109,23 @@ function copy_ops(){
104109
return
105110
fi
106111

112+
if_corex=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device(\"iluvatar_gpu\"))"`
113+
if [ "$if_corex" = "True" ]; then
114+
DEVICE_TYPE="iluvatar-gpu"
115+
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
116+
cp -r ./${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/iluvatar
117+
echo -e "BASE and Iluvatar ops have been copy to fastdeploy"
118+
return
119+
fi
120+
121+
is_gcu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('gcu'))"`
122+
if [ "$is_gcu" = "True" ]; then
123+
DEVICE_TYPE="gcu"
124+
cp -r ${OPS_TMP_DIR}/${WHEEL_NAME}/* ../fastdeploy/model_executor/ops/gcu
125+
echo -e "gcu ops have been copy to fastdeploy"
126+
return
127+
fi
128+
107129
DEVICE_TYPE="cpu"
108130
cp -r ./${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
109131
cd ../../../../
@@ -163,25 +185,17 @@ function build_and_install() {
163185
exit 1
164186
fi
165187
echo -e "${BLUE}[build]${NONE} ${GREEN}build fastdeploy wheel success${NONE}\n"
166-
167-
echo -e "${BLUE}[install]${NONE} installing fastdeploy..."
168-
cd $DIST_DIR
169-
find . -name "fastdeploy*.whl" | xargs ${python} -m pip install --force-reinstall --no-cache-dir
170-
if [ $? -ne 0 ]; then
171-
cd ..
172-
echo -e "${RED}[FAIL]${NONE} install fastdeploy wheel failed"
173-
exit 1
174-
fi
175-
echo -e "${BLUE}[install]${NONE} ${GREEN}fastdeploy install success${NONE}\n"
176-
cd ..
177188
}
178189

179190
function version_info() {
180191
output_file="fastdeploy/version.txt"
181192
fastdeploy_git_commit_id=$(git rev-parse HEAD)
182193
paddle_version=$(${python} -c "import paddle; print(paddle.__version__)")
183194
paddle_git_commit_id=$(${python} -c "import paddle; print(paddle.__git_commit__)")
184-
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
195+
cuda_version="nvcc-not-installed"
196+
if command -v nvcc &> /dev/null; then
197+
cuda_version=$(nvcc -V | grep -Po "(?<=release )[\d.]+(?=, V)")
198+
fi
185199
cxx_version=$(g++ --version | head -n 1 | grep -Po "(?<=\) )[\d.]+")
186200

187201
echo "fastdeploy GIT COMMIT ID: $fastdeploy_git_commit_id" > $output_file
@@ -246,7 +260,7 @@ if [ "$BUILD_WHEEL" -eq 1 ]; then
246260
echo -e "${GREEN}wheel saved under${NONE} ${RED}${BOLD}./dist${NONE}"
247261

248262
# install wheel
249-
${python} -m pip install ./dist/fastdeploy*.whl
263+
${python} -m pip install ./dist/fastdeploy*.whl --force-reinstall --no-cache-dir
250264
echo -e "${GREEN}wheel install success${NONE}\n"
251265

252266
trap : 0

0 commit comments

Comments
 (0)