Skip to content

Commit 7bd4c37

Browse files
pavanimajetywenscarlmgoin
authored
[Core] Add Flashinfer TRTLLM Backend for Flashinfer decode path (SM100). (#19825)
Signed-off-by: Pavani Majety <pmajety@nvidia.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: shuw <shuw@nvidia.com> Co-authored-by: mgoin <mgoin64@gmail.com>
1 parent 8020e98 commit 7bd4c37

File tree

8 files changed

+667
-56
lines changed

8 files changed

+667
-56
lines changed
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import csv
5+
import os
6+
import random
7+
from datetime import datetime
8+
9+
import flashinfer
10+
import torch
11+
12+
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
13+
14+
# KV Cache Layout for TRT-LLM
15+
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
16+
17+
18+
def to_float8(x, dtype=torch.float8_e4m3fn):
19+
finfo = torch.finfo(dtype)
20+
min_val, max_val = x.aminmax()
21+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
22+
scale = finfo.max / amax * 0.1
23+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
24+
return x_scl_sat.to(dtype), scale.float().reciprocal()
25+
26+
27+
@torch.no_grad()
28+
def benchmark_decode(
29+
num_seqs,
30+
max_seq_len,
31+
page_size=16,
32+
dtype=torch.bfloat16,
33+
kv_layout="HND",
34+
num_kv_heads=8,
35+
kv_cache_dtype="auto",
36+
head_dim=128,
37+
warmup=10,
38+
trials=20,
39+
):
40+
torch.set_default_device("cuda")
41+
device = "cuda"
42+
torch.manual_seed(0)
43+
44+
# Currently only HEAD_GRP_SIZE == 8 is supported
45+
HEAD_GRP_SIZE = 8
46+
MAX_SEQ_LEN = max_seq_len
47+
48+
# large number to reduce kv_cache reuse
49+
NUM_BLOCKS = int(256000 / page_size)
50+
51+
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
52+
53+
# For decode, batch_size is num_decode_token
54+
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
55+
sm_scale = float(1.0 / (head_dim**0.5))
56+
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
57+
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
58+
59+
max_kv_len = max(kv_lens)
60+
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
61+
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
62+
63+
block_tables = torch.randint(
64+
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
65+
)
66+
67+
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
68+
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
69+
k_scale = v_scale = 1.0
70+
71+
if kv_cache_dtype.startswith("fp8"):
72+
kv_cache, _ = to_float8(kv_cache)
73+
74+
# Benchmark TRT decode
75+
def trt_decode():
76+
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
77+
q,
78+
kv_cache,
79+
workspace_buffer,
80+
num_qo_heads,
81+
num_kv_heads,
82+
sm_scale,
83+
block_tables,
84+
kv_lens_tensor,
85+
page_size,
86+
max_kv_len,
87+
kv_cache_dtype,
88+
k_scale,
89+
v_scale,
90+
)
91+
92+
def time_fn(fn, warmup=10, trials=20):
93+
torch.cuda.synchronize()
94+
start = torch.cuda.Event(enable_timing=True)
95+
end = torch.cuda.Event(enable_timing=True)
96+
times = []
97+
for i in range(warmup):
98+
fn()
99+
for i in range(trials):
100+
start.record()
101+
fn()
102+
end.record()
103+
torch.cuda.synchronize()
104+
times.append(start.elapsed_time(end)) # ms
105+
return sum(times) / len(times), torch.std(torch.tensor(times))
106+
107+
# TRT Decode
108+
trt_mean, trt_std = time_fn(trt_decode)
109+
110+
kv_indptr = [0]
111+
kv_indices = []
112+
kv_last_page_lens = []
113+
for i in range(num_seqs):
114+
seq_len = kv_lens[i]
115+
assert seq_len > 0
116+
num_blocks = (seq_len + page_size - 1) // page_size
117+
kv_indices.extend(block_tables[i, :num_blocks])
118+
kv_indptr.append(kv_indptr[-1] + num_blocks)
119+
kv_last_page_len = seq_len % page_size
120+
if kv_last_page_len == 0:
121+
kv_last_page_len = page_size
122+
kv_last_page_lens.append(kv_last_page_len)
123+
124+
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
125+
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
126+
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
127+
128+
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
129+
workspace_buffer,
130+
kv_layout,
131+
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
132+
)
133+
134+
wrapper.plan(
135+
kv_indptr,
136+
kv_indices,
137+
kv_last_page_lens,
138+
num_qo_heads,
139+
num_kv_heads,
140+
head_dim,
141+
page_size,
142+
"NONE",
143+
q_data_type=dtype,
144+
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
145+
)
146+
147+
def baseline_decode():
148+
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
149+
150+
baseline_mean, baseline_std = time_fn(baseline_decode)
151+
152+
# Calculate percentage speedup (positive means TRT is faster)
153+
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
154+
155+
print(
156+
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
157+
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
158+
)
159+
160+
# Return results for CSV writing
161+
return {
162+
"num_seqs": num_seqs,
163+
"trt_mean": trt_mean,
164+
"trt_std": trt_std.item(),
165+
"baseline_mean": baseline_mean,
166+
"baseline_std": baseline_std.item(),
167+
"speedup_percent": speedup_percent,
168+
"q_dtype": str(dtype),
169+
"kv_cache_dtype": kv_cache_dtype,
170+
"page_size": page_size,
171+
"num_kv_heads": num_kv_heads,
172+
"head_dim": head_dim,
173+
"max_seq_len": max_seq_len,
174+
}
175+
176+
177+
def write_results_to_csv(results, filename=None):
178+
"""Write benchmark results to CSV file."""
179+
if filename is None:
180+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
181+
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
182+
183+
fieldnames = [
184+
"num_seqs",
185+
"trt_mean",
186+
"trt_std",
187+
"baseline_mean",
188+
"baseline_std",
189+
"speedup_percent",
190+
"q_dtype",
191+
"kv_cache_dtype",
192+
"page_size",
193+
"num_kv_heads",
194+
"head_dim",
195+
"max_seq_len",
196+
]
197+
198+
file_exists = os.path.exists(filename)
199+
200+
with open(filename, "a", newline="") as csvfile:
201+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
202+
203+
if not file_exists:
204+
writer.writeheader()
205+
206+
for result in results:
207+
writer.writerow(result)
208+
209+
print(f"Results written to {filename}")
210+
211+
212+
if __name__ == "__main__":
213+
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
214+
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
215+
all_results = []
216+
217+
print("Running benchmark for kv_cache_dtype: bfloat16")
218+
print(
219+
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
220+
)
221+
for max_seq_len in max_seq_lens:
222+
for bs in num_seqs:
223+
result = benchmark_decode(
224+
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
225+
)
226+
all_results.append(result)
227+
228+
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
229+
print(
230+
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
231+
)
232+
for max_seq_len in max_seq_lens:
233+
for bs in num_seqs:
234+
result = benchmark_decode(
235+
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
236+
)
237+
all_results.append(result)
238+
239+
# Write all results to CSV
240+
write_results_to_csv(all_results)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from typing import Optional
4+
5+
import flashinfer
6+
import pytest
7+
import torch
8+
9+
from vllm.platforms import current_platform
10+
11+
if not current_platform.is_device_capability(100):
12+
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
13+
allow_module_level=True)
14+
15+
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
16+
17+
# KV Cache Layout for TRT-LLM
18+
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
19+
20+
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
21+
HEAD_SIZES = [128]
22+
BLOCK_SIZES = [16, 32]
23+
DTYPES = [torch.float16, torch.bfloat16]
24+
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
25+
SOFT_CAPS = [None, 30.0, 50.0]
26+
27+
28+
def to_float8(x, dtype=torch.float8_e4m3fn):
29+
finfo = torch.finfo(dtype)
30+
min_val, max_val = x.aminmax()
31+
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
32+
scale = finfo.max / amax * 0.1
33+
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
34+
return x_scl_sat.to(dtype), scale.float().reciprocal()
35+
36+
37+
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
38+
@pytest.mark.parametrize("num_heads", NUM_HEADS)
39+
@pytest.mark.parametrize("head_size", HEAD_SIZES)
40+
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
41+
@pytest.mark.parametrize("kv_layout", ["HND"])
42+
@pytest.mark.parametrize("dtype", DTYPES)
43+
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
44+
@torch.inference_mode
45+
def test_flashinfer_trtllm_decode_with_baseline(
46+
kv_lens: list[int],
47+
num_heads: tuple[int, int],
48+
head_size: int,
49+
dtype: torch.dtype,
50+
block_size: int,
51+
soft_cap: Optional[float],
52+
kv_layout: str,
53+
) -> None:
54+
torch.set_default_device("cuda")
55+
current_platform.seed_everything(0)
56+
num_seqs = len(kv_lens)
57+
num_query_heads = num_heads[0]
58+
num_kv_heads = num_heads[1]
59+
60+
assert num_query_heads % num_kv_heads == 0
61+
max_kv_len = max(kv_lens)
62+
scale = head_size**-0.5
63+
64+
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
65+
kv_cache_shape = None
66+
if kv_layout == "NHD":
67+
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
68+
elif kv_layout == "HND":
69+
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
70+
else:
71+
raise ValueError(f"Invalid kv_layout: {kv_layout}")
72+
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
73+
74+
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
75+
block_tables = torch.randint(0,
76+
NUM_BLOCKS,
77+
(num_seqs, max_num_blocks_per_seq),
78+
dtype=torch.int32)
79+
k_scale = v_scale = 1.0
80+
kv_indptr = [0]
81+
kv_indices = []
82+
kv_last_page_lens = []
83+
for i in range(num_seqs):
84+
seq_len = kv_lens[i]
85+
assert seq_len > 0
86+
num_blocks = (seq_len + block_size - 1) // block_size
87+
kv_indices.extend(block_tables[i, :num_blocks])
88+
kv_indptr.append(kv_indptr[-1] + num_blocks)
89+
kv_last_page_len = seq_len % block_size
90+
if kv_last_page_len == 0:
91+
kv_last_page_len = block_size
92+
kv_last_page_lens.append(kv_last_page_len)
93+
94+
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
95+
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
96+
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
97+
98+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
99+
wrapper = flashinfer.\
100+
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout,
101+
use_tensor_cores=(
102+
(num_query_heads//num_kv_heads) > 4)
103+
)
104+
wrapper.plan(kv_indptr,
105+
kv_indices,
106+
kv_last_page_lens,
107+
num_query_heads,
108+
num_kv_heads,
109+
head_size,
110+
block_size,
111+
"NONE",
112+
q_data_type=dtype,
113+
kv_data_type=dtype,
114+
logits_soft_cap=soft_cap)
115+
116+
output = wrapper.run(query, key_value_cache, scale)
117+
118+
# TRTLLM Decode
119+
max_kv_len = max(kv_lens)
120+
kv_lens_tensor = torch.tensor(kv_lens,
121+
dtype=torch.int,
122+
device=query.device)
123+
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
124+
query.contiguous(),
125+
key_value_cache,
126+
workspace_buffer,
127+
num_query_heads,
128+
num_kv_heads,
129+
scale,
130+
block_tables,
131+
kv_lens_tensor,
132+
block_size,
133+
max_kv_len,
134+
"auto",
135+
k_scale,
136+
v_scale,
137+
)
138+
139+
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
140+
f"{torch.max(torch.abs(output - output_trtllm))}"

0 commit comments

Comments
 (0)