Skip to content

Commit e4827f2

Browse files
authored
Lint benchmark folder (#1519)
* Torchao folder linted * Torchao/prototype folder linted * Benchmarks folder linted * Lint fixes
1 parent e1d8899 commit e4827f2

26 files changed

+656
-429
lines changed

benchmarks/bench_galore_fused_kernels.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
def run(args):
99
dtype = getattr(torch, args.dtype)
1010
allow_tf32 = args.allow_tf32
11-
fp8_fast_accum = False
1211
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
13-
kernel = args.kernel
1412
M, N = args.M, args.N
1513
rank = args.rank
1614

benchmarks/benchmark_aq.py

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
1-
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs
2-
"""
1+
"""Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs"""
2+
3+
import copy
4+
35
import torch
6+
7+
from torchao.quantization.quant_api import (
8+
_replace_with_custom_fn_if_matches_filter,
9+
int4_weight_only,
10+
int8_dynamic_activation_int8_weight,
11+
int8_weight_only,
12+
quantize_,
13+
)
414
from torchao.quantization.subclass import (
5-
Int8WeightOnlyQuantizedLinearWeight,
615
Int4WeightOnlyQuantizedLinearWeight,
16+
Int8WeightOnlyQuantizedLinearWeight,
717
)
818
from torchao.utils import (
919
TORCH_VERSION_AT_LEAST_2_4,
1020
TORCH_VERSION_AT_LEAST_2_5,
21+
unwrap_tensor_subclass,
1122
)
12-
from torchao.quantization.quant_api import (
13-
int4_weight_only,
14-
int8_weight_only,
15-
int8_dynamic_activation_int8_weight,
16-
quantize_,
17-
_replace_with_custom_fn_if_matches_filter,
18-
)
19-
import copy
20-
from torchao.utils import unwrap_tensor_subclass
23+
2124

2225
def _int8wo_api(mod, **kwargs):
2326
if TORCH_VERSION_AT_LEAST_2_4:
@@ -27,14 +30,20 @@ def _int8wo_api(mod, **kwargs):
2730
else:
2831
change_linear_weights_to_int8_woqtensors(mod, **kwargs)
2932

33+
3034
def _int8da_int8w_api(mod, **kwargs):
3135
if TORCH_VERSION_AT_LEAST_2_4:
32-
quantize_(mod, int8_dynamic_activation_int8_weight(**kwargs), set_inductor_config=False)
36+
quantize_(
37+
mod,
38+
int8_dynamic_activation_int8_weight(**kwargs),
39+
set_inductor_config=False,
40+
)
3341
if not TORCH_VERSION_AT_LEAST_2_5:
3442
unwrap_tensor_subclass(mod)
3543
else:
3644
change_linear_weights_to_int8_dqtensors(mod, **kwargs)
3745

46+
3847
def _int4wo_api(mod, **kwargs):
3948
if TORCH_VERSION_AT_LEAST_2_4:
4049
kwargs_copy = kwargs.copy()
@@ -47,31 +56,43 @@ def _int4wo_api(mod, **kwargs):
4756
else:
4857
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
4958

59+
5060
class ToyLinearModel(torch.nn.Module):
51-
"""Single linear for m * k * n problem size
52-
"""
53-
def __init__(self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"):
61+
"""Single linear for m * k * n problem size"""
62+
63+
def __init__(
64+
self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"
65+
):
5466
super().__init__()
5567
self.m = m
5668
self.dtype = dtype
5769
self.device = device
58-
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(dtype=self.dtype, device=self.device)
70+
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(
71+
dtype=self.dtype, device=self.device
72+
)
5973

6074
def example_inputs(self):
61-
return (torch.randn(self.m, self.linear.in_features, dtype=self.dtype, device=self.device),)
75+
return (
76+
torch.randn(
77+
self.m, self.linear.in_features, dtype=self.dtype, device=self.device
78+
),
79+
)
6280

6381
def forward(self, x):
6482
x = self.linear(x)
6583
return x
6684

85+
6786
def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
6887
"""
6988
The deprecated implementation for int8 dynamic quant API, used as a reference for
7089
numerics and performance
7190
"""
72-
from torchao.quantization.quant_api import _in_features_greater_than_16
73-
from torchao.quantization.quant_api import _is_linear
74-
from torchao.quantization.quant_api import _get_subclass_inserter
91+
from torchao.quantization.quant_api import (
92+
_get_subclass_inserter,
93+
_in_features_greater_than_16,
94+
_is_linear,
95+
)
7596
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight
7697

7798
if filter_fn is None:
@@ -80,40 +101,54 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
80101
)
81102

82103
_replace_with_custom_fn_if_matches_filter(
83-
model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn
104+
model,
105+
_get_subclass_inserter(
106+
Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs
107+
),
108+
filter_fn,
84109
)
85110

111+
86112
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
87113
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
88114
"""
89115
The deprecated implementation for weight only quant API, used as a reference for
90116
numerics and performance
91117
"""
92-
from torchao.quantization.quant_api import _is_linear
93-
from torchao.quantization.quant_api import _get_subclass_inserter
118+
from torchao.quantization.quant_api import _get_subclass_inserter, _is_linear
94119

95120
filter_fn = kwargs.pop("filter_fn", _is_linear)
96121

97122
_replace_with_custom_fn_if_matches_filter(
98123
model,
99-
_get_subclass_inserter(deprecated_tenosr_subclass, enable_parametrization=True, **kwargs),
124+
_get_subclass_inserter(
125+
deprecated_tenosr_subclass, enable_parametrization=True, **kwargs
126+
),
100127
filter_fn,
101128
)
102129

103130
return _ref_change_linear_weights_to_woqtensors
104131

105-
_ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
106-
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
132+
133+
_ref_change_linear_weights_to_int8_woqtensors = (
134+
_get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
135+
)
136+
_ref_change_linear_weights_to_int4_woqtensors = (
137+
_get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
138+
)
107139

108140

109141
torch._dynamo.config.cache_size_limit = 50000
110142

143+
111144
@torch.no_grad
112145
def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
113146
if kwargs is None:
114147
kwargs = {}
115148

116-
m = ToyLinearModel(M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda").eval()
149+
m = ToyLinearModel(
150+
M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda"
151+
).eval()
117152
m_bf16 = copy.deepcopy(m)
118153
m_ref = copy.deepcopy(m)
119154
example_inputs = m.example_inputs()
@@ -130,26 +165,30 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
130165

131166
# perf comparison
132167
from torchao.utils import benchmark_model
168+
133169
# warmup
134170
WARMUP = 20
135171
RUNS = 100
136172

137173
torch._dynamo.reset()
138-
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
174+
m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True)
139175
benchmark_model(m_ref, WARMUP, example_inputs)
140176
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)
141177

142178
torch._dynamo.reset()
143-
m = torch.compile(m, mode='max-autotune', fullgraph=True)
179+
m = torch.compile(m, mode="max-autotune", fullgraph=True)
144180
benchmark_model(m, WARMUP, example_inputs)
145181
elapsed_time = benchmark_model(m, RUNS, example_inputs)
146182

147183
torch._dynamo.reset()
148-
m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
184+
m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True)
149185
benchmark_model(m_bf16, WARMUP, example_inputs)
150186
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
151187

152-
print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}")
188+
print(
189+
f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}"
190+
)
191+
153192

154193
if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available():
155194
all_shapes = [
@@ -158,16 +197,25 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
158197

159198
print("_int8da_int8w_api")
160199
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
200+
161201
for M, N, K in all_shapes:
162-
_bench_quantized_tensor_subclass_perf(_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K)
202+
_bench_quantized_tensor_subclass_perf(
203+
_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K
204+
)
163205

164206
print("_int8wo_api")
165207
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
208+
166209
for M, N, K in all_shapes:
167-
_bench_quantized_tensor_subclass_perf(_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K)
210+
_bench_quantized_tensor_subclass_perf(
211+
_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K
212+
)
168213

169214
print("_int4wo_api")
170215
kwargs = {"groupsize": 32}
171216
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
217+
172218
for M, N, K in all_shapes:
173-
_bench_quantized_tensor_subclass_perf(_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs)
219+
_bench_quantized_tensor_subclass_perf(
220+
_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs
221+
)

benchmarks/benchmark_fp6.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
1-
import torch
21
import pandas as pd
2+
import torch
33
import torch.nn.functional as F
4+
from tqdm import tqdm
5+
46
from torchao.dtypes import to_affine_quantized_fpx
5-
from torchao.dtypes.floatx import FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout
7+
from torchao.dtypes.floatx import FloatxTensorCoreLayout
68
from torchao.utils import benchmark_torch_function_in_microseconds
7-
from tqdm import tqdm
89

910

1011
def benchmark(m: int, k: int, n: int):
1112
float_data_fp16 = torch.randn(n, k, dtype=torch.float16, device="cuda")
1213
float_data_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
13-
fp6_weight_fp16 = to_affine_quantized_fpx(float_data_fp16, FloatxTensorCoreLayout(3, 2))
14-
fp6_weight_bf16 = to_affine_quantized_fpx(float_data_bf16, FloatxTensorCoreLayout(3, 2))
14+
fp6_weight_fp16 = to_affine_quantized_fpx(
15+
float_data_fp16, FloatxTensorCoreLayout(3, 2)
16+
)
17+
fp6_weight_bf16 = to_affine_quantized_fpx(
18+
float_data_bf16, FloatxTensorCoreLayout(3, 2)
19+
)
1520
fp16_weight = fp6_weight_fp16.dequantize(torch.float16)
1621
bf16_weight = fp6_weight_bf16.dequantize(torch.bfloat16)
1722

@@ -22,15 +27,27 @@ def benchmark(m: int, k: int, n: int):
2227
fp16_output = F.linear(fp16_act, fp16_weight)
2328
bf16_output = F.linear(bf16_act, bf16_weight)
2429

25-
fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight)
26-
bf16_time = benchmark_torch_function_in_microseconds(F.linear, bf16_act, bf16_weight)
27-
fp6_time_fp16 = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight_fp16)
28-
fp6_time_bf16 = benchmark_torch_function_in_microseconds(F.linear, bf16_act, fp6_weight_bf16)
30+
fp16_time = benchmark_torch_function_in_microseconds(
31+
F.linear, fp16_act, fp16_weight
32+
)
33+
bf16_time = benchmark_torch_function_in_microseconds(
34+
F.linear, bf16_act, bf16_weight
35+
)
36+
fp6_time_fp16 = benchmark_torch_function_in_microseconds(
37+
F.linear, fp16_act, fp6_weight_fp16
38+
)
39+
fp6_time_bf16 = benchmark_torch_function_in_microseconds(
40+
F.linear, bf16_act, fp6_weight_bf16
41+
)
2942

3043
# follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py
3144
# doesn't seem to be the right way to check for correctness
32-
correct_fp16 = (fp6_output_fp16 - fp16_output).abs().mean() / fp16_output.abs().mean() < 1e-3
33-
correct_bf16 = (fp6_output_bf16 - bf16_output).abs().mean() / bf16_output.abs().mean() < 1e-2
45+
correct_fp16 = (
46+
fp6_output_fp16 - fp16_output
47+
).abs().mean() / fp16_output.abs().mean() < 1e-3
48+
correct_bf16 = (
49+
fp6_output_bf16 - bf16_output
50+
).abs().mean() / bf16_output.abs().mean() < 1e-2
3451

3552
return {
3653
"m": m,

0 commit comments

Comments
 (0)