Skip to content

Commit efb6514

Browse files
authored
Autotuner for int mm Triton kernels (#41)
1 parent 6473aab commit efb6514

17 files changed

+1201
-65
lines changed

.lintrunner.toml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
merge_base_with = "origin/main"
2+
3+
[[linter]]
4+
code = 'FLAKE8'
5+
include_patterns = ['**/*.py']
6+
exclude_patterns = [
7+
'third-party/**',
8+
'**/third-party/**',
9+
]
10+
command = [
11+
'python',
12+
'-m',
13+
'lintrunner_adapters',
14+
'run',
15+
'flake8_linter',
16+
'--',
17+
'@{{PATHSFILE}}'
18+
]
19+
init_command = [
20+
'python',
21+
'-m',
22+
'lintrunner_adapters',
23+
'run',
24+
'pip_init',
25+
'--dry-run={{DRYRUN}}',
26+
'--requirement=requirements-lintrunner.txt',
27+
]
28+
29+
# Black + usort
30+
[[linter]]
31+
code = 'UFMT'
32+
include_patterns = [
33+
'**/*.py',
34+
'**/*.pyi',
35+
]
36+
exclude_patterns = [
37+
'third-party/**',
38+
'**/third-party/**',
39+
]
40+
command = [
41+
'python',
42+
'-m',
43+
'lintrunner_adapters',
44+
'run',
45+
'ufmt_linter',
46+
'--',
47+
'@{{PATHSFILE}}'
48+
]
49+
init_command = [
50+
'python',
51+
'-m',
52+
'lintrunner_adapters',
53+
'run',
54+
'pip_init',
55+
'--dry-run={{DRYRUN}}',
56+
'--no-black-binary',
57+
'--requirement=requirements-lintrunner.txt',
58+
]
59+
is_formatter = true

benchmarks/intmm.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import argparse
2+
import csv
3+
import itertools
4+
import math
5+
import pathlib
6+
7+
import torch
8+
import torch.nn.functional as F
9+
import torch.utils.benchmark as benchmark
10+
from torchao.kernel.intmm_triton import int_matmul, int_scaled_matmul
11+
12+
torch._dynamo.config.cache_size_limit = 128
13+
torch._dynamo.config.accumulated_cache_size_limit = 128
14+
15+
dtype = torch.float16
16+
device = "cuda"
17+
18+
19+
def benchmark_in_ms(warmup, iters, f, *args, **kwargs):
20+
for _ in range(warmup):
21+
f(*args, **kwargs)
22+
torch.cuda.synchronize()
23+
start_event = torch.cuda.Event(enable_timing=True)
24+
end_event = torch.cuda.Event(enable_timing=True)
25+
start_event.record()
26+
27+
for _ in range(iters):
28+
f(*args, **kwargs)
29+
30+
end_event.record()
31+
torch.cuda.synchronize()
32+
return start_event.elapsed_time(end_event) / float(iters)
33+
34+
35+
@torch.compile(mode="max-autotune")
36+
def compiled_mm(x, w):
37+
return torch.mm(x, w)
38+
39+
40+
@torch.compile(mode="max-autotune")
41+
def compiled_int_mm(x, w):
42+
return torch._int_mm(x, w)
43+
44+
45+
def run_int_mm_benchmark(x, w, b):
46+
fp_time = benchmark_in_ms(10, 100, torch.mm, x, w)
47+
x_int = x.to(dtype=torch.int8)
48+
w_int = w.to(dtype=torch.int8)
49+
int_mm_time = benchmark_in_ms(10, 100, int_matmul, x_int, w_int)
50+
return fp_time, int_mm_time
51+
52+
53+
def run_int_scaled_mm_benchmark(x, w, b):
54+
scales = x.sum(-1, keepdim=True)
55+
fp_time = benchmark_in_ms(10, 100, lambda x, w, s: torch.mm(x, w) * s, x, w, scales)
56+
x_int = x.to(dtype=torch.int8)
57+
w_int = w.to(dtype=torch.int8)
58+
int_scaled_mm_time = benchmark_in_ms(
59+
10, 100, int_scaled_matmul, x_int, w_int, scales
60+
)
61+
return fp_time, int_scaled_mm_time
62+
63+
64+
def run_benchmarks(shapes):
65+
print("fn,m,k,n,fp_time,int_mm_time,ratio")
66+
positives = []
67+
dtype = torch.bfloat16
68+
device = "cuda"
69+
for fn, (m, k, n) in itertools.product(
70+
[run_int_mm_benchmark, run_int_scaled_mm_benchmark], shapes
71+
):
72+
x = torch.randn(m, k, dtype=dtype, device=device)
73+
w = torch.randn(n, k, dtype=dtype, device=device).t()
74+
b = torch.randn(m, n, dtype=dtype, device=device)
75+
76+
fp_time, int_mm_time = fn(x, w, b)
77+
ratio = fp_time / int_mm_time
78+
result = ",".join(map(str, [fn, m, k, n, fp_time, int_mm_time, ratio]))
79+
print(result)
80+
81+
82+
if __name__ == "__main__":
83+
parser = argparse.ArgumentParser(description="integer matmul benchmarks")
84+
parser.add_argument("file_path", type=str, help="Path to csv file with shapes")
85+
args = parser.parse_args()
86+
# Access the file path provided as an argument
87+
file_path = args.file_path
88+
file_path = pathlib.Path(file_path)
89+
assert file_path.is_file()
90+
91+
# Format is (m, k, n)
92+
shapes = list(csv.reader(open(file_path, "r")))[1:]
93+
# Turn into list of int tuples
94+
shapes = list(map(lambda x: tuple(map(int, x)), shapes))
95+
96+
run_benchmarks(shapes)

benchmarks/intmm_shapes.csv

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
m,k,n
2+
1024,1024,2304
3+
1024,1024,4608
4+
1024,8192,2304
5+
1024,8192,4608
6+
1152,1024,2048
7+
1152,2048,16384
8+
1152,2048,2048
9+
1152,3072,2048
10+
1152,4096,2048
11+
1152,8192,2048
12+
1,2048,1024
13+
1,2048,2048
14+
1,2048,4096
15+
144,2048,16384
16+
144,2048,2048
17+
144,4096,2048
18+
144,8192,2048
19+
1472,1024,154
20+
1472,1024,308
21+
1472,2048,154
22+
1472,2048,308
23+
1472,512,154
24+
1472,512,308
25+
1,512,2048
26+
154,1472,1024
27+
154,1472,2048
28+
154,1472,512
29+
18432,1024,512
30+
18432,1536,512
31+
18432,2048,512
32+
18432,512,4096
33+
18432,512,512
34+
2048,1024,1
35+
2048,1024,2
36+
2048,16384,1152
37+
2048,16384,144
38+
2048,16384,288
39+
2048,16384,576
40+
2048,2048,1
41+
2048,2048,1152
42+
2048,2048,144
43+
2048,2048,2
44+
2048,2048,288
45+
2048,2048,576
46+
2048,4096,1
47+
2048,4096,2
48+
2048,512,18432
49+
2048,512,9216
50+
2,2048,1024
51+
2,2048,2048
52+
2,2048,4096
53+
2304,1024,1024
54+
2304,1024,8192
55+
2304,1536,1024
56+
2304,2048,1024
57+
2304,3072,1024
58+
2304,4096,1024
59+
2304,512,1024
60+
231,4096,1024
61+
231,4096,2048
62+
231,4096,512
63+
231,768,1024
64+
231,768,2048
65+
231,768,512
66+
2,512,2048
67+
288,2048,16384
68+
288,2048,2048
69+
288,4096,2048
70+
288,8192,2048
71+
308,1472,1024
72+
308,1472,2048
73+
308,1472,512
74+
4096,1024,2304
75+
4096,1024,231
76+
4096,1024,4608
77+
4096,1024,462
78+
4096,2048,231
79+
4096,2048,462
80+
4096,512,231
81+
4096,512,462
82+
4608,1024,1024
83+
4608,1024,8192
84+
4608,1536,1024
85+
4608,2048,1024
86+
4608,3072,1024
87+
4608,4096,1024
88+
4608,512,1024
89+
462,4096,1024
90+
462,4096,2048
91+
462,4096,512
92+
462,768,1024
93+
462,768,2048
94+
462,768,512
95+
512,2048,1
96+
512,2048,2
97+
512,4096,18432
98+
512,4096,9216
99+
512,512,18432
100+
512,512,9216
101+
576,1024,2048
102+
576,2048,16384
103+
576,2048,2048
104+
576,3072,2048
105+
576,4096,2048
106+
576,8192,2048
107+
768,1024,231
108+
768,1024,462
109+
768,2048,231
110+
768,2048,462
111+
768,512,231
112+
768,512,462
113+
8192,2048,1152
114+
8192,2048,144
115+
8192,2048,288
116+
8192,2048,576
117+
9216,1024,512
118+
9216,1536,512
119+
9216,2048,512
120+
9216,512,4096
121+
9216,512,512
122+
32768,3072,768
123+
32768,768,2304
124+
32768,768,3072
125+
32768,768,768
126+
39200,768,2304
127+
39200,768,768

benchmarks/print_config_shapes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torchao
2+
3+
from torchao.kernel import autotuner
4+
5+
configs = autotuner._load_best_configs()
6+
7+
print("m,k,n")
8+
for k, v in configs.items():
9+
a_shape = k[1]
10+
b_shape = k[4]
11+
M, K0 = a_shape
12+
K1, N = b_shape
13+
14+
assert K0 == K1
15+
16+
print(f"{M},{K0},{N}")

benchmarks/sam_vit_b_shapes.csv

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
m,k,n
2+
32768,3072,768
3+
32768,768,2304
4+
32768,768,3072
5+
32768,768,768
6+
39200,768,2304
7+
39200,768,768

dev-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
pytest
22
expecttest
3-
packaging
3+
parameterized
4+
packaging

requirements-lintrunner.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Lintrunner itself
2+
lintrunner==0.11.0
3+
lintrunner-adapters==0.11.0
4+
5+
# Flake 8 and its dependencies
6+
flake8==6.0.0
7+
flake8-breakpoint==1.1.0
8+
flake8-bugbear==23.6.5
9+
flake8-comprehensions==3.12.0
10+
flake8-pyi==23.5.0
11+
mccabe==0.7.0
12+
pycodestyle==2.10.0
13+
torchfix==0.1.1
14+
15+
# UFMT
16+
black==24.2.0
17+
ufmt==2.5.1
18+
usort==1.0.5
19+
20+
# Other linters
21+
clang-format==12.0.1
22+
cmakelint==1.4.1

setup.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,35 @@
55

66
import os
77
from datetime import datetime
8-
from setuptools import setup, find_packages
9-
current_date = datetime.now().strftime('%Y.%m.%d')
8+
9+
from setuptools import find_packages, setup
10+
11+
current_date = datetime.now().strftime("%Y.%m.%d")
1012

1113

1214
def read_requirements(file_path):
13-
with open(file_path, 'r') as file:
15+
with open(file_path, "r") as file:
1416
return file.read().splitlines()
1517

18+
1619
# Determine the package name based on the presence of an environment variable
17-
package_name = 'torchao-nightly' if os.environ.get('TORCHAO_NIGHTLY') else 'torchao'
20+
package_name = "torchao-nightly" if os.environ.get("TORCHAO_NIGHTLY") else "torchao"
1821

1922
# Version is year.month.date if using nightlies
20-
version = current_date if package_name == 'torchao-nightly' else '0.0.3'
23+
version = current_date if package_name == "torchao-nightly" else "0.0.3"
2124

2225

2326
setup(
2427
name=package_name,
2528
version=version,
2629
packages=find_packages(),
27-
install_requires=read_requirements('requirements.txt'),
28-
description='Package for applying ao techniques to GPU models',
29-
long_description=open('README.md').read(),
30-
long_description_content_type='text/markdown',
31-
url='https://github.com/pytorch-labs/ao',
30+
include_package_data=True,
31+
package_data={
32+
"torchao.kernel.configs": ["*.pkl"],
33+
},
34+
install_requires=read_requirements("requirements.txt"),
35+
description="Package for applying ao techniques to GPU models",
36+
long_description=open("README.md").read(),
37+
long_description_content_type="text/markdown",
38+
url="https://github.com/pytorch-labs/ao",
3239
)

0 commit comments

Comments
 (0)