Skip to content

Commit f7698ea

Browse files
njriasanfacebook-github-bot
authored andcommitted
Fix Warpspec Matmul to be compatible with OmniFm Shapes (#298)
Summary: Fixes a couple tutorial assumptions, most notably: 1. It only worked with fp8 and fp16. Now it works with all dtypes for OmniFm. 2. The shapes were not compatible due to layout mismatches. Since every shape will have a layout mismatch this adds an explicit tranpose to allow benchmarking a "best case" although this may not be accurate. 3. Some shapes will never be compatible with TMA as the strides are not divisble by 16. I added an explicit check in the code to simplify this issue, but I will be skipping these. Reviewed By: PaulZhang12 Differential Revision: D77950060
1 parent 608f961 commit f7698ea

File tree

1 file changed

+60
-7
lines changed

1 file changed

+60
-7
lines changed

tritonbench/operators/gemm/warp_spec_persistent_matmul.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
on blackwell with/without warpspec.
44
"""
55

6+
import functools
7+
import logging
68
from typing import Optional
79

810
import torch
@@ -13,6 +15,27 @@
1315
# TODO: Add proton support
1416

1517

18+
def torch_dtype_to_triton_dtype(dtype):
19+
if dtype == torch.float16:
20+
return tl.float16
21+
elif dtype == torch.float32:
22+
return tl.float32
23+
elif dtype == torch.float8_e4m3fn:
24+
return tl.float8e4nv
25+
elif dtype == torch.bfloat16:
26+
return tl.bfloat16
27+
else:
28+
raise ValueError(f"Unsupported dtype: {dtype}")
29+
30+
31+
def check_tma_alignment(strides, elem_bytes):
32+
for stride in strides[:-1]:
33+
if (stride * elem_bytes) % 16 != 0:
34+
raise RuntimeError("strides must be 16-byte aligned")
35+
if strides[-1] != 1:
36+
raise RuntimeError("Last dimension must be contiguous")
37+
38+
1639
def _matmul_launch_metadata(grid, kernel, args):
1740
ret = {}
1841
M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False)
@@ -21,7 +44,8 @@ def _matmul_launch_metadata(grid, kernel, args):
2144
if "c_ptr" in args:
2245
bytes_per_elem = args["c_ptr"].element_size()
2346
else:
24-
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
47+
# ceil division to capture the correct number of bytes
48+
bytes_per_elem = (args["DTYPE"].int_bitwidth + 7) // 8
2549
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K
2650
ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
2751
return ret
@@ -77,10 +101,10 @@ def matmul_kernel_tma(
77101
BLOCK_SIZE_N: tl.constexpr, #
78102
BLOCK_SIZE_K: tl.constexpr, #
79103
GROUP_SIZE_M: tl.constexpr, #
80-
FP8_OUTPUT: tl.constexpr, #
81104
WARP_SPECIALIZE: tl.constexpr, #
105+
DTYPE: tl.constexpr,
82106
):
83-
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
107+
dtype = DTYPE
84108

85109
pid = tl.program_id(axis=0)
86110
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
@@ -112,8 +136,24 @@ def matmul_kernel_tma(
112136
c_desc.store([offs_cm, offs_cn], c)
113137

114138

139+
@functools.lru_cache
140+
def warn_once(msg: str):
141+
"""
142+
Wrapper around logging.warning to try minimize the number of warnings when
143+
a function is repeatedly called.
144+
"""
145+
logging.warning(
146+
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
147+
)
148+
149+
115150
def blackwell_matmul_tma(a, b, warp_specialize: bool):
116151
# Check constraints.
152+
if a.shape[1] != b.shape[1]:
153+
warn_once(
154+
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
155+
)
156+
b = b.T.contiguous()
117157
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
118158
assert a.dtype == b.dtype, "Incompatible dtypes"
119159

@@ -141,8 +181,8 @@ def grid(META):
141181
M,
142182
N,
143183
K, #
144-
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
145184
WARP_SPECIALIZE=warp_specialize, #
185+
DTYPE=torch_dtype_to_triton_dtype(dtype), #
146186
)
147187
return c
148188

@@ -196,12 +236,12 @@ def matmul_kernel_tma_persistent(
196236
BLOCK_SIZE_N: tl.constexpr, #
197237
BLOCK_SIZE_K: tl.constexpr, #
198238
GROUP_SIZE_M: tl.constexpr, #
199-
FP8_OUTPUT: tl.constexpr, #
200239
EPILOGUE_SUBTILE: tl.constexpr, #
201240
NUM_SMS: tl.constexpr, #
202241
WARP_SPECIALIZE: tl.constexpr, #
242+
DTYPE: tl.constexpr,
203243
):
204-
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
244+
dtype = DTYPE
205245
start_pid = tl.program_id(axis=0)
206246
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
207247
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
@@ -256,9 +296,17 @@ def matmul_kernel_tma_persistent(
256296

257297
def blackwell_matmul_tma_persistent(a, b, warp_specialize: bool):
258298
# Check constraints.
299+
if a.shape[1] != b.shape[1]:
300+
warn_once(
301+
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
302+
)
303+
b = b.T.contiguous()
259304
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
260305
assert a.dtype == b.dtype, "Incompatible dtypes"
261306

307+
check_tma_alignment(a.stride(), (torch.finfo(a.dtype).bits + 7) // 8)
308+
check_tma_alignment(b.stride(), (torch.finfo(b.dtype).bits + 7) // 8)
309+
262310
M, K = a.shape
263311
N, K = b.shape
264312
dtype = a.dtype
@@ -291,9 +339,9 @@ def grid(META):
291339
M,
292340
N,
293341
K, #
294-
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
295342
NUM_SMS=NUM_SMS, #
296343
WARP_SPECIALIZE=warp_specialize, #
344+
DTYPE=torch_dtype_to_triton_dtype(dtype), #
297345
)
298346
return c
299347

@@ -403,6 +451,11 @@ def matmul_kernel_descriptor_persistent(
403451

404452
def blackwell_matmul_descriptor_persistent(a, b, warp_specialize: bool):
405453
# Check constraints.
454+
if a.shape[1] != b.shape[1]:
455+
warn_once(
456+
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
457+
)
458+
b = b.T.contiguous()
406459
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
407460
assert a.dtype == b.dtype, "Incompatible dtypes"
408461

0 commit comments

Comments
 (0)