|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +import argparse |
| 4 | +import copy |
| 5 | +import itertools |
| 6 | + |
| 7 | +import torch |
| 8 | +from weight_shapes import WEIGHT_SHAPES |
| 9 | + |
| 10 | +from vllm import _custom_ops as ops |
| 11 | +from vllm.platforms import current_platform |
| 12 | +from vllm.scalar_type import scalar_types |
| 13 | +from vllm.triton_utils import triton |
| 14 | + |
| 15 | +if not current_platform.has_device_capability(100): |
| 16 | + raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)") |
| 17 | + |
| 18 | + |
| 19 | +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() |
| 20 | +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max |
| 21 | + |
| 22 | +PROVIDER_CFGS = { |
| 23 | + "torch-bf16": dict(enabled=True), |
| 24 | + "nvfp4": dict(no_a_quant=False, enabled=True), |
| 25 | + "nvfp4-noquant": dict(no_a_quant=True, enabled=True), |
| 26 | +} |
| 27 | + |
| 28 | +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] |
| 29 | + |
| 30 | + |
| 31 | +def _quant_weight_nvfp4(b: torch.Tensor, device: str): |
| 32 | + # Compute global scale for weight |
| 33 | + b_amax = torch.abs(b).max().to(torch.float32) |
| 34 | + b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax |
| 35 | + b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale) |
| 36 | + return b_fp4, scale_b_fp4, b_global_scale |
| 37 | + |
| 38 | + |
| 39 | +def build_nvfp4_runner(cfg, a, b, dtype, device): |
| 40 | + b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device) |
| 41 | + |
| 42 | + # Compute global scale for activation |
| 43 | + # NOTE: This is generally provided ahead-of-time by the model checkpoint. |
| 44 | + a_amax = torch.abs(a).max().to(torch.float32) |
| 45 | + a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax |
| 46 | + |
| 47 | + # Alpha for the GEMM operation |
| 48 | + alpha = 1.0 / (a_global_scale * b_global_scale) |
| 49 | + |
| 50 | + if cfg["no_a_quant"]: |
| 51 | + # Pre-quantize activation |
| 52 | + a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale) |
| 53 | + |
| 54 | + def run(): |
| 55 | + return ops.cutlass_scaled_fp4_mm( |
| 56 | + a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype |
| 57 | + ) |
| 58 | + |
| 59 | + return run |
| 60 | + |
| 61 | + # Quantize activation on-the-fly |
| 62 | + def run(): |
| 63 | + a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale) |
| 64 | + return ops.cutlass_scaled_fp4_mm( |
| 65 | + a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype |
| 66 | + ) |
| 67 | + |
| 68 | + return run |
| 69 | + |
| 70 | + |
| 71 | +@triton.testing.perf_report( |
| 72 | + triton.testing.Benchmark( |
| 73 | + x_names=["batch_size"], |
| 74 | + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], |
| 75 | + x_log=False, |
| 76 | + line_arg="provider", |
| 77 | + line_vals=_enabled, |
| 78 | + line_names=_enabled, |
| 79 | + ylabel="TFLOP/s (larger is better)", |
| 80 | + plot_name="BF16 vs NVFP4 GEMMs", |
| 81 | + args={}, |
| 82 | + ) |
| 83 | +) |
| 84 | +def benchmark(batch_size, provider, N, K): |
| 85 | + M = batch_size |
| 86 | + device = "cuda" |
| 87 | + dtype = torch.bfloat16 |
| 88 | + |
| 89 | + a = torch.randn((M, K), device=device, dtype=dtype) |
| 90 | + b = torch.randn((N, K), device=device, dtype=dtype) |
| 91 | + |
| 92 | + quantiles = [0.5, 0.2, 0.8] |
| 93 | + |
| 94 | + if provider == "torch-bf16": |
| 95 | + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( |
| 96 | + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles |
| 97 | + ) |
| 98 | + else: |
| 99 | + cfg = PROVIDER_CFGS[provider] |
| 100 | + run_quant = build_nvfp4_runner(cfg, a, b, dtype, device) |
| 101 | + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( |
| 102 | + lambda: run_quant(), quantiles=quantiles |
| 103 | + ) |
| 104 | + |
| 105 | + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) |
| 106 | + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) |
| 107 | + |
| 108 | + |
| 109 | +def prepare_shapes(args): |
| 110 | + out = [] |
| 111 | + for model, tp_size in itertools.product(args.models, args.tp_sizes): |
| 112 | + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): |
| 113 | + KN[tp_dim] //= tp_size |
| 114 | + KN.append(model) |
| 115 | + out.append(KN) |
| 116 | + return out |
| 117 | + |
| 118 | + |
| 119 | +if __name__ == "__main__": |
| 120 | + parser = argparse.ArgumentParser() |
| 121 | + parser.add_argument( |
| 122 | + "--models", |
| 123 | + nargs="+", |
| 124 | + type=str, |
| 125 | + default=["meta-llama/Llama-3.1-8B-Instruct"], |
| 126 | + choices=list(WEIGHT_SHAPES.keys()), |
| 127 | + ) |
| 128 | + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) |
| 129 | + args = parser.parse_args() |
| 130 | + |
| 131 | + for K, N, model in prepare_shapes(args): |
| 132 | + print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:") |
| 133 | + benchmark.run( |
| 134 | + print_data=True, |
| 135 | + show_plots=True, |
| 136 | + save_path=f"bench_nvfp4_res_n{N}_k{K}", |
| 137 | + N=N, |
| 138 | + K=K, |
| 139 | + ) |
| 140 | + |
| 141 | + print("Benchmark finished!") |
0 commit comments