Skip to content

Commit ee6ce03

Browse files
add more generic kernel for fp8 blockwise scaling
stack-info: PR: #2592, branch: danielvegamyhre/stack/15
1 parent 0e00df3 commit ee6ce03

File tree

8 files changed

+568
-76
lines changed

8 files changed

+568
-76
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from triton.testing import do_bench
1414

1515
from torchao.float8.float8_utils import compute_error
16-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
16+
from torchao.prototype.blockwise_fp8.kernels import (
1717
blockwise_fp8_gemm,
1818
fp8_blockwise_act_quant,
1919
fp8_blockwise_weight_quant,
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
import argparse
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from utils import benchmark_microseconds
16+
17+
from torchao.prototype.blockwise_fp8.kernels import (
18+
fp8_blockwise_act_quant,
19+
fp8_blockwise_weight_quant,
20+
torch_blockwise_scale_act_quant,
21+
torch_blockwise_scale_weight_quant,
22+
triton_quantize_fp8_block,
23+
)
24+
25+
device = torch.device("cuda")
26+
27+
# Needed since changing args to function causes recompiles
28+
torch._dynamo.config.cache_size_limit = 1000
29+
30+
31+
@dataclass(frozen=True)
32+
class ExperimentConfig:
33+
A_shape: tuple[int]
34+
block_m: int
35+
block_k: int
36+
37+
38+
@dataclass(frozen=True)
39+
class ExperimentResult:
40+
torch_us: float
41+
fbgemm_us: float
42+
deepgemm_us: float
43+
44+
45+
@dataclass(frozen=True)
46+
class Experiment:
47+
config: ExperimentConfig
48+
result: ExperimentResult
49+
50+
51+
def get_configs() -> List[ExperimentConfig]:
52+
A_shapes = [
53+
(1024, 1024),
54+
(2048, 2048),
55+
(4096, 4096),
56+
(8192, 8192),
57+
(16384, 16384),
58+
(32768, 32768),
59+
]
60+
block_m_opts = [1, 128]
61+
block_k_opts = [
62+
128,
63+
]
64+
configs = []
65+
for A_shape, block_m, block_k in itertools.product(
66+
A_shapes,
67+
block_m_opts,
68+
block_k_opts,
69+
):
70+
configs.append(
71+
ExperimentConfig(
72+
A_shape=A_shape,
73+
block_m=block_m,
74+
block_k=block_k,
75+
)
76+
)
77+
return configs
78+
79+
80+
def run_experiment(
81+
config: ExperimentConfig, args: argparse.Namespace
82+
) -> ExperimentResult:
83+
A = torch.randn(
84+
*config.A_shape,
85+
dtype=torch.bfloat16,
86+
device=device,
87+
)
88+
89+
# Torch and DeepGEMM implementations are specific to activation quantization (1 x block_size)
90+
# and weight quantization (block_size x block_size)
91+
if config.block_m == 1:
92+
torch_func = torch.compile(torch_blockwise_scale_act_quant)
93+
deepgemm_func = fp8_blockwise_act_quant
94+
else:
95+
torch_func = torch.compile(torch_blockwise_scale_weight_quant)
96+
deepgemm_func = fp8_blockwise_weight_quant
97+
98+
# Validate output shapes and strides
99+
torch_out, torch_scale = torch_func(A, tile_size=config.block_k)
100+
deepgemm_out, deepgemm_scale = deepgemm_func(A, block_size=config.block_k)
101+
fbgemm_out, fbgemm_scale = triton_quantize_fp8_block(
102+
A, block_m=config.block_m, block_k=config.block_k, k_major=True
103+
)
104+
assert torch_out.shape == deepgemm_out.shape == fbgemm_out.shape
105+
assert torch_out.stride() == deepgemm_out.stride() == fbgemm_out.stride()
106+
assert torch_scale.shape == deepgemm_scale.shape == fbgemm_scale.shape
107+
assert torch_scale.stride() == deepgemm_scale.stride() == fbgemm_scale.stride()
108+
109+
# Do benchmarking
110+
torch_us = benchmark_microseconds(torch_func, A, tile_size=config.block_k)
111+
deepgemm_us = benchmark_microseconds(
112+
fp8_blockwise_act_quant, A, block_size=config.block_k
113+
)
114+
fbgemm_us = benchmark_microseconds(
115+
triton_quantize_fp8_block,
116+
A,
117+
block_m=config.block_m,
118+
block_k=config.block_k,
119+
k_major=True,
120+
)
121+
122+
return ExperimentResult(
123+
torch_us=round(torch_us, 3),
124+
fbgemm_us=round(fbgemm_us, 3),
125+
deepgemm_us=round(deepgemm_us, 3),
126+
)
127+
128+
129+
def print_results(experiments: List[Experiment]):
130+
headers = [
131+
"A_shape",
132+
"block_shape",
133+
"torch_us",
134+
"fbgemm_us",
135+
"deepgemm_us",
136+
]
137+
rows = []
138+
for experiment in experiments:
139+
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
140+
block_shape = f"({experiment.config.block_m},{experiment.config.block_k})"
141+
rows.append(
142+
[
143+
A_shape,
144+
block_shape,
145+
experiment.result.torch_us,
146+
experiment.result.fbgemm_us,
147+
experiment.result.deepgemm_us,
148+
]
149+
)
150+
print(tabulate(rows, headers=headers))
151+
152+
153+
def main(args: argparse.Namespace):
154+
torch.random.manual_seed(123)
155+
configs = get_configs()
156+
results = []
157+
for config in tqdm(configs):
158+
result = run_experiment(config, args)
159+
results.append(Experiment(config=config, result=result))
160+
161+
# Use Tabulate to print results
162+
print_results(results)
163+
164+
165+
if __name__ == "__main__":
166+
arg_parser = argparse.ArgumentParser()
167+
arg_parser.add_argument("--compile", action="store_true")
168+
args = arg_parser.parse_args()
169+
main(args)

benchmarks/float8/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch.utils.benchmark as benchmark
1313
from torch.profiler import ProfilerActivity, profile
14+
from triton.testing import do_bench
1415

1516

1617
def profiler_output_to_filtered_time_by_kernel_name(
@@ -428,3 +429,12 @@ def do_benchmarks(
428429
tops_sec = float(tops) / time_sec
429430
pct_top_peak = tops_sec / peak_tops
430431
return time_sec, tops_sec, pct_top_peak
432+
433+
434+
def benchmark_microseconds(f, *args, warmup=25, rep=100, **kwargs):
435+
return (
436+
do_bench(
437+
lambda: f(*args, **kwargs), warmup=warmup, rep=rep, return_mode="median"
438+
)
439+
* 1e3
440+
)

test/prototype/test_blockwise_triton.py

Lines changed: 0 additions & 72 deletions
This file was deleted.

0 commit comments

Comments
 (0)