Skip to content

Commit f7cad30

Browse files
authored
[Feature] Add speculative decoding simulation benchmark. (#2751)
* Add speculative decoding simulation benchmark * Fix the name of the parameter
1 parent 6b10c19 commit f7cad30

File tree

8 files changed

+246
-7
lines changed

8 files changed

+246
-7
lines changed

benchmarks/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,30 @@ python benchmark_serving.py \
105105
--save-result > infer_log.txt 2>&1 &
106106
```
107107

108+
### 投机解码性能测试工具
109+
110+
#### 使用方式:
111+
112+
```bash
113+
python benchmarks/benchmark_mtp.py \
114+
--host 127.0.0.1 --port 8000 \
115+
--max-concurrency 16 32 64 96 --num-prompts 256 \
116+
--acceptance-rate 0.8 --draft-token-steps 1 2 3 \
117+
--s_itl-base-model 15.88 22.84 16.47 16.93 \
118+
--dataset-name EBChat \
119+
--dataset-path ./filtered_sharedgpt_2000_input_1136_output_200_fd.json
120+
```
121+
122+
#### 参数说明
123+
124+
```bash
125+
--host:服务ip地址,用于组url
126+
--port:服务HTTP端口,用于组url
127+
--max-concurrency:测试并发数
128+
--num-prompts:总计发送多少条请求
129+
--acceptance-rate:投机解码的模拟接受率
130+
--draft-token-steps:投机解码的步数
131+
--s_itl-base-model:主模型的解码延迟,可由上述的性能压测工具获得,与batch-size一一对应
132+
--dataset-name:指定数据集类,指定为"EBChat"可读取转存的FD格式数据集
133+
--dataset-path:测试数据集路径
134+
```

benchmarks/benchmark_mtp.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
17+
import argparse
18+
import asyncio
19+
import contextlib
20+
import os
21+
import signal
22+
import socket
23+
import subprocess
24+
import time
25+
from typing import Union
26+
27+
import openai
28+
import yaml
29+
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
30+
from benchmark_serving import benchmark
31+
32+
33+
def prepare_input_requests(
34+
num_prompts: int, dataset_name: str, dataset_path: str
35+
) -> Union[EBDataset, EBChatDataset]:
36+
dataset_mapping = {
37+
"EB": lambda: EBDataset(dataset_path=dataset_path).sample(
38+
num_requests=num_prompts
39+
),
40+
"EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
41+
num_requests=num_prompts
42+
),
43+
}
44+
45+
try:
46+
input_requests = dataset_mapping[dataset_name]()
47+
except KeyError as err:
48+
raise ValueError(f"Unknown dataset: {dataset_name}") from err
49+
50+
return input_requests
51+
52+
53+
class FakeTokenizer:
54+
def encode(self, text: str, add_special_tokens: bool = False):
55+
return []
56+
57+
58+
def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
59+
selected_percentile_metrics = ["s_itl"]
60+
selected_percentiles = []
61+
# Run benchmark
62+
results = asyncio.run(
63+
benchmark(
64+
backend="openai-chat",
65+
api_url=f"{base_url}/v1/chat/completions",
66+
base_url=base_url,
67+
model_id="default",
68+
model_name="default",
69+
input_requests=input_requests,
70+
hyper_parameters={},
71+
logprobs=None,
72+
request_rate=float("inf"),
73+
burstiness=1.0,
74+
disable_tqdm=disable_tqdm,
75+
profile=False,
76+
selected_percentile_metrics=selected_percentile_metrics,
77+
selected_percentiles=selected_percentiles,
78+
ignore_eos=False,
79+
goodput_config_dict=None,
80+
max_concurrency=max_concurrency,
81+
lora_modules=None,
82+
extra_body=None,
83+
)
84+
)
85+
86+
record = {
87+
"mean_s_itl_ms": results["mean_s_itl_ms"],
88+
}
89+
90+
return record
91+
92+
93+
def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):
94+
95+
tmp = 0.0
96+
for i in range(draft_token_step):
97+
tmp += pow(acceptance_rate, i + 1)
98+
99+
r_ac = tmp / (1 + tmp)
100+
101+
return t_ori / ((1 - r_ac) * t_mtp)
102+
103+
104+
def main(args):
105+
base_url = f"http://{args.host}:{args.port}"
106+
107+
input_requests = prepare_input_requests(
108+
args.num_prompts, args.dataset_name, args.dataset_path
109+
)
110+
111+
if len(args.max_concurrency) != len(args.s_itl_base_model):
112+
raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")
113+
114+
for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
115+
# Wramup
116+
print("Starting warmup...")
117+
with open(os.devnull, "w") as f:
118+
with contextlib.redirect_stdout(f):
119+
send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)
120+
121+
# Benchmark
122+
record = send_one_batch(base_url, max_concurrency, input_requests, False)
123+
124+
metric_header = f"Speed up"
125+
print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
126+
for draft_token_step in args.draft_token_steps:
127+
speedup = calculate_speedup(
128+
args.acceptance_rate,
129+
draft_token_step,
130+
s_itl,
131+
record["mean_s_itl_ms"],
132+
)
133+
print(
134+
"{:<40} {:<10.2f}".format(
135+
f"Speed up on {draft_token_step} steps draft", speedup
136+
)
137+
)
138+
print("=" * 50)
139+
140+
141+
if __name__ == "__main__":
142+
parser = argparse.ArgumentParser()
143+
parser.add_argument(
144+
"--host",
145+
type=str,
146+
default="127.0.0.1",
147+
)
148+
parser.add_argument(
149+
"--port",
150+
type=str,
151+
default="8000",
152+
)
153+
parser.add_argument(
154+
"--max-concurrency",
155+
type=int,
156+
nargs="+",
157+
default=(1, 2, 4, 8, 16, 32),
158+
)
159+
parser.add_argument(
160+
"--num-prompts",
161+
type=int,
162+
default=128,
163+
)
164+
parser.add_argument(
165+
"--acceptance-rate",
166+
type=float,
167+
default=0.8,
168+
)
169+
parser.add_argument(
170+
"--draft-token-steps",
171+
type=int,
172+
nargs="+",
173+
default=(1, 2),
174+
)
175+
parser.add_argument(
176+
"--s_itl-base-model",
177+
type=float,
178+
nargs="+",
179+
)
180+
parser.add_argument(
181+
"--dataset-name",
182+
type=str,
183+
default="EBChat",
184+
)
185+
parser.add_argument(
186+
"--dataset-path",
187+
type=str,
188+
)
189+
args = parser.parse_args()
190+
191+
main(args)

custom_ops/gpu_ops/speculate_decoding/speculate_verify.cu

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ __global__ void speculate_verify(
7373
const int *output_cum_offsets, const int *actual_candidate_len,
7474
const int real_bsz, const int max_draft_tokens, const int end_length,
7575
const int max_seq_len, const int max_candidate_len, const int verify_window,
76-
const bool prefill_one_step_stop) {
76+
const bool prefill_one_step_stop, const bool benchmark_mode) {
7777
const int bid = threadIdx.x;
7878
// verify and set stop flags
7979
int accept_num_now = 1;
@@ -95,6 +95,9 @@ __global__ void speculate_verify(
9595
// printf("seq_lens_this_time[%d]-1: %d \n",bid,
9696
// seq_lens_this_time[bid]-1);
9797
for (; i < seq_lens_this_time[bid] - 1; i++) {
98+
if (benchmark_mode) {
99+
break;
100+
}
98101
if (seq_lens_encoder[bid] != 0) {
99102
break;
100103
}
@@ -246,7 +249,7 @@ void SpeculateVerify(
246249
const paddle::Tensor &output_cum_offsets,
247250
const paddle::Tensor &actual_candidate_len,
248251
const paddle::Tensor &actual_draft_token_nums, const paddle::Tensor &topp,
249-
int max_seq_len, int verify_window, bool enable_topp) {
252+
int max_seq_len, int verify_window, bool enable_topp, bool benchmark_mode) {
250253
// printf("Enter speculate update\n");
251254
auto bsz = accept_tokens.shape()[0];
252255
int real_bsz = seq_lens_this_time.shape()[0];
@@ -301,7 +304,7 @@ void SpeculateVerify(
301304
is_block_step.data<bool>(), output_cum_offsets.data<int>(),
302305
actual_candidate_len.data<int>(), real_bsz, max_draft_tokens,
303306
end_length, max_seq_len, max_candidate_len, verify_window,
304-
prefill_one_step_stop);
307+
prefill_one_step_stop, benchmark_mode);
305308
} else {
306309
speculate_verify<false, true>
307310
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
@@ -317,7 +320,7 @@ void SpeculateVerify(
317320
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
318321
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
319322
real_bsz, max_draft_tokens, end_length, max_seq_len,
320-
max_candidate_len, verify_window, prefill_one_step_stop);
323+
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
321324
}
322325
} else {
323326
if (enable_topp) {
@@ -335,7 +338,7 @@ void SpeculateVerify(
335338
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
336339
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
337340
real_bsz, max_draft_tokens, end_length, max_seq_len,
338-
max_candidate_len, verify_window, prefill_one_step_stop);
341+
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
339342
} else {
340343
speculate_verify<false, false>
341344
<<<1, BlockSize, 0, accept_tokens.stream()>>>(
@@ -351,7 +354,7 @@ void SpeculateVerify(
351354
end_tokens.data<int64_t>(), is_block_step.data<bool>(),
352355
output_cum_offsets.data<int>(), actual_candidate_len.data<int>(),
353356
real_bsz, max_draft_tokens, end_length, max_seq_len,
354-
max_candidate_len, verify_window, prefill_one_step_stop);
357+
max_candidate_len, verify_window, prefill_one_step_stop, benchmark_mode);
355358
}
356359
}
357360

@@ -366,7 +369,7 @@ PD_BUILD_STATIC_OP(speculate_verify)
366369
"actual_candidate_len", "actual_draft_token_nums", "topp"})
367370
.Outputs({"accept_tokens_out", "accept_num_out", "step_idx_out",
368371
"stop_flags_out"})
369-
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool"})
372+
.Attrs({"max_seq_len: int", "verify_window: int", "enable_topp: bool", "benchmark_mode: bool"})
370373
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
371374
{"accept_num", "accept_num_out"},
372375
{"step_idx", "step_idx_out"},

fastdeploy/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,10 @@ class SpeculativeConfig:
238238
# A trick method is currently used to enable this sharing.
239239
# This will be replaced with a more standardized solution in the future.
240240
sharing_model = None
241+
# During benchmarking, we need to enforce that the number of accepted tokens is 1.
242+
# This means no tokens from MTP are accepted.
243+
# This ensures that the specified simulation acceptance rate is not affected.
244+
benchmark_mode: bool = False
241245

242246

243247
@dataclass

fastdeploy/engine/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ class SpeculativeConfig:
337337
model_name_or_path (Optional[str]): Path of the model.
338338
quantization (str): Quantization method for draft model, default is WINT8.
339339
max_model_len: Optional[int]: Maximum model length for draft model.
340+
benchmark_mode (bool): Whether to use benchmark mode.
340341
"""
341342

342343
def __init__(self,
@@ -345,12 +346,14 @@ def __init__(self,
345346
model: Optional[str] = None,
346347
quantization: Optional[str] = "WINT8",
347348
max_model_len: Optional[int] = None,
349+
benchmark_mode: bool = False,
348350
**kwargs):
349351
self.model_name_or_path = model
350352
self.method = method
351353
self.num_speculative_tokens = num_speculative_tokens
352354
self.quantization = quantization
353355
self.max_model_len = max_model_len
356+
self.benchmark_mode = benchmark_mode
354357
# Fixed now
355358
self.num_gpu_block_expand_ratio = 1
356359
self.num_extra_cache_layer = 0

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,7 @@ def _start_worker_service(self):
10301030
f" --speculative_max_draft_token_num {self.cfg.speculative_config.num_speculative_tokens}"
10311031
f" --speculative_model_name_or_path {self.cfg.speculative_config.model_name_or_path}"
10321032
f" --speculative_model_quantization {self.cfg.speculative_config.quantization}"
1033+
f" --speculative_benchmark_mode {self.cfg.speculative_config.benchmark_mode}"
10331034
f" --max_capture_batch_size {self.cfg.max_capture_batch_size}"
10341035
f" --guided_decoding_backend {self.cfg.guided_decoding_backend}"
10351036
f" --load_strategy {self.cfg.model_config.load_strategy}")

fastdeploy/model_executor/layers/sample/sampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def __init__(self, fd_config: FDConfig):
235235
raise NotImplementedError()
236236
self.speculative_verify_window = fd_config.speculative_config.verify_window
237237
self.speculative_max_candidate_len = fd_config.speculative_config.max_candidate_len
238+
self.speculative_benchmark_mode = fd_config.speculative_config.benchmark_mode
238239

239240
def pre_process(self, skip_idx_list: List[int] = []):
240241
""" pre process before running """
@@ -309,6 +310,7 @@ def forward_cuda(
309310
max_model_len,
310311
self.speculative_verify_window,
311312
True, # enable_topp
313+
self.speculative_benchmark_mode,
312314
)
313315

314316
return None

fastdeploy/worker/worker_process.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,11 @@ def parse_args():
494494
default="WINT8",
495495
type=str,
496496
)
497+
parser.add_argument(
498+
"--speculative_benchmark_mode",
499+
default="false",
500+
type=str,
501+
)
497502
parser.add_argument("--max_num_batched_tokens",
498503
type=int,
499504
default=2048,
@@ -625,6 +630,9 @@ def initialize_fd_config(config_or_args) -> FDConfig:
625630
speculative_config.num_speculative_tokens = getattr(config_or_args, 'speculative_max_draft_token_num', 0)
626631
speculative_config.model_name_or_path = getattr(config_or_args, 'speculative_model_name_or_path', None)
627632
speculative_config.quantization = getattr(config_or_args, 'speculative_model_quantization', None)
633+
speculative_config.benchmark_mode = (
634+
getattr(config_or_args, "speculative_benchmark_mode", "false").lower() == "true"
635+
)
628636

629637
# Update parallel config
630638
parallel_config.engine_pid = getattr(config_or_args, 'engine_pid', None)

0 commit comments

Comments
 (0)