Skip to content

Add small block shapes to warpspec matmul configs #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 86 additions & 15 deletions tritonbench/operators/gemm/warp_spec_persistent_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
on blackwell with/without warpspec.
"""

import functools
import logging
import os
from typing import Optional

import torch
Expand All @@ -13,6 +16,27 @@
# TODO: Add proton support


def torch_dtype_to_triton_dtype(dtype):
if dtype == torch.float16:
return tl.float16
elif dtype == torch.float32:
return tl.float32
elif dtype == torch.float8_e4m3fn:
return tl.float8e4nv
elif dtype == torch.bfloat16:
return tl.bfloat16
else:
raise ValueError(f"Unsupported dtype: {dtype}")


def check_tma_alignment(strides, elem_bytes):
for stride in strides[:-1]:
if (stride * elem_bytes) % 16 != 0:
raise RuntimeError("strides must be 16-byte aligned")
if strides[-1] != 1:
raise RuntimeError("Last dimension must be contiguous")


def _matmul_launch_metadata(grid, kernel, args):
ret = {}
M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False)
Expand All @@ -21,12 +45,30 @@ def _matmul_launch_metadata(grid, kernel, args):
if "c_ptr" in args:
bytes_per_elem = args["c_ptr"].element_size()
else:
bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
# ceil division to capture the correct number of bytes
bytes_per_elem = (args["DTYPE"].int_bitwidth + 7) // 8
ret[f"flops{bytes_per_elem * 8}"] = 2.0 * M * N * K
ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
return ret


small_block_range = [32, 64, 128]
small_stage_range = [1, 2, 3, 4]
include_small_configs = os.environ.get("INCLUDE_SMALL_CONFIGS", "0") == "1"
if include_small_configs:
bm_range = small_block_range
bn_range = small_block_range + [256]
bk_range = small_block_range
default_s_range = small_stage_range
tma_persistent_s_range = small_stage_range
else:
bm_range = [128]
bn_range = [128, 256]
bk_range = [64, 128]
default_s_range = [3, 4]
tma_persistent_s_range = [2, 3, 4]


def matmul_get_configs(pre_hook=None):
return [
triton.Config(
Expand All @@ -40,10 +82,10 @@ def matmul_get_configs(pre_hook=None):
num_warps=w,
pre_hook=pre_hook,
)
for BM in [128]
for BN in [128, 256]
for BK in [64, 128]
for s in ([3, 4])
for BM in bm_range
for BN in bn_range
for BK in bk_range
for s in default_s_range
for w in [4, 8]
]

Expand Down Expand Up @@ -77,10 +119,10 @@ def matmul_kernel_tma(
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
WARP_SPECIALIZE: tl.constexpr, #
DTYPE: tl.constexpr,
):
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
dtype = DTYPE

pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Expand Down Expand Up @@ -112,8 +154,24 @@ def matmul_kernel_tma(
c_desc.store([offs_cm, offs_cn], c)


@functools.lru_cache
def warn_once(msg: str):
"""
Wrapper around logging.warning to try minimize the number of warnings when
a function is repeatedly called.
"""
logging.warning(
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
)


def blackwell_matmul_tma(a, b, warp_specialize: bool):
# Check constraints.
if a.shape[1] != b.shape[1]:
warn_once(
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
)
b = b.T.contiguous()
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
assert a.dtype == b.dtype, "Incompatible dtypes"

Expand Down Expand Up @@ -141,8 +199,8 @@ def grid(META):
M,
N,
K, #
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
WARP_SPECIALIZE=warp_specialize, #
DTYPE=torch_dtype_to_triton_dtype(dtype), #
)
return c

Expand Down Expand Up @@ -171,10 +229,10 @@ def matmul_tma_persistent_get_configs(pre_hook=None):
num_warps=w,
pre_hook=pre_hook,
) #
for BM in [128] #
for BN in [128, 256] #
for BK in [64, 128] #
for s in ([2, 3, 4]) #
for BM in bm_range #
for BN in bn_range #
for BK in bk_range #
for s in tma_persistent_s_range #
for w in [4, 8] #
for SUBTILE in [True, False] #
]
Expand All @@ -196,12 +254,12 @@ def matmul_kernel_tma_persistent(
BLOCK_SIZE_N: tl.constexpr, #
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
EPILOGUE_SUBTILE: tl.constexpr, #
NUM_SMS: tl.constexpr, #
WARP_SPECIALIZE: tl.constexpr, #
DTYPE: tl.constexpr,
):
dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16
dtype = DTYPE
start_pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
Expand Down Expand Up @@ -256,9 +314,17 @@ def matmul_kernel_tma_persistent(

def blackwell_matmul_tma_persistent(a, b, warp_specialize: bool):
# Check constraints.
if a.shape[1] != b.shape[1]:
warn_once(
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
)
b = b.T.contiguous()
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
assert a.dtype == b.dtype, "Incompatible dtypes"

check_tma_alignment(a.stride(), (torch.finfo(a.dtype).bits + 7) // 8)
check_tma_alignment(b.stride(), (torch.finfo(b.dtype).bits + 7) // 8)

M, K = a.shape
N, K = b.shape
dtype = a.dtype
Expand Down Expand Up @@ -291,9 +357,9 @@ def grid(META):
M,
N,
K, #
FP8_OUTPUT=dtype == torch.float8_e4m3fn, #
NUM_SMS=NUM_SMS, #
WARP_SPECIALIZE=warp_specialize, #
DTYPE=torch_dtype_to_triton_dtype(dtype), #
)
return c

Expand Down Expand Up @@ -403,6 +469,11 @@ def matmul_kernel_descriptor_persistent(

def blackwell_matmul_descriptor_persistent(a, b, warp_specialize: bool):
# Check constraints.
if a.shape[1] != b.shape[1]:
warn_once(
"Incompatible dimensions, B is transposed. We are transposing B which may impact results"
)
b = b.T.contiguous()
assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed
assert a.dtype == b.dtype, "Incompatible dtypes"

Expand Down