|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import csv |
| 5 | +import os |
| 6 | +import random |
| 7 | +from datetime import datetime |
| 8 | + |
| 9 | +import flashinfer |
| 10 | +import torch |
| 11 | + |
| 12 | +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 |
| 13 | + |
| 14 | +# KV Cache Layout for TRT-LLM |
| 15 | +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) |
| 16 | + |
| 17 | + |
| 18 | +def to_float8(x, dtype=torch.float8_e4m3fn): |
| 19 | + finfo = torch.finfo(dtype) |
| 20 | + min_val, max_val = x.aminmax() |
| 21 | + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) |
| 22 | + scale = finfo.max / amax * 0.1 |
| 23 | + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) |
| 24 | + return x_scl_sat.to(dtype), scale.float().reciprocal() |
| 25 | + |
| 26 | + |
| 27 | +@torch.no_grad() |
| 28 | +def benchmark_decode( |
| 29 | + num_seqs, |
| 30 | + max_seq_len, |
| 31 | + page_size=16, |
| 32 | + dtype=torch.bfloat16, |
| 33 | + kv_layout="HND", |
| 34 | + num_kv_heads=8, |
| 35 | + kv_cache_dtype="auto", |
| 36 | + head_dim=128, |
| 37 | + warmup=10, |
| 38 | + trials=20, |
| 39 | +): |
| 40 | + torch.set_default_device("cuda") |
| 41 | + device = "cuda" |
| 42 | + torch.manual_seed(0) |
| 43 | + |
| 44 | + # Currently only HEAD_GRP_SIZE == 8 is supported |
| 45 | + HEAD_GRP_SIZE = 8 |
| 46 | + MAX_SEQ_LEN = max_seq_len |
| 47 | + |
| 48 | + # large number to reduce kv_cache reuse |
| 49 | + NUM_BLOCKS = int(256000 / page_size) |
| 50 | + |
| 51 | + workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) |
| 52 | + |
| 53 | + # For decode, batch_size is num_decode_token |
| 54 | + num_qo_heads = num_kv_heads * HEAD_GRP_SIZE |
| 55 | + sm_scale = float(1.0 / (head_dim**0.5)) |
| 56 | + q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) |
| 57 | + kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] |
| 58 | + |
| 59 | + max_kv_len = max(kv_lens) |
| 60 | + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) |
| 61 | + max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size |
| 62 | + |
| 63 | + block_tables = torch.randint( |
| 64 | + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 |
| 65 | + ) |
| 66 | + |
| 67 | + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) |
| 68 | + kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) |
| 69 | + k_scale = v_scale = 1.0 |
| 70 | + |
| 71 | + if kv_cache_dtype.startswith("fp8"): |
| 72 | + kv_cache, _ = to_float8(kv_cache) |
| 73 | + |
| 74 | + # Benchmark TRT decode |
| 75 | + def trt_decode(): |
| 76 | + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( |
| 77 | + q, |
| 78 | + kv_cache, |
| 79 | + workspace_buffer, |
| 80 | + num_qo_heads, |
| 81 | + num_kv_heads, |
| 82 | + sm_scale, |
| 83 | + block_tables, |
| 84 | + kv_lens_tensor, |
| 85 | + page_size, |
| 86 | + max_kv_len, |
| 87 | + kv_cache_dtype, |
| 88 | + k_scale, |
| 89 | + v_scale, |
| 90 | + ) |
| 91 | + |
| 92 | + def time_fn(fn, warmup=10, trials=20): |
| 93 | + torch.cuda.synchronize() |
| 94 | + start = torch.cuda.Event(enable_timing=True) |
| 95 | + end = torch.cuda.Event(enable_timing=True) |
| 96 | + times = [] |
| 97 | + for i in range(warmup): |
| 98 | + fn() |
| 99 | + for i in range(trials): |
| 100 | + start.record() |
| 101 | + fn() |
| 102 | + end.record() |
| 103 | + torch.cuda.synchronize() |
| 104 | + times.append(start.elapsed_time(end)) # ms |
| 105 | + return sum(times) / len(times), torch.std(torch.tensor(times)) |
| 106 | + |
| 107 | + # TRT Decode |
| 108 | + trt_mean, trt_std = time_fn(trt_decode) |
| 109 | + |
| 110 | + kv_indptr = [0] |
| 111 | + kv_indices = [] |
| 112 | + kv_last_page_lens = [] |
| 113 | + for i in range(num_seqs): |
| 114 | + seq_len = kv_lens[i] |
| 115 | + assert seq_len > 0 |
| 116 | + num_blocks = (seq_len + page_size - 1) // page_size |
| 117 | + kv_indices.extend(block_tables[i, :num_blocks]) |
| 118 | + kv_indptr.append(kv_indptr[-1] + num_blocks) |
| 119 | + kv_last_page_len = seq_len % page_size |
| 120 | + if kv_last_page_len == 0: |
| 121 | + kv_last_page_len = page_size |
| 122 | + kv_last_page_lens.append(kv_last_page_len) |
| 123 | + |
| 124 | + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) |
| 125 | + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) |
| 126 | + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) |
| 127 | + |
| 128 | + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( |
| 129 | + workspace_buffer, |
| 130 | + kv_layout, |
| 131 | + use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), |
| 132 | + ) |
| 133 | + |
| 134 | + wrapper.plan( |
| 135 | + kv_indptr, |
| 136 | + kv_indices, |
| 137 | + kv_last_page_lens, |
| 138 | + num_qo_heads, |
| 139 | + num_kv_heads, |
| 140 | + head_dim, |
| 141 | + page_size, |
| 142 | + "NONE", |
| 143 | + q_data_type=dtype, |
| 144 | + kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, |
| 145 | + ) |
| 146 | + |
| 147 | + def baseline_decode(): |
| 148 | + return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale) |
| 149 | + |
| 150 | + baseline_mean, baseline_std = time_fn(baseline_decode) |
| 151 | + |
| 152 | + # Calculate percentage speedup (positive means TRT is faster) |
| 153 | + speedup_percent = (baseline_mean - trt_mean) / baseline_mean |
| 154 | + |
| 155 | + print( |
| 156 | + f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" |
| 157 | + f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" |
| 158 | + ) |
| 159 | + |
| 160 | + # Return results for CSV writing |
| 161 | + return { |
| 162 | + "num_seqs": num_seqs, |
| 163 | + "trt_mean": trt_mean, |
| 164 | + "trt_std": trt_std.item(), |
| 165 | + "baseline_mean": baseline_mean, |
| 166 | + "baseline_std": baseline_std.item(), |
| 167 | + "speedup_percent": speedup_percent, |
| 168 | + "q_dtype": str(dtype), |
| 169 | + "kv_cache_dtype": kv_cache_dtype, |
| 170 | + "page_size": page_size, |
| 171 | + "num_kv_heads": num_kv_heads, |
| 172 | + "head_dim": head_dim, |
| 173 | + "max_seq_len": max_seq_len, |
| 174 | + } |
| 175 | + |
| 176 | + |
| 177 | +def write_results_to_csv(results, filename=None): |
| 178 | + """Write benchmark results to CSV file.""" |
| 179 | + if filename is None: |
| 180 | + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 181 | + filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" |
| 182 | + |
| 183 | + fieldnames = [ |
| 184 | + "num_seqs", |
| 185 | + "trt_mean", |
| 186 | + "trt_std", |
| 187 | + "baseline_mean", |
| 188 | + "baseline_std", |
| 189 | + "speedup_percent", |
| 190 | + "q_dtype", |
| 191 | + "kv_cache_dtype", |
| 192 | + "page_size", |
| 193 | + "num_kv_heads", |
| 194 | + "head_dim", |
| 195 | + "max_seq_len", |
| 196 | + ] |
| 197 | + |
| 198 | + file_exists = os.path.exists(filename) |
| 199 | + |
| 200 | + with open(filename, "a", newline="") as csvfile: |
| 201 | + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) |
| 202 | + |
| 203 | + if not file_exists: |
| 204 | + writer.writeheader() |
| 205 | + |
| 206 | + for result in results: |
| 207 | + writer.writerow(result) |
| 208 | + |
| 209 | + print(f"Results written to {filename}") |
| 210 | + |
| 211 | + |
| 212 | +if __name__ == "__main__": |
| 213 | + num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] |
| 214 | + max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] |
| 215 | + all_results = [] |
| 216 | + |
| 217 | + print("Running benchmark for kv_cache_dtype: bfloat16") |
| 218 | + print( |
| 219 | + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" |
| 220 | + ) |
| 221 | + for max_seq_len in max_seq_lens: |
| 222 | + for bs in num_seqs: |
| 223 | + result = benchmark_decode( |
| 224 | + bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto" |
| 225 | + ) |
| 226 | + all_results.append(result) |
| 227 | + |
| 228 | + print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8") |
| 229 | + print( |
| 230 | + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" |
| 231 | + ) |
| 232 | + for max_seq_len in max_seq_lens: |
| 233 | + for bs in num_seqs: |
| 234 | + result = benchmark_decode( |
| 235 | + bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8" |
| 236 | + ) |
| 237 | + all_results.append(result) |
| 238 | + |
| 239 | + # Write all results to CSV |
| 240 | + write_results_to_csv(all_results) |
0 commit comments