Skip to content

Commit c788ee7

Browse files
authored
[1/x] mx roofline: make the script work on NVIDIA B200 (#1778)
Update [ghstack-poisoned]
1 parent 1ab1b77 commit c788ee7

File tree

3 files changed

+66
-26
lines changed

3 files changed

+66
-26
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
get_float8_mem_sympy,
6666
get_gemm_time_sympy,
6767
)
68+
from torchao.utils import is_sm_at_least_90, is_sm_at_least_100
6869

6970

7071
class LNLinearSigmoid(torch.nn.Module):
@@ -154,10 +155,13 @@ def do_matmul(A, B):
154155

155156
f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
156157

157-
scale_a = torch.ones(M, 1, device=device)
158-
scale_b = torch.ones(1, N, device=device)
159-
fast_accum = True # for axiswise
160-
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
158+
if is_sm_at_least_90() and (not is_sm_at_least_100()):
159+
scale_a = torch.ones(M, 1, device=device)
160+
scale_b = torch.ones(1, N, device=device)
161+
fast_accum = True # for axiswise
162+
f8_axs_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B)
163+
else:
164+
f8_axs_time_s = -1.0
161165

162166
# save to cache if needed
163167
if cache_filename is not None:
@@ -298,17 +302,24 @@ def run(
298302
bf16_time_actual_s = get_gpu_kernel_time(m_bf16, x)
299303

300304
# get the float8 dynamic scaling gpu kernel time
305+
301306
torch._dynamo.reset()
302307
m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig))
303308
m_fp8_dyn = torch.compile(m_fp8_dyn)
304309
fp8_dyn_time_actual_s = get_gpu_kernel_time(m_fp8_dyn, x)
305310

306-
# get the float8 dynamic axiswise scaling gpu kernel time
307-
torch._dynamo.reset()
308-
config = Float8LinearConfig.from_recipe_name("rowwise")
309-
m_fp8_dyn_axs = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
310-
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
311-
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
311+
# get the float8 dynamic axiswise scaling gpu kernel time, if supported
312+
# on current hardware
313+
if is_sm_at_least_90() and (not is_sm_at_least_100()):
314+
torch._dynamo.reset()
315+
config = Float8LinearConfig.from_recipe_name("rowwise")
316+
m_fp8_dyn_axs = convert_to_float8_training(
317+
copy.deepcopy(m_orig), config=config
318+
)
319+
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
320+
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)
321+
else:
322+
fp8_dyn_axs_time_actual_s = -1.0
312323

313324
# get the lw recipe scaling gpu kernel time
314325
# TODO(future PR): enable below once basic performance issues

benchmarks/float8/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ def profiler_output_to_filtered_time_by_kernel_name(
8181
continue
8282
elif e.key == "cudaDeviceSynchronize":
8383
continue
84+
elif e.key == "Activity Buffer Request":
85+
continue
8486

8587
kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total
8688
return kernel_name_to_gpu_time_us

torchao/testing/float8/roofline_utils.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,43 @@
99
BYTES_PER_EL_FLOAT8 = 1
1010
BYTES_PER_EL_BF16 = 2
1111

12-
# https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
13-
H100_BF16_PEAK_TOPS = 989e12
14-
H100_FP8_PEAK_TOPS = 1979e12
12+
gpu_name_to_specs = {
13+
"NVIDIA H100": {
14+
# https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
15+
"bf16_peak_tops": 989e12,
16+
"fp8_peak_tops": 1979e12,
17+
# 2.4 TB per second, custom to Meta's H100 variant
18+
"peak_mem_bw_bytes_sec": 2.4e12,
19+
# based on quick experimental observation with sample large inputs
20+
"pct_achievable_gemm_tops": 0.6,
21+
# based on previous experience looking at pointwise triton kernels with large inputs,
22+
# which would hit about 2.2k GBPS on Meta's H100 variant
23+
"pct_achievable_mem_bw": 0.92,
24+
},
25+
"NVIDIA B200": {
26+
# https://resources.nvidia.com/en-us-blackwell-architecture, page 19,
27+
# divide by 2 because no sparsity
28+
"bf16_peak_tops": 2.25e15,
29+
"fp8_peak_tops": 4.5e15,
30+
"fp4_peak_tops": 9.0e15,
31+
# https://resources.nvidia.com/en-us-blackwell-architecture, page 20
32+
# 8.0 TB per second
33+
"peak_mem_bw_bytes_sec": 8.0e12,
34+
# for now, copy over from H100
35+
# TODO(future): measure once we have the hardware
36+
"pct_achievable_gemm_tops": 0.6,
37+
# for now, copy over from H100
38+
# TODO(future): measure once we have the hardware
39+
"pct_achievable_mem_bw": 0.92,
40+
},
41+
# TODO(future): more GPU names
42+
}
43+
44+
45+
def get_specs():
46+
gpu_name = torch.cuda.get_device_name(0)
47+
return gpu_name_to_specs[gpu_name]
1548

16-
# 2.4 TB per second, custom to Meta's H100 variant
17-
H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12
18-
19-
# based on quick experimental observation with sample large inputs
20-
H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6
21-
22-
# based on previous experience looking at pointwise triton kernels with large inputs,
23-
# which would hit about 2.2k GBPS on Meta's H100 variant
24-
H100_PCT_ACHIEVABLE_MEM_BW = 0.92
2549

2650
# Source: run a triton kernel with a single element read/write on an H100 and
2751
# measure GPU time from the trace
@@ -65,12 +89,13 @@ def get_tensor_memory_traffic_bytes(
6589

6690

6791
def get_gemm_time_sympy(M, K, N, dtype):
92+
specs = get_specs()
6893
gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
6994
if dtype is torch.bfloat16:
70-
peak_tops = H100_BF16_PEAK_TOPS
95+
peak_tops = specs["bf16_peak_tops"]
7196
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
72-
peak_tops = H100_FP8_PEAK_TOPS
73-
gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS
97+
peak_tops = specs["fp8_peak_tops"]
98+
gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"]
7499
return gemm_time_s
75100

76101

@@ -87,6 +112,8 @@ def get_float8_mem_sympy(
87112
assert scaling_type_weight in ("dynamic",), "unsupported"
88113
assert scaling_type_grad_output in ("dynamic",), "unsupported"
89114

115+
specs = get_specs()
116+
90117
# there are three gemms in the fwd/bwd of a linear:
91118
#
92119
# input @ weight_t = output
@@ -148,7 +175,7 @@ def get_float8_mem_sympy(
148175
)
149176
fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem
150177
fp8_mem_time_s = (
151-
fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW
178+
fp8_total_mem / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"]
152179
)
153180

154181
# Adjust final estimate for small kernel launches

0 commit comments

Comments
 (0)