Skip to content

Commit 36b8502

Browse files
committed
remove extra files
1 parent 3d2f0f0 commit 36b8502

File tree

6 files changed

+486
-2
lines changed

6 files changed

+486
-2
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ venv.bak/
146146

147147
# mkdocs documentation
148148
/site
149+
docs/argparse
149150
docs/examples
150151

151152
# mypy
@@ -202,5 +203,3 @@ shellcheck*/
202203

203204
# Ignore moe/marlin_moe gen code
204205
csrc/moe/marlin_moe_wna16/kernel_*
205-
local/
206-
*.patch

benchmarks/benchmark_dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,9 @@ def sample(
324324
input_low = int(real_input_len * (1 - range_ratio))
325325
input_high = int(real_input_len * (1 + range_ratio))
326326
output_low = int(output_len * (1 - range_ratio))
327+
# Ensure the lower bound for output length is at least 1 to prevent
328+
# sampling 0 tokens, which can cause request failures.
329+
output_low = max(output_low, 1)
327330
output_high = int(output_len * (1 + range_ratio))
328331

329332
# Add logging for debugging
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import argparse
4+
import copy
5+
import itertools
6+
7+
import torch
8+
from weight_shapes import WEIGHT_SHAPES
9+
10+
from vllm import _custom_ops as ops
11+
from vllm.platforms import current_platform
12+
from vllm.scalar_type import scalar_types
13+
from vllm.triton_utils import triton
14+
15+
if not current_platform.has_device_capability(100):
16+
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
17+
18+
19+
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
20+
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
21+
22+
PROVIDER_CFGS = {
23+
"torch-bf16": dict(enabled=True),
24+
"nvfp4": dict(no_a_quant=False, enabled=True),
25+
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
26+
}
27+
28+
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
29+
30+
31+
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
32+
# Compute global scale for weight
33+
b_amax = torch.abs(b).max().to(torch.float32)
34+
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
35+
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
36+
return b_fp4, scale_b_fp4, b_global_scale
37+
38+
39+
def build_nvfp4_runner(cfg, a, b, dtype, device):
40+
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
41+
42+
# Compute global scale for activation
43+
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
44+
a_amax = torch.abs(a).max().to(torch.float32)
45+
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
46+
47+
# Alpha for the GEMM operation
48+
alpha = 1.0 / (a_global_scale * b_global_scale)
49+
50+
if cfg["no_a_quant"]:
51+
# Pre-quantize activation
52+
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
53+
54+
def run():
55+
return ops.cutlass_scaled_fp4_mm(
56+
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
57+
)
58+
59+
return run
60+
61+
# Quantize activation on-the-fly
62+
def run():
63+
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
64+
return ops.cutlass_scaled_fp4_mm(
65+
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
66+
)
67+
68+
return run
69+
70+
71+
@triton.testing.perf_report(
72+
triton.testing.Benchmark(
73+
x_names=["batch_size"],
74+
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
75+
x_log=False,
76+
line_arg="provider",
77+
line_vals=_enabled,
78+
line_names=_enabled,
79+
ylabel="TFLOP/s (larger is better)",
80+
plot_name="BF16 vs NVFP4 GEMMs",
81+
args={},
82+
)
83+
)
84+
def benchmark(batch_size, provider, N, K):
85+
M = batch_size
86+
device = "cuda"
87+
dtype = torch.bfloat16
88+
89+
a = torch.randn((M, K), device=device, dtype=dtype)
90+
b = torch.randn((N, K), device=device, dtype=dtype)
91+
92+
quantiles = [0.5, 0.2, 0.8]
93+
94+
if provider == "torch-bf16":
95+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
96+
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
97+
)
98+
else:
99+
cfg = PROVIDER_CFGS[provider]
100+
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
101+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
102+
lambda: run_quant(), quantiles=quantiles
103+
)
104+
105+
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
106+
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
107+
108+
109+
def prepare_shapes(args):
110+
out = []
111+
for model, tp_size in itertools.product(args.models, args.tp_sizes):
112+
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
113+
KN[tp_dim] //= tp_size
114+
KN.append(model)
115+
out.append(KN)
116+
return out
117+
118+
119+
if __name__ == "__main__":
120+
parser = argparse.ArgumentParser()
121+
parser.add_argument(
122+
"--models",
123+
nargs="+",
124+
type=str,
125+
default=["meta-llama/Llama-3.1-8B-Instruct"],
126+
choices=list(WEIGHT_SHAPES.keys()),
127+
)
128+
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
129+
args = parser.parse_args()
130+
131+
for K, N, model in prepare_shapes(args):
132+
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
133+
benchmark.run(
134+
print_data=True,
135+
show_plots=True,
136+
save_path=f"bench_nvfp4_res_n{N}_k{K}",
137+
N=N,
138+
K=K,
139+
)
140+
141+
print("Benchmark finished!")
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import itertools
4+
from typing import Callable
5+
6+
import torch
7+
8+
from vllm import _custom_ops as ops
9+
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
10+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
11+
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
12+
from vllm.triton_utils import triton
13+
14+
15+
# TODO(luka): use standalone_compile utility
16+
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
17+
def inner(*args):
18+
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
19+
return fn(*args)
20+
21+
return inner
22+
23+
24+
torch._dynamo.config.recompile_limit = 8888
25+
compilation_config = CompilationConfig(custom_ops=["none"])
26+
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
27+
torch_per_token_quant_fp8 = torch.compile(
28+
QuantFP8(False, GroupShape.PER_TOKEN),
29+
fullgraph=True,
30+
dynamic=False, # recompile for different shapes
31+
)
32+
33+
# First dim is explicitly dynamic to simulate vLLM usage
34+
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
35+
36+
37+
def cuda_per_token_quant_fp8(
38+
input: torch.Tensor,
39+
) -> tuple[torch.Tensor, torch.Tensor]:
40+
return ops.scaled_fp8_quant(input)
41+
42+
43+
def calculate_diff(batch_size: int, seq_len: int):
44+
"""Calculate difference between Triton and CUDA implementations."""
45+
device = torch.device("cuda")
46+
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
47+
48+
torch_out, torch_scale = torch_per_token_quant_fp8(x)
49+
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
50+
51+
if torch.allclose(
52+
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
53+
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
54+
print("✅ All implementations match")
55+
else:
56+
print("❌ Implementations differ")
57+
58+
59+
batch_size_range = [1, 16, 32, 64, 128]
60+
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
61+
62+
configs = list(itertools.product(batch_size_range, seq_len_range))
63+
64+
65+
@triton.testing.perf_report(
66+
triton.testing.Benchmark(
67+
x_names=["batch_size", "seq_len"],
68+
x_vals=configs,
69+
line_arg="provider",
70+
line_vals=["torch", "cuda"],
71+
line_names=["Torch", "CUDA"],
72+
styles=[("blue", "-"), ("green", "-")],
73+
ylabel="us",
74+
plot_name="per-token-dynamic-quant-fp8-performance",
75+
args={},
76+
)
77+
)
78+
def benchmark_quantization(batch_size, seq_len, provider):
79+
dtype = torch.float16
80+
device = torch.device("cuda")
81+
82+
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
83+
84+
quantiles = [0.5, 0.2, 0.8]
85+
86+
if provider == "torch":
87+
fn = lambda: torch_per_token_quant_fp8(x.clone())
88+
elif provider == "cuda":
89+
fn = lambda: cuda_per_token_quant_fp8(x.clone())
90+
91+
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
92+
93+
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
94+
95+
96+
if __name__ == "__main__":
97+
calculate_diff(batch_size=4, seq_len=4096)
98+
benchmark_quantization.run(print_data=True)

benchmarks/kernels/benchmark_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def benchmark_config(
8686
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
8787
)
8888
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
89+
if use_deep_gemm:
90+
# we use the default block shape for deepgemm
91+
block_quant_shape = [128, 128]
8992
if use_fp8_w8a8:
9093
if block_quant_shape:
9194
block_n, block_k = block_quant_shape[0], block_quant_shape[1]

0 commit comments

Comments
 (0)