Skip to content

Commit c044ddb

Browse files
[moe training] add benchmarking script for grouped mm (#2490)
1 parent 1fd34e4 commit c044ddb

File tree

3 files changed

+273
-34
lines changed

3 files changed

+273
-34
lines changed

benchmarks/float8/bench_grouped_mm.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import random
7+
from typing import Optional
8+
9+
import fire
10+
import pandas as pd
11+
import torch
12+
from utils import do_benchmarks, get_name_to_moe_shapes_iter
13+
14+
from torchao.testing.training.roofline_utils import get_specs
15+
16+
17+
@torch.inference_mode()
18+
def run(
19+
n_limit: Optional[int] = None,
20+
out_filename: Optional[str] = None,
21+
M: Optional[int] = None,
22+
K: Optional[int] = None,
23+
N: Optional[int] = None,
24+
E: Optional[int] = None, # dim 0 of B tensor (num experts)
25+
use_gpu_kernel_time: bool = True,
26+
shape_gen_name="llama4_17bx16e",
27+
recipe: str = "rowwise",
28+
):
29+
device = "cuda"
30+
31+
assert recipe in ("rowwise",), "unsupported"
32+
33+
specs = get_specs()
34+
bf16_peak_tops = specs["bf16_peak_tops"]
35+
fp8_peak_tops = specs["fp8_peak_tops"]
36+
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
37+
print(f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}")
38+
headers = (
39+
"name",
40+
"recipe",
41+
"M",
42+
"K",
43+
"N",
44+
"E",
45+
"time_s",
46+
"speedup",
47+
"fp8_speedup",
48+
)
49+
results = []
50+
51+
dtype = torch.bfloat16
52+
name_to_shapes = get_name_to_moe_shapes_iter(shape_gen_name, M, K, N, E)
53+
54+
for idx, (name, (M, K, N, E)) in enumerate(
55+
name_to_shapes,
56+
):
57+
if n_limit is not None and idx >= n_limit:
58+
break
59+
assert M % E == 0, (
60+
"tokens (M) must be evenly divisible by num experts (E) for this benchmark"
61+
)
62+
tops = 2 * M * N * K * E
63+
print("M, K, N, E:", M, K, N, E, f"tops: {tops:.2E}")
64+
65+
# Run bf16 torch._grouped_mm baseline.
66+
A = torch.randn(M, K, device=device, dtype=dtype)
67+
B = torch.randn(E, K, N, device=device, dtype=dtype)
68+
offs = generate_jagged_offs(E, M)
69+
print(f"offs: {offs}")
70+
ref_time_sec, ref_tops_sec, ref_pct_top_peak = do_benchmarks(
71+
tops,
72+
bf16_peak_tops,
73+
use_gpu_kernel_time,
74+
torch._grouped_mm,
75+
A,
76+
B,
77+
offs,
78+
)
79+
print(
80+
f"{dtype} time_sec {ref_time_sec:.2E}, tops/sec {ref_tops_sec:.2E}, pct_peak {ref_pct_top_peak:.3f}"
81+
)
82+
del A
83+
del B
84+
85+
# Run scaled_grouped_mm.
86+
A_hp = torch.randn(M, K, device=device)
87+
B_hp_t = (
88+
torch.randn(E, K, N, device=device)
89+
.transpose(-2, -1)
90+
.contiguous()
91+
.transpose(-2, -1)
92+
)
93+
94+
if recipe == "rowwise":
95+
# TODO: add e5m2
96+
A = A_hp.to(torch.float8_e4m3fn)
97+
B = B_hp_t.to(torch.float8_e4m3fn)
98+
peak_tops = fp8_peak_tops
99+
scale_a = torch.ones(M, device=device)
100+
scale_b = torch.ones(E, N, device=device)
101+
else:
102+
assert False, f"unknown recipe {recipe}"
103+
104+
def do_scaled_grouped_mm(A, B):
105+
nonlocal scale_a
106+
nonlocal scale_b
107+
nonlocal offs
108+
return torch._scaled_grouped_mm(A, B, scale_a, scale_b, offs=offs)
109+
110+
if recipe == "rowwise":
111+
do_matmul = do_scaled_grouped_mm
112+
else:
113+
raise ValueError(f"unknown recipe {recipe}")
114+
115+
time_sec, tops_sec, pct_top_peak = do_benchmarks(
116+
tops, peak_tops, use_gpu_kernel_time, do_matmul, A, B
117+
)
118+
print(
119+
f"time_sec {time_sec:.2E}, tops/sec {tops_sec:.2E}, pct_peak {pct_top_peak:.3f}"
120+
)
121+
122+
del A, B
123+
if scale_a is not None:
124+
del scale_a
125+
if scale_b is not None:
126+
del scale_b
127+
128+
results.append(
129+
[
130+
name,
131+
recipe,
132+
M,
133+
K,
134+
N,
135+
E,
136+
ref_time_sec,
137+
time_sec,
138+
ref_time_sec / time_sec,
139+
]
140+
)
141+
142+
data_df = pd.DataFrame(results, columns=headers)
143+
print(data_df)
144+
145+
if out_filename is not None:
146+
data_df.to_csv(out_filename)
147+
148+
149+
def generate_jagged_offs(E, M, dtype=torch.int32, device="cuda"):
150+
"""
151+
Generates a tensor of length E, containing random values divisible by 16,
152+
from 0 to M, in sorted order, and where the final value in the tensor is always M.
153+
Args:
154+
E (int): The length of the tensor.
155+
M (int): The maximum value in the tensor.
156+
Returns:
157+
torch.Tensor: A tensor of length E with the specified properties.
158+
"""
159+
# Ensure M is divisible by 16
160+
if M % 16 != 0:
161+
raise ValueError("M must be divisible by 16")
162+
163+
# Generate a list of possible values
164+
possible_values = [i for i in range(0, M + 1, 16)]
165+
166+
# If E is larger than the number of possible values, raise an error
167+
if E > len(possible_values):
168+
raise ValueError("E cannot be larger than the number of possible values")
169+
170+
# Randomly select E - 1 values from the possible values (excluding M)
171+
selected_values = torch.tensor(random.sample(possible_values[:-1], E - 1))
172+
173+
# Append M to the selected values
174+
selected_values = torch.cat((selected_values, torch.tensor([M])))
175+
176+
# Sort the selected values
177+
selected_values, _ = torch.sort(selected_values)
178+
179+
return selected_values.to(dtype).to(device)
180+
181+
182+
def main() -> None:
183+
fire.Fire(run)
184+
185+
186+
if __name__ == "__main__":
187+
main() # pragma: no cover

benchmarks/float8/bench_matmul.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010
import pandas as pd
1111
import torch
1212
import torch.nn as nn
13-
import torch.utils.benchmark as benchmark
1413
from utils import (
15-
get_gpu_kernel_gemm_time_s,
14+
do_benchmarks,
1615
get_name_to_shapes_iter,
1716
)
1817

@@ -21,36 +20,6 @@
2120
from torchao.testing.training.roofline_utils import get_specs
2221

2322

24-
def benchmark_fn_in_sec(f, *args, **kwargs):
25-
# Manual warmup
26-
for _ in range(4):
27-
f(*args, **kwargs)
28-
t0 = benchmark.Timer(
29-
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
30-
)
31-
measurement = t0.blocked_autorange()
32-
return measurement.mean
33-
34-
35-
def do_benchmarks(
36-
tops,
37-
peak_tops,
38-
use_gpu_kernel_time,
39-
f,
40-
*args,
41-
**kwargs,
42-
):
43-
if use_gpu_kernel_time:
44-
# just the gemm GPU kernel
45-
time_sec = get_gpu_kernel_gemm_time_s(f, *args, **kwargs)
46-
else:
47-
# e2e time including kernel launch overhead
48-
time_sec = benchmark_fn_in_sec(f, *args, **kwargs)
49-
tops_sec = float(tops) / time_sec
50-
pct_top_peak = tops_sec / peak_tops
51-
return time_sec, tops_sec, pct_top_peak
52-
53-
5423
@torch.inference_mode()
5524
def run(
5625
n_limit: Optional[int] = None,
@@ -76,7 +45,7 @@ def run(
7645
specs = get_specs()
7746
bf16_peak_tops = specs["bf16_peak_tops"]
7847
fp8_peak_tops = specs["fp8_peak_tops"]
79-
fp4_peak_tops = specs["fp4_peak_tops"]
48+
fp4_peak_tops = specs.get("fp4_peak_tops", 0.0) # only on sm120
8049
print(f"gpu_name: {torch.cuda.get_device_name(0)}")
8150
print(
8251
f"peak tops: bf16 {bf16_peak_tops:.2e}, fp8 {fp8_peak_tops:.2e}, fp4 {fp4_peak_tops:.2e}"
@@ -175,6 +144,16 @@ def do_matmul_nvfp4(A, B):
175144
nonlocal scale_b
176145
return torch._scaled_mm(A, B, scale_a, scale_b, out_dtype=dtype)
177146

147+
def do_grouped_mm(A, B):
148+
return torch._grouped_mm(A, B, use_fast_accum=fast_accum)
149+
150+
def do_scaled_grouped_mm(A, B):
151+
nonlocal scale_a
152+
nonlocal scale_b
153+
return torch._scaled_grouped_mm(
154+
A, B, scale_a, scale_b, use_fast_accum=fast_accum
155+
)
156+
178157
if recipe == "mxfp4_cutlass":
179158
do_matmul = do_matmul_mxfp4
180159
elif recipe == "nvfp4":

benchmarks/float8/utils.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import re
1010
from typing import Optional
1111

12+
import torch.utils.benchmark as benchmark
1213
from torch.profiler import ProfilerActivity, profile
1314

1415

@@ -211,6 +212,42 @@ def get_name_to_shapes_iter(
211212
raise AssertionError(f"unknown shape_gen_name {shape_gen_name}")
212213

213214

215+
def get_name_to_moe_shapes_iter(
216+
shape_gen_name: str,
217+
M: Optional[int] = None,
218+
K: Optional[int] = None,
219+
N: Optional[int] = None,
220+
E: Optional[int] = None,
221+
):
222+
M = 8192 if M is None else M
223+
if shape_gen_name == "llama4_17bx16e":
224+
# num_experts=16, dim=5120
225+
names_to_shapes = {
226+
# M, K, N, E
227+
"moe.experts.w1": (M, 5120, 8192, 16),
228+
"moe.experts.w2": (M, 8192, 5120, 16),
229+
}
230+
return names_to_shapes.items()
231+
elif shape_gen_name == "llama4_17bx128e":
232+
# num_experts=128, dim=5120
233+
names_to_shapes = {
234+
# M, K, N, E
235+
"moe.experts.w1": (M, 5120, 8192, 128),
236+
"moe.experts.w2": (M, 8192, 5120, 128),
237+
}
238+
return names_to_shapes.items()
239+
elif shape_gen_name == "custom":
240+
assert M is not None and K is not None and N is not None and E is not None, (
241+
"M, K, N, E must be specified for custom shape_gen"
242+
)
243+
name_to_shapes = {
244+
1: (M, K, N, E),
245+
}
246+
return name_to_shapes.items()
247+
248+
raise AssertionError(f"unknown shape_gen_name {shape_gen_name}")
249+
250+
214251
# copy-pasta from https://github.com/vkuzo/pytorch_scripts/blob/main/add_inductor_metadata_to_perf_trace.py
215252
def update_triton_kernels_in_prof_chome_trace_with_torch_logs(
216253
perf_trace_file: str,
@@ -353,5 +390,41 @@ def get_gpu_kernel_gemm_time_s(f, *args, **kwargs):
353390
# there is only 1 key, aten::mm or aten::_scaled_mm, with unit nanoseconds
354391
assert len(data) == 1
355392
key, value = next(iter(data.items()))
356-
assert key in ("aten::mm", "aten::_scaled_mm", "torchao::mx_fp4_bf16")
393+
assert key in (
394+
"aten::mm",
395+
"aten::_scaled_mm",
396+
"torchao::mx_fp4_bf16",
397+
"aten::_grouped_mm",
398+
"aten::_scaled_grouped_mm",
399+
)
357400
return value / 1e6 / n_iter
401+
402+
403+
def benchmark_fn_in_sec(f, *args, **kwargs):
404+
# Manual warmup
405+
for _ in range(4):
406+
f(*args, **kwargs)
407+
t0 = benchmark.Timer(
408+
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
409+
)
410+
measurement = t0.blocked_autorange()
411+
return measurement.mean
412+
413+
414+
def do_benchmarks(
415+
tops,
416+
peak_tops,
417+
use_gpu_kernel_time,
418+
f,
419+
*args,
420+
**kwargs,
421+
):
422+
if use_gpu_kernel_time:
423+
# just the gemm GPU kernel
424+
time_sec = get_gpu_kernel_gemm_time_s(f, *args, **kwargs)
425+
else:
426+
# e2e time including kernel launch overhead
427+
time_sec = benchmark_fn_in_sec(f, *args, **kwargs)
428+
tops_sec = float(tops) / time_sec
429+
pct_top_peak = tops_sec / peak_tops
430+
return time_sec, tops_sec, pct_top_peak

0 commit comments

Comments
 (0)