Skip to content

Add int4 > bf16 PTX asm support #224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 66 additions & 15 deletions tritonbench/operators/int4_gemm/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,51 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
return out_uint8, scales_and_zeros


@triton.jit
def _int4_to_bf16_fast(
packed_vals, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
):
# adapted from
# https://github.com/NVIDIA/cutlass/blob/...
# ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/numeric_conversion.h#L6486
cast_lower, cast_upper = tl.inline_asm_elementwise(
asm="""
{
.reg .s32 src_shifted;
.reg .b32 bias;

mov.b32 bias, 0x43084308;

shr.s32 src_shifted, $4, 4;

// interleaved ordering:
prmt.b32 $0, $4, src_shifted, 0xF1F0;
prmt.b32 $1, $4, src_shifted, 0xF3F2;
prmt.b32 $2, $4, src_shifted, 0xF5F4;
prmt.b32 $3, $4, src_shifted, 0xF7F6;

lop3.b32 $0, $0, 0x000F000F, bias, 0x6a;
lop3.b32 $1, $1, 0x000F000F, bias, 0x6a;
lop3.b32 $2, $2, 0x000F000F, bias, 0x6a;
lop3.b32 $3, $3, 0x000F000F, bias, 0x6a;

sub.bf16x2 $0, $0, bias;
sub.bf16x2 $1, $1, bias;
sub.bf16x2 $2, $2, bias;
sub.bf16x2 $3, $3, bias;
}
""",
constraints=("=r,=r,=r,=r," "r"),
args=[packed_vals],
dtype=(tl.bfloat16, tl.bfloat16),
is_pure=True,
pack=4,
)
vals = tl.join(cast_lower, cast_upper)
vals = tl.reshape(vals, (BLOCK_SIZE_N, BLOCK_SIZE_K))
return vals


@triton.autotune(configs=AUTOTUNE_CONFIGS, key=["M", "N", "K"])
@triton.jit
def matmul_kernel(
Expand Down Expand Up @@ -105,6 +150,7 @@ def matmul_kernel(
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
FAST_UPCAST_ASM: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
Expand Down Expand Up @@ -137,7 +183,7 @@ def matmul_kernel(
offs_ak = tl.arange(0, BLOCK_SIZE_K)
offs_bk = tl.arange(0, BLOCK_SIZE_K // 2)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_bk[None, :] * stride_bk)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
Expand All @@ -150,21 +196,23 @@ def matmul_kernel(
b = tl.load(b_ptrs)
tl.static_assert(b.dtype == tl.int8)

# Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
# _4_i8 because the literal "4" is considered an i32, which causes the
# shift operands to be widened to i32.
_4_i8 = tl.full((1,), 4, dtype=tl.int8)
b_lo = (b << _4_i8) >> _4_i8
b_hi = b >> _4_i8
# Workaround: Convert before the join() so that Triton can load the data
# after the join using ldmatrix.
b_f16 = (
tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16))
.permute(0, 2, 1)
.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N)
)
if FAST_UPCAST_ASM:
# Perform the unpack and upcast using PTX asm
b_f16 = _int4_to_bf16_fast(b, BLOCK_SIZE_N, BLOCK_SIZE_K)
else:
# Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
# _4_i8 because the literal "4" is considered an i32, which causes the
# shift operands to be widened to i32.
_4_i8 = tl.full((1,), 4, dtype=tl.int8)
b_lo = (b << _4_i8) >> _4_i8
b_hi = b >> _4_i8
# Workaround: Convert before the join() so that Triton can load the data
# after the join using ldmatrix.
b_f16 = tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16)).reshape(
BLOCK_SIZE_N, BLOCK_SIZE_K
)

accumulator += tl.dot(a, b_f16)
accumulator += tl.dot(a, b_f16.T)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk // 2

Expand All @@ -185,6 +233,8 @@ def matmul(a, b):
M, K = a.shape
_, N = b.shape

fast_upcast_asm = b.is_cuda and b.stride(0) == 1

c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
Expand All @@ -202,6 +252,7 @@ def matmul(a, b):
b.stride(1),
c.stride(0),
c.stride(1),
FAST_UPCAST_ASM=fast_upcast_asm,
)
return c

Expand Down
Loading