Skip to content

Commit 152c23c

Browse files
authored
mx cast performance bench (#1907)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 151bd03 commit 152c23c

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed

benchmarks/mx_formats/cast_bench.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from typing import Callable, Tuple
2+
3+
import fire
4+
import torch
5+
import triton
6+
from torch._inductor.utils import do_bench_using_profiling
7+
8+
from torchao.prototype.mx_formats.mx_tensor import to_mx
9+
10+
torch.manual_seed(0)
11+
12+
bytes_per_el_bf16 = 2
13+
bytes_per_el_fp8 = 1
14+
15+
16+
def scale_dim0_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
17+
assert x_hp.is_contiguous()
18+
x_hp_d0_block = x_hp.reshape(-1, block_size)
19+
x_hp_d0_block_abs = x_hp_d0_block.abs()
20+
amax_dim0 = torch.amax(x_hp_d0_block_abs, dim=1).unsqueeze(1)
21+
x_hp_d0_block_normalized = x_hp_d0_block / amax_dim0
22+
x_hp_d0_normalized = x_hp_d0_block_normalized.reshape(x_hp.shape)
23+
return x_hp_d0_normalized, amax_dim0
24+
25+
26+
def scale_dim1_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
27+
assert x_hp.is_contiguous()
28+
x_hp_d1 = x_hp.t().contiguous()
29+
x_hp_d1_block = x_hp_d1.reshape(-1, block_size)
30+
x_hp_d1_block_abs = x_hp_d1_block.abs()
31+
amax_dim1 = torch.amax(x_hp_d1_block_abs, dim=1).unsqueeze(1)
32+
x_hp_d1_block_normalized = x_hp_d1_block / amax_dim1
33+
x_hp_d1_normalized = x_hp_d1_block_normalized.reshape(x_hp_d1.shape)
34+
return x_hp_d1_normalized, amax_dim1
35+
36+
37+
def scale_dim0_dim1_reference(
38+
x_hp: torch.Tensor, block_size
39+
) -> Tuple[torch.Tensor, torch.Tensor]:
40+
# normalize across dim0
41+
x_hp_d0_normalized, amax_dim0 = scale_dim0_reference(x_hp, block_size)
42+
# normalize across dim1
43+
x_hp_d1_normalized, amax_dim1 = scale_dim1_reference(x_hp, block_size)
44+
return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1
45+
46+
47+
def to_mx_dim0_reference(x_hp, block_size):
48+
scale_d0, data_d0 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
49+
return data_d0, scale_d0
50+
51+
52+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
53+
"""Thin wrapper around do_bench_using_profiling"""
54+
no_args = lambda: func(*args, **kwargs)
55+
time = do_bench_using_profiling(no_args)
56+
return time * 1e3
57+
58+
59+
def run(
60+
M: int = 16384,
61+
K: int = 16384,
62+
BLOCK_SIZE: int = 32,
63+
mode: str = "dim0",
64+
):
65+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
66+
print(f"GPU: {torch.cuda.get_device_name(0)}")
67+
print(f"torch version: {torch.__version__}")
68+
print(f"triton version: {triton.__version__}")
69+
print(f"mode: {mode}")
70+
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx")
71+
72+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
73+
74+
if mode == "dim0":
75+
scale_dim0_reference_c = torch.compile(scale_dim0_reference)
76+
y_d0, s_d0 = scale_dim0_reference_c(x, BLOCK_SIZE)
77+
78+
for _ in range(2):
79+
__ = scale_dim0_reference_c(x, BLOCK_SIZE)
80+
time_us = benchmark_cuda_function_in_microseconds(
81+
lambda x, b: scale_dim0_reference_c(x, BLOCK_SIZE),
82+
x,
83+
BLOCK_SIZE,
84+
)
85+
86+
assert y_d0.dtype == torch.bfloat16
87+
assert s_d0.dtype == torch.bfloat16
88+
bytes_rw = sum(t.numel() for t in [x, y_d0, s_d0]) * bytes_per_el_bf16
89+
bps = bytes_rw / (time_us / 1e6)
90+
91+
elif mode == "dim1":
92+
scale_dim1_reference_c = torch.compile(scale_dim1_reference)
93+
y_d1, s_d1 = scale_dim1_reference_c(x, BLOCK_SIZE)
94+
95+
for _ in range(2):
96+
__ = scale_dim1_reference_c(x, BLOCK_SIZE)
97+
time_us = benchmark_cuda_function_in_microseconds(
98+
lambda x, b: scale_dim1_reference_c(x, BLOCK_SIZE),
99+
x,
100+
BLOCK_SIZE,
101+
)
102+
103+
assert y_d1.dtype == torch.bfloat16
104+
assert s_d1.dtype == torch.bfloat16
105+
bytes_rw = sum(t.numel() for t in [x, y_d1, s_d1]) * bytes_per_el_bf16
106+
bps = bytes_rw / (time_us / 1e6)
107+
108+
elif mode == "dim0_dim1":
109+
scale_dim0_dim1_reference_c = torch.compile(scale_dim0_dim1_reference)
110+
y_d0, y_d1, s_d0, s_d1 = scale_dim0_dim1_reference_c(x, BLOCK_SIZE)
111+
112+
for _ in range(2):
113+
__ = scale_dim0_dim1_reference_c(x, BLOCK_SIZE)
114+
time_us = benchmark_cuda_function_in_microseconds(
115+
lambda x, b: scale_dim0_dim1_reference_c(x, BLOCK_SIZE),
116+
x,
117+
BLOCK_SIZE,
118+
)
119+
120+
assert y_d0.dtype == torch.bfloat16
121+
assert s_d0.dtype == torch.bfloat16
122+
assert y_d1.dtype == torch.bfloat16
123+
assert s_d1.dtype == torch.bfloat16
124+
bytes_rw = (
125+
sum(t.numel() for t in [x, y_d0, y_d1, s_d0, s_d1]) * bytes_per_el_bf16
126+
)
127+
bps = bytes_rw / (time_us / 1e6)
128+
129+
elif mode == "dim0_mx":
130+
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
131+
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE)
132+
133+
for _ in range(2):
134+
__ = to_mx_dim0_reference_c(x, BLOCK_SIZE)
135+
time_us = benchmark_cuda_function_in_microseconds(
136+
lambda x, b: to_mx_dim0_reference_c(x, BLOCK_SIZE),
137+
x,
138+
BLOCK_SIZE,
139+
)
140+
141+
assert y_d0.dtype == torch.float8_e4m3fn
142+
assert s_d0.dtype == torch.uint8
143+
bytes_r = x.numel() * bytes_per_el_bf16
144+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
145+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
146+
147+
else:
148+
raise AssertionError(f"unknown mode {mode}")
149+
150+
print("time_us", time_us)
151+
print("mem_bw_gbps", bps / 1e9)
152+
153+
154+
if __name__ == "__main__":
155+
fire.Fire(run)

0 commit comments

Comments
 (0)