Skip to content

Commit 0bbac1c

Browse files
authored
[Bench] Add NVFP4 GEMM benchmark script (#20578)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent a3e4e85 commit 0bbac1c

File tree

1 file changed

+141
-0
lines changed

1 file changed

+141
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import argparse
4+
import copy
5+
import itertools
6+
7+
import torch
8+
from weight_shapes import WEIGHT_SHAPES
9+
10+
from vllm import _custom_ops as ops
11+
from vllm.platforms import current_platform
12+
from vllm.scalar_type import scalar_types
13+
from vllm.triton_utils import triton
14+
15+
if not current_platform.has_device_capability(100):
16+
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
17+
18+
19+
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
20+
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
21+
22+
PROVIDER_CFGS = {
23+
"torch-bf16": dict(enabled=True),
24+
"nvfp4": dict(no_a_quant=False, enabled=True),
25+
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
26+
}
27+
28+
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
29+
30+
31+
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
32+
# Compute global scale for weight
33+
b_amax = torch.abs(b).max().to(torch.float32)
34+
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
35+
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
36+
return b_fp4, scale_b_fp4, b_global_scale
37+
38+
39+
def build_nvfp4_runner(cfg, a, b, dtype, device):
40+
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
41+
42+
# Compute global scale for activation
43+
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
44+
a_amax = torch.abs(a).max().to(torch.float32)
45+
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
46+
47+
# Alpha for the GEMM operation
48+
alpha = 1.0 / (a_global_scale * b_global_scale)
49+
50+
if cfg["no_a_quant"]:
51+
# Pre-quantize activation
52+
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
53+
54+
def run():
55+
return ops.cutlass_scaled_fp4_mm(
56+
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
57+
)
58+
59+
return run
60+
61+
# Quantize activation on-the-fly
62+
def run():
63+
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
64+
return ops.cutlass_scaled_fp4_mm(
65+
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
66+
)
67+
68+
return run
69+
70+
71+
@triton.testing.perf_report(
72+
triton.testing.Benchmark(
73+
x_names=["batch_size"],
74+
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
75+
x_log=False,
76+
line_arg="provider",
77+
line_vals=_enabled,
78+
line_names=_enabled,
79+
ylabel="TFLOP/s (larger is better)",
80+
plot_name="BF16 vs NVFP4 GEMMs",
81+
args={},
82+
)
83+
)
84+
def benchmark(batch_size, provider, N, K):
85+
M = batch_size
86+
device = "cuda"
87+
dtype = torch.bfloat16
88+
89+
a = torch.randn((M, K), device=device, dtype=dtype)
90+
b = torch.randn((N, K), device=device, dtype=dtype)
91+
92+
quantiles = [0.5, 0.2, 0.8]
93+
94+
if provider == "torch-bf16":
95+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
96+
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
97+
)
98+
else:
99+
cfg = PROVIDER_CFGS[provider]
100+
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
101+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
102+
lambda: run_quant(), quantiles=quantiles
103+
)
104+
105+
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
106+
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
107+
108+
109+
def prepare_shapes(args):
110+
out = []
111+
for model, tp_size in itertools.product(args.models, args.tp_sizes):
112+
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
113+
KN[tp_dim] //= tp_size
114+
KN.append(model)
115+
out.append(KN)
116+
return out
117+
118+
119+
if __name__ == "__main__":
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument(
122+
"--models",
123+
nargs="+",
124+
type=str,
125+
default=["meta-llama/Llama-3.1-8B-Instruct"],
126+
choices=list(WEIGHT_SHAPES.keys()),
127+
)
128+
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
129+
args = parser.parse_args()
130+
131+
for K, N, model in prepare_shapes(args):
132+
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
133+
benchmark.run(
134+
print_data=True,
135+
show_plots=True,
136+
save_path=f"bench_nvfp4_res_n{N}_k{K}",
137+
N=N,
138+
K=K,
139+
)
140+
141+
print("Benchmark finished!")

0 commit comments

Comments
 (0)