Skip to content

Commit 0991ba9

Browse files
authored
added gpu benchmarking script (#192)
Add combined GPU sparsity benchmarking script.
1 parent ac8ce4c commit 0991ba9

File tree

2 files changed

+246
-1
lines changed

2 files changed

+246
-1
lines changed

benchmarks/benchmark_gpu_sparsity.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import argparse
2+
import random
3+
4+
import pandas as pd
5+
import torch
6+
import torch.utils.benchmark as benchmark
7+
import torch.nn.functional as F
8+
from torch import nn
9+
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
10+
11+
from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm
12+
from torchao.utils import benchmark_model
13+
from torchao.sparsity.utils import create_semi_structured_tensor, create_block_sparse_tensor
14+
15+
torch.set_printoptions(
16+
precision=2,
17+
threshold=None,
18+
edgeitems=16,
19+
linewidth=480,
20+
profile=None,
21+
sci_mode=False,
22+
)
23+
24+
25+
def run_gpu_sparse_benchmark(m, k, n, args):
26+
dtype = getattr(torch, args.dtype)
27+
28+
x = torch.randn(n, k).to(dtype).cuda()
29+
b = torch.randn(m, dtype=dtype).cuda()
30+
31+
# handle sparsity types
32+
if args.sparsity == "semi-structured":
33+
SparseSemiStructuredTensor._FORCE_CUTLASS = args.backend == "cutlass"
34+
A = create_semi_structured_tensor(m, k, dtype)
35+
A_sparse = to_sparse_semi_structured(A)
36+
elif args.sparsity == "block-sparse":
37+
A = create_block_sparse_tensor(m, k, args.block_size, args.sparsity_level, dtype)
38+
A_sparse = A.to_sparse_bsr(blocksize=args.block_size)
39+
# BSR kernel tuning
40+
if args.bsr_autotune:
41+
print("Tuning kernel params")
42+
optimize_bsr_dense_addmm(m, k, n, args.block_size, args.block_size,
43+
dtype=dtype, sparsity=args.sparsity_level, verbose=True)
44+
else:
45+
raise ValueError(f"Unknown sparsity: {args.sparsity}")
46+
47+
if args.eval_fn == "linear":
48+
dense_output = F.linear(x, A, b)
49+
sparse_output = F.linear(x, A_sparse, b)
50+
# warmup
51+
benchmark_model(F.linear, 10, args=(x, A, b), device_type="cuda")
52+
dense_time = benchmark_model(F.linear, 100, args=(x, A, b), device_type="cuda")
53+
54+
benchmark_model(F.linear, 10, args=(x, A_sparse, b), device_type="cuda")
55+
sparse_time = benchmark_model(F.linear, 100, args=(x, A_sparse, b), device_type="cuda")
56+
elif args.eval_fn == "mm":
57+
dense_output = torch.mm(A, x.t())
58+
sparse_output = torch.mm(A_sparse, x.t())
59+
dense_time = benchmark_in_us(torch.mm, A, x.t())
60+
sparse_time = benchmark_in_us(torch.mm, A_sparse, x.t())
61+
else:
62+
raise ValueError(f"Unknown eval_fn: {args.eval_fn}")
63+
64+
65+
return {
66+
"test_function": args.eval_fn,
67+
"m": m,
68+
"k": k,
69+
"n": n,
70+
"dtype": args.dtype,
71+
"sparse_latency (ms)": sparse_time,
72+
"dense_latency (ms)": dense_time,
73+
"speedup (d/s)": dense_time / sparse_time,
74+
"contiguous": sparse_output.is_contiguous(),
75+
}
76+
77+
78+
if __name__ == "__main__":
79+
parser = argparse.ArgumentParser(description="GPU Sparsity Microbenchmarks")
80+
parser.add_argument(
81+
"--mode",
82+
type=str,
83+
choices=[
84+
"bert-large",
85+
"vit-mlp-shapes",
86+
"nvidia-fixed-k",
87+
"nvidia-fixed-mn",
88+
],
89+
)
90+
parser.add_argument(
91+
"--sparsity",
92+
type=str,
93+
choices=[
94+
"semi-structured",
95+
"block-sparse",
96+
],
97+
)
98+
parser.add_argument(
99+
"--sparsity-level",
100+
type=float,
101+
)
102+
parser.add_argument(
103+
"--block-size",
104+
type=int,
105+
choices=[
106+
16,
107+
32,
108+
64,
109+
]
110+
)
111+
parser.add_argument(
112+
"--dtype",
113+
type=str,
114+
choices=[
115+
"int8",
116+
"float16",
117+
"bfloat16",
118+
"float32",
119+
],
120+
default="bfloat16",
121+
)
122+
parser.add_argument(
123+
"--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt"
124+
)
125+
parser.add_argument("--eval-fn", type=str, choices=["linear", "mm"], default="linear")
126+
parser.add_argument("-contiguous", action="store_true")
127+
parser.add_argument("-save", action="store_true")
128+
parser.add_argument("-bsr-autotune", action="store_true", help="Tune BSR kernel parameters")
129+
args = parser.parse_args()
130+
131+
print(f"Started benchmark: {args}")
132+
133+
if args.mode == "bert-large-shapes":
134+
bert_shapes = [
135+
(3072, 1024, 16384),
136+
(4096, 1024, 16384),
137+
(1024, 1024, 16384),
138+
(1024, 4096, 16384),
139+
]
140+
results = (
141+
run_gpu_sparse_benchmark(m, k, n, args)
142+
for (m, k, n) in bert_shapes
143+
)
144+
elif args.mode == "vit-mlp-shapes":
145+
vit_shapes= [
146+
(768, 3072, 50432),
147+
(3072, 768, 50432),
148+
(1280, 5120, 65792),
149+
(5120, 1280, 65792),
150+
]
151+
results = (
152+
run_gpu_sparse_benchmark(m, k, n, args)
153+
for (m, k, n) in vit_shapes
154+
)
155+
elif args.mode == "nvidia-fixed-k":
156+
mn_vals = [
157+
3072,
158+
4096,
159+
5120,
160+
6144,
161+
7168,
162+
8192,
163+
9216,
164+
10240,
165+
11264,
166+
12288,
167+
13312,
168+
14336,
169+
15360,
170+
16384,
171+
17408,
172+
18432,
173+
19456,
174+
20480,
175+
]
176+
results = (
177+
run_gpu_sparse_benchmark(mn, 10240, mn, args)
178+
for mn in mn_vals
179+
)
180+
elif args.mode == "nvidia-fixed-mn":
181+
k_vals = [
182+
2560,
183+
3840,
184+
5120,
185+
6400,
186+
7680,
187+
8960,
188+
10240,
189+
11520,
190+
12800,
191+
14080,
192+
15360,
193+
16640,
194+
17920,
195+
19200,
196+
20480,
197+
]
198+
results = (
199+
run_gpu_sparse_benchmark(10240, k, 10240, args)
200+
for k in k_vals
201+
)
202+
203+
else:
204+
raise ValueError(f"Unknown mode: {args.mode}")
205+
206+
207+
df = pd.DataFrame.from_records(results)
208+
if args.save:
209+
save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv"
210+
df.to_csv(save_file)
211+
print(f"Finished benchmark: {args.mode} saved results to {save_file}")
212+
print(df)

torchao/sparsity/utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,40 @@
1+
import random
12
import torch
23
from torch.ao.quantization.observer import UniformQuantizationObserverBase
34

4-
__all__ = ["PerChannelNormObserver"]
5+
__all__ = [
6+
"create_block_sparse_tensor",
7+
"create_semi_structured_tensor",
8+
"PerChannelNormObserver",
9+
]
10+
11+
def create_block_sparse_tensor(M, N, blocksize, sparsity, dtype):
12+
assert sparsity <= 1.0 and sparsity >= 0.0, \
13+
"sparsity should be a value between 0 and 1"
14+
A = torch.bernoulli(torch.full((M//blocksize, N//blocksize),
15+
1 - sparsity, dtype=dtype))
16+
A = torch.repeat_interleave(A, blocksize, dim=0)
17+
A = torch.repeat_interleave(A, blocksize, dim=1)
18+
return A.to(dtype).contiguous().cuda()
19+
20+
def create_semi_structured_tensor(
21+
r, c, dtype
22+
):
23+
"""
24+
This function returns a 1:2 sparse matrix of size (r, c).
25+
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
26+
"""
27+
28+
choices = [[0, 1], [1, 0]]
29+
mask_entries = [random.choice(choices) for i in range(r * c // 2)]
30+
31+
mask = (
32+
torch.tensor(mask_entries, dtype=dtype)
33+
.reshape(r, c)
34+
.contiguous()
35+
).cuda()
36+
sparse_weight = torch.rand(r, c).to(dtype).cuda() * mask
37+
return sparse_weight
538

639

740
# Observers

0 commit comments

Comments
 (0)