Skip to content

Fix Warpspec Matmul to be compatible with OmniFm Shapes #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions tritonbench/operators/gemm/warp_spec_persistent_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
on blackwell with/without warpspec.
"""

import functools
import logging
from typing import Optional

import torch
Expand All @@ -13,6 +15,27 @@
# TODO: Add proton support


def torch_dtype_to_triton_dtype(dtype):
if dtype == torch.float16:
return tl.float16
elif dtype == torch.float32:
return tl.float32
elif dtype == torch.float8_e4m3fn:
return tl.float8e4nv
elif dtype == torch.bfloat16:
return tl.bfloat16
else:
raise ValueError(f"Unsupported dtype: {dtype}")


def check_tma_alignment(strides, elem_bytes):
for stride in strides[:-1]:
if (stride * elem_bytes) % 16 != 0:
raise RuntimeError("strides must be 16-byte aligned")
if strides[-1] != 1:
raise RuntimeError("Last dimension must be contiguous")


def _matmul_launch_metadata(grid, kernel, args):
ret = {}
M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False)
Expand All @@ -21,7 +44,8 @@ def _matmul_launch_metadata(grid, kernel, args):
if "c_ptr" in args:
bytes_per_elem = args["c_ptr"].element_size()
else:
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
# ceil division to capture the correct number of bytes
bytes_per_elem = (args["DTYPE"].int_bitwidth + 7) // 8
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K
ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
return ret
Expand Down Expand Up @@ -77,10 +101,10 @@ def matmul_kernel_tma(
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
WARP_SPECIALIZE: tl.constexpr, #
DTYPE: tl.constexpr,
):
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
dtype = DTYPE

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand Down Expand Up @@ -112,8 +136,24 @@ def matmul_kernel_tma(
c_desc.store([offs_cm, offs_cn], c)


@functools.lru_cache
def warn_once(msg: str):
"""
Wrapper around logging.warning to try minimize the number of warnings when
a function is repeatedly called.
"""
logging.warning(
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
)


def blackwell_matmul_tma(a, b, warp_specialize: bool):
# Check constraints.
if a.shape[1] != b.shape[1]:
warn_once(
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
)
b = b.T.contiguous()
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
assert a.dtype == b.dtype, "Incompatible dtypes"

Expand Down Expand Up @@ -141,8 +181,8 @@ def grid(META):
M,
N,
K, #
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
WARP_SPECIALIZE=warp_specialize, #
DTYPE=torch_dtype_to_triton_dtype(dtype), #
)
return c

Expand Down Expand Up @@ -196,12 +236,12 @@ def matmul_kernel_tma_persistent(
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
EPILOGUE_SUBTILE: tl.constexpr, #
NUM_SMS: tl.constexpr, #
WARP_SPECIALIZE: tl.constexpr, #
DTYPE: tl.constexpr,
):
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
dtype = DTYPE
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
Expand Down Expand Up @@ -256,9 +296,17 @@ def matmul_kernel_tma_persistent(

def blackwell_matmul_tma_persistent(a, b, warp_specialize: bool):
# Check constraints.
if a.shape[1] != b.shape[1]:
warn_once(
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
)
b = b.T.contiguous()
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
assert a.dtype == b.dtype, "Incompatible dtypes"

check_tma_alignment(a.stride(), (torch.finfo(a.dtype).bits + 7) // 8)
check_tma_alignment(b.stride(), (torch.finfo(b.dtype).bits + 7) // 8)

M, K = a.shape
N, K = b.shape
dtype = a.dtype
Expand Down Expand Up @@ -291,9 +339,9 @@ def grid(META):
M,
N,
K, #
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
NUM_SMS=NUM_SMS, #
WARP_SPECIALIZE=warp_specialize, #
DTYPE=torch_dtype_to_triton_dtype(dtype), #
)
return c

Expand Down Expand Up @@ -403,6 +451,11 @@ def matmul_kernel_descriptor_persistent(

def blackwell_matmul_descriptor_persistent(a, b, warp_specialize: bool):
# Check constraints.
if a.shape[1] != b.shape[1]:
warn_once(
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
)
b = b.T.contiguous()
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
assert a.dtype == b.dtype, "Incompatible dtypes"

Expand Down