Skip to content

Commit 7f0621d

Browse files
authored
Update micro benchmarking code for AQT (#673)
Summary: Just benchmark a single linear module with (m * k) * (k * n) problem size Test Plan: python benchmarks/benchmark_aq.py Reviewers: Subscribers: Tasks: Tags:
1 parent 18e38f1 commit 7f0621d

File tree

1 file changed

+74
-21
lines changed

1 file changed

+74
-21
lines changed

benchmarks/benchmark_aq.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,60 @@
77
)
88
from torchao.utils import (
99
TORCH_VERSION_AT_LEAST_2_4,
10+
TORCH_VERSION_AT_LEAST_2_5,
1011
)
1112
from torchao.quantization.quant_api import (
13+
int4_weight_only,
14+
int8_weight_only,
15+
int8_dynamic_activation_int8_weight,
16+
quantize_,
1217
_replace_with_custom_fn_if_matches_filter,
1318
)
1419
import copy
1520

21+
def _int8wo_api(mod, **kwargs):
22+
if TORCH_VERSION_AT_LEAST_2_4:
23+
quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False)
24+
if not TORCH_VERSION_AT_LEAST_2_5:
25+
unwrap_tensor_subclass(mod)
26+
else:
27+
change_linear_weights_to_int8_woqtensors(mod, **kwargs)
28+
29+
def _int8da_int8w_api(mod, **kwargs):
30+
if TORCH_VERSION_AT_LEAST_2_4:
31+
quantize_(mod, int8_dynamic_activation_int8_weight(**kwargs), set_inductor_config=False)
32+
if not TORCH_VERSION_AT_LEAST_2_5:
33+
unwrap_tensor_subclass(mod)
34+
else:
35+
change_linear_weights_to_int8_dqtensors(mod, **kwargs)
36+
37+
def _int4wo_api(mod, **kwargs):
38+
if TORCH_VERSION_AT_LEAST_2_4:
39+
kwargs_copy = kwargs.copy()
40+
if "groupsize" in kwargs_copy:
41+
kwargs_copy["group_size"] = kwargs_copy["groupsize"]
42+
del kwargs_copy["groupsize"]
43+
quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False)
44+
if not TORCH_VERSION_AT_LEAST_2_5:
45+
unwrap_tensor_subclass(mod)
46+
else:
47+
change_linear_weights_to_int4_woqtensors(mod, **kwargs)
48+
1649
class ToyLinearModel(torch.nn.Module):
17-
def __init__(self, m=64, n=32, k=64):
50+
"""Single linear for m * k * n problem size
51+
"""
52+
def __init__(self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"):
1853
super().__init__()
19-
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
20-
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
54+
self.m = m
55+
self.dtype = dtype
56+
self.device = device
57+
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(dtype=self.dtype, device=self.device)
2158

22-
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
23-
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
59+
def example_inputs(self):
60+
return (torch.randn(self.m, self.linear.in_features, dtype=self.dtype, device=self.device),)
2461

2562
def forward(self, x):
26-
x = self.linear1(x)
27-
x = self.linear2(x)
63+
x = self.linear(x)
2864
return x
2965

3066
def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
@@ -69,14 +105,17 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
69105
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
70106

71107

72-
def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
108+
torch._dynamo.config.cache_size_limit = 50000
109+
110+
@torch.no_grad
111+
def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
73112
if kwargs is None:
74113
kwargs = {}
75114

76-
m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
115+
m = ToyLinearModel(M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda").eval()
116+
m_bf16 = copy.deepcopy(m)
77117
m_ref = copy.deepcopy(m)
78-
# setting batch_size to 20 to be compatible with the kernel
79-
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
118+
example_inputs = m.example_inputs()
80119

81120
api(m, **kwargs)
82121

@@ -91,27 +130,41 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
91130
# perf comparison
92131
from torchao.utils import benchmark_model
93132
# warmup
94-
WARMUP = 5
133+
WARMUP = 20
95134
RUNS = 100
96-
m = torch.compile(m, mode='max-autotune', fullgraph=True)
97-
98-
benchmark_model(m, WARMUP, example_inputs)
99-
elapsed_time = benchmark_model(m, RUNS, example_inputs)
100135

101136
m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
102137
benchmark_model(m_ref, WARMUP, example_inputs)
103138
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)
104139

105-
print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
106-
assert elapsed_time < 1.05 * ref_elapsed_time
140+
m = torch.compile(m, mode='max-autotune', fullgraph=True)
141+
benchmark_model(m, WARMUP, example_inputs)
142+
elapsed_time = benchmark_model(m, RUNS, example_inputs)
143+
144+
145+
m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
146+
benchmark_model(m_bf16, WARMUP, example_inputs)
147+
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
148+
149+
print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}")
107150

108151
if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available():
152+
all_shapes = [
153+
(20, 2048, 2048),
154+
]
155+
156+
print("_int8da_int8w_api")
109157
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
110-
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors)
158+
for M, N, K in all_shapes:
159+
_bench_quantized_tensor_subclass_perf(_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K)
111160

161+
print("_int8wo_api")
112162
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
113-
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors)
163+
for M, N, K in all_shapes:
164+
_bench_quantized_tensor_subclass_perf(_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K)
114165

166+
print("_int4wo_api")
115167
kwargs = {"groupsize": 32}
116168
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
117-
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs)
169+
for M, N, K in all_shapes:
170+
_bench_quantized_tensor_subclass_perf(_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs)

0 commit comments

Comments
 (0)