Skip to content

Commit 01dc7da

Browse files
authored
replace triton.ops dependencies in pytorch/ao
Differential Revision: D65678605 Pull Request resolved: #1250
1 parent f96e5ec commit 01dc7da

File tree

7 files changed

+650
-7
lines changed

7 files changed

+650
-7
lines changed

torchao/prototype/common/triton/__init__.py

Whitespace-only changes.
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
import torch
2+
3+
from triton import Config, autotune, cdiv, heuristics, jit
4+
from triton import language as tl
5+
from .matmul_perf_model import early_config_prune, estimate_matmul_time
6+
7+
_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32]
8+
9+
10+
def upcast_if_fp8(a):
11+
if "fp8" in str(a):
12+
return torch.float16
13+
return a
14+
15+
16+
def get_higher_dtype(a, b):
17+
a = upcast_if_fp8(a)
18+
b = upcast_if_fp8(b)
19+
if a is b:
20+
return a
21+
22+
assert a in _ordered_datatypes
23+
assert b in _ordered_datatypes
24+
25+
for d in _ordered_datatypes:
26+
if a is d:
27+
return b
28+
if b is d:
29+
return a
30+
31+
32+
def init_to_zero(name):
33+
return lambda nargs: nargs[name].zero_()
34+
35+
36+
def get_configs_io_bound():
37+
configs = []
38+
for num_stages in [2, 3, 4, 5, 6]:
39+
for block_m in [16, 32]:
40+
for block_k in [32, 64]:
41+
for block_n in [32, 64, 128, 256]:
42+
num_warps = 2 if block_n <= 64 else 4
43+
configs.append(
44+
Config(
45+
{
46+
"BLOCK_M": block_m,
47+
"BLOCK_N": block_n,
48+
"BLOCK_K": block_k,
49+
"SPLIT_K": 1,
50+
},
51+
num_stages=num_stages,
52+
num_warps=num_warps,
53+
)
54+
)
55+
# split_k
56+
for split_k in [2, 4, 8, 16]:
57+
configs.append(
58+
Config(
59+
{
60+
"BLOCK_M": block_m,
61+
"BLOCK_N": block_n,
62+
"BLOCK_K": block_k,
63+
"SPLIT_K": split_k,
64+
},
65+
num_stages=num_stages,
66+
num_warps=num_warps,
67+
pre_hook=init_to_zero("C"),
68+
)
69+
)
70+
return configs
71+
72+
73+
@autotune(
74+
configs=[
75+
# basic configs for compute-bound matmuls
76+
Config(
77+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
78+
num_stages=3,
79+
num_warps=8,
80+
),
81+
Config(
82+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
83+
num_stages=3,
84+
num_warps=8,
85+
),
86+
Config(
87+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
88+
num_stages=4,
89+
num_warps=4,
90+
),
91+
Config(
92+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1},
93+
num_stages=4,
94+
num_warps=4,
95+
),
96+
Config(
97+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
98+
num_stages=4,
99+
num_warps=4,
100+
),
101+
Config(
102+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1},
103+
num_stages=4,
104+
num_warps=4,
105+
),
106+
Config(
107+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1},
108+
num_stages=4,
109+
num_warps=4,
110+
),
111+
Config(
112+
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
113+
num_stages=4,
114+
num_warps=4,
115+
),
116+
Config(
117+
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1},
118+
num_stages=5,
119+
num_warps=2,
120+
),
121+
# good for int8
122+
Config(
123+
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
124+
num_stages=3,
125+
num_warps=8,
126+
),
127+
Config(
128+
{"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
129+
num_stages=3,
130+
num_warps=8,
131+
),
132+
Config(
133+
{"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1},
134+
num_stages=4,
135+
num_warps=4,
136+
),
137+
Config(
138+
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1},
139+
num_stages=4,
140+
num_warps=4,
141+
),
142+
Config(
143+
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1},
144+
num_stages=4,
145+
num_warps=4,
146+
),
147+
Config(
148+
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1},
149+
num_stages=4,
150+
num_warps=4,
151+
),
152+
Config(
153+
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1},
154+
num_stages=4,
155+
num_warps=4,
156+
),
157+
Config(
158+
{"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
159+
num_stages=4,
160+
num_warps=4,
161+
),
162+
Config(
163+
{"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1},
164+
num_stages=5,
165+
num_warps=2,
166+
),
167+
]
168+
+ get_configs_io_bound(),
169+
key=["M", "N", "K"],
170+
prune_configs_by={
171+
"early_config_prune": early_config_prune,
172+
"perf_model": estimate_matmul_time,
173+
"top_k": 10,
174+
},
175+
)
176+
@heuristics(
177+
{
178+
"EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0,
179+
}
180+
)
181+
@jit
182+
def _kernel(
183+
A,
184+
B,
185+
C,
186+
M,
187+
N,
188+
K, #
189+
stride_am,
190+
stride_ak, #
191+
stride_bk,
192+
stride_bn, #
193+
stride_cm,
194+
stride_cn, #
195+
acc_dtype: tl.constexpr, #
196+
input_precision: tl.constexpr, #
197+
fp8_fast_accum: tl.constexpr, #
198+
BLOCK_M: tl.constexpr,
199+
BLOCK_N: tl.constexpr,
200+
BLOCK_K: tl.constexpr, #
201+
GROUP_M: tl.constexpr,
202+
SPLIT_K: tl.constexpr,
203+
EVEN_K: tl.constexpr,
204+
AB_DTYPE: tl.constexpr, #
205+
):
206+
# matrix multiplication
207+
pid = tl.program_id(0)
208+
pid_z = tl.program_id(1)
209+
grid_m = tl.cdiv(M, BLOCK_M)
210+
grid_n = tl.cdiv(N, BLOCK_N)
211+
# re-order program ID for better L2 performance
212+
width = GROUP_M * grid_n
213+
group_id = pid // width
214+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
215+
pid_m = group_id * GROUP_M + (pid % group_size)
216+
pid_n = (pid % width) // (group_size)
217+
# do matrix multiplication
218+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
219+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
220+
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
221+
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
222+
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
223+
# pointers
224+
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
225+
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
226+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
227+
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
228+
if EVEN_K:
229+
a = tl.load(A)
230+
b = tl.load(B)
231+
else:
232+
k_remaining = K - k * (BLOCK_K * SPLIT_K)
233+
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
234+
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
235+
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
236+
if AB_DTYPE is not None:
237+
a = a.to(AB_DTYPE)
238+
b = b.to(AB_DTYPE)
239+
if fp8_fast_accum:
240+
acc = tl.dot(
241+
a, b, acc, out_dtype=acc_dtype, input_precision=input_precision
242+
)
243+
else:
244+
acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision)
245+
A += BLOCK_K * SPLIT_K * stride_ak
246+
B += BLOCK_K * SPLIT_K * stride_bk
247+
acc = acc.to(C.dtype.element_ty)
248+
# rematerialize rm and rn to save registers
249+
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
250+
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
251+
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
252+
mask = (rm < M)[:, None] & (rn < N)[None, :]
253+
# handles write-back with reduction-splitting
254+
if SPLIT_K == 1:
255+
tl.store(C, acc, mask=mask)
256+
else:
257+
tl.atomic_add(C, acc, mask=mask)
258+
259+
260+
class _matmul(torch.autograd.Function):
261+
kernel = _kernel
262+
263+
_locks = {}
264+
265+
@staticmethod
266+
def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype):
267+
device = a.device
268+
# handle non-contiguous inputs if necessary
269+
if a.stride(0) > 1 and a.stride(1) > 1:
270+
a = a.contiguous()
271+
if b.stride(0) > 1 and b.stride(1) > 1:
272+
b = b.contiguous()
273+
# checks constraints
274+
assert (
275+
a.shape[1] == b.shape[0]
276+
), f"incompatible dimensions {a.shape} and {b.shape}"
277+
M, K = a.shape
278+
_, N = b.shape
279+
280+
# common type between a and b
281+
ab_dtype = get_higher_dtype(a.dtype, b.dtype)
282+
283+
# allocates output
284+
if output_dtype is None:
285+
output_dtype = ab_dtype
286+
287+
c = torch.empty((M, N), device=device, dtype=output_dtype)
288+
289+
# Allowed types for acc_type given the types of a and b.
290+
supported_acc_dtypes = {
291+
torch.float16: (torch.float32, torch.float16),
292+
torch.bfloat16: (torch.float32, torch.bfloat16),
293+
torch.float32: (torch.float32,),
294+
torch.int8: (torch.int32,),
295+
}
296+
297+
if acc_dtype is None:
298+
acc_dtype = supported_acc_dtypes[ab_dtype][0]
299+
else:
300+
assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype"
301+
assert (
302+
acc_dtype in supported_acc_dtypes[a.dtype]
303+
), "acc_dtype not compatible with the type of a"
304+
assert (
305+
acc_dtype in supported_acc_dtypes[b.dtype]
306+
), "acc_dtype not compatible with the type of b"
307+
308+
def to_tl_type(ty):
309+
return getattr(tl, str(ty).split(".")[-1])
310+
311+
acc_dtype = to_tl_type(acc_dtype)
312+
ab_dtype = to_tl_type(ab_dtype)
313+
output_dtype = to_tl_type(output_dtype)
314+
315+
# Tensor cores support input with mixed float8 types.
316+
if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [
317+
tl.float8e4nv,
318+
tl.float8e5,
319+
]:
320+
ab_dtype = None
321+
# launch kernel
322+
grid = lambda META: (
323+
cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]),
324+
META["SPLIT_K"],
325+
)
326+
_kernel[grid](
327+
a,
328+
b,
329+
c,
330+
M,
331+
N,
332+
K, #
333+
a.stride(0),
334+
a.stride(1), #
335+
b.stride(0),
336+
b.stride(1), #
337+
c.stride(0),
338+
c.stride(1), #
339+
acc_dtype=acc_dtype, #
340+
input_precision=input_precision, #
341+
fp8_fast_accum=fp8_fast_accum, #
342+
GROUP_M=8,
343+
AB_DTYPE=ab_dtype,
344+
)
345+
return c
346+
347+
@staticmethod
348+
def forward(
349+
ctx,
350+
a,
351+
b,
352+
acc_dtype=None,
353+
input_precision=None,
354+
fp8_fast_accum=True,
355+
output_dtype=None,
356+
):
357+
return _matmul._call(
358+
a,
359+
b,
360+
acc_dtype=acc_dtype,
361+
input_precision=input_precision,
362+
fp8_fast_accum=fp8_fast_accum,
363+
output_dtype=output_dtype,
364+
)
365+
366+
367+
matmul = _matmul.apply

0 commit comments

Comments
 (0)