diff --git a/tritonbench/operators/int4_gemm/kernel.py b/tritonbench/operators/int4_gemm/kernel.py index 13d20fae..09b3b42b 100644 --- a/tritonbench/operators/int4_gemm/kernel.py +++ b/tritonbench/operators/int4_gemm/kernel.py @@ -74,6 +74,51 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): return out_uint8, scales_and_zeros +@triton.jit +def _int4_to_bf16_fast( + packed_vals, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr +): + # adapted from + # https://github.com/NVIDIA/cutlass/blob/... + # ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/numeric_conversion.h#L6486 + cast_lower, cast_upper = tl.inline_asm_elementwise( + asm=""" + { + .reg .s32 src_shifted; + .reg .b32 bias; + + mov.b32 bias, 0x43084308; + + shr.s32 src_shifted, $4, 4; + + // interleaved ordering: + prmt.b32 $0, $4, src_shifted, 0xF1F0; + prmt.b32 $1, $4, src_shifted, 0xF3F2; + prmt.b32 $2, $4, src_shifted, 0xF5F4; + prmt.b32 $3, $4, src_shifted, 0xF7F6; + + lop3.b32 $0, $0, 0x000F000F, bias, 0x6a; + lop3.b32 $1, $1, 0x000F000F, bias, 0x6a; + lop3.b32 $2, $2, 0x000F000F, bias, 0x6a; + lop3.b32 $3, $3, 0x000F000F, bias, 0x6a; + + sub.bf16x2 $0, $0, bias; + sub.bf16x2 $1, $1, bias; + sub.bf16x2 $2, $2, bias; + sub.bf16x2 $3, $3, bias; + } + """, + constraints=("=r,=r,=r,=r," "r"), + args=[packed_vals], + dtype=(tl.bfloat16, tl.bfloat16), + is_pure=True, + pack=4, + ) + vals = tl.join(cast_lower, cast_upper) + vals = tl.reshape(vals, (BLOCK_SIZE_N, BLOCK_SIZE_K)) + return vals + + @triton.autotune(configs=AUTOTUNE_CONFIGS, key=["M", "N", "K"]) @triton.jit def matmul_kernel( @@ -105,6 +150,7 @@ def matmul_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + FAST_UPCAST_ASM: tl.constexpr, ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -137,7 +183,7 @@ def matmul_kernel( offs_ak = tl.arange(0, BLOCK_SIZE_K) offs_bk = tl.arange(0, BLOCK_SIZE_K // 2) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_bk[None, :] * stride_bk) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -150,21 +196,23 @@ def matmul_kernel( b = tl.load(b_ptrs) tl.static_assert(b.dtype == tl.int8) - # Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use - # _4_i8 because the literal "4" is considered an i32, which causes the - # shift operands to be widened to i32. - _4_i8 = tl.full((1,), 4, dtype=tl.int8) - b_lo = (b << _4_i8) >> _4_i8 - b_hi = b >> _4_i8 - # Workaround: Convert before the join() so that Triton can load the data - # after the join using ldmatrix. - b_f16 = ( - tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16)) - .permute(0, 2, 1) - .reshape(BLOCK_SIZE_K, BLOCK_SIZE_N) - ) + if FAST_UPCAST_ASM: + # Perform the unpack and upcast using PTX asm + b_f16 = _int4_to_bf16_fast(b, BLOCK_SIZE_N, BLOCK_SIZE_K) + else: + # Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use + # _4_i8 because the literal "4" is considered an i32, which causes the + # shift operands to be widened to i32. + _4_i8 = tl.full((1,), 4, dtype=tl.int8) + b_lo = (b << _4_i8) >> _4_i8 + b_hi = b >> _4_i8 + # Workaround: Convert before the join() so that Triton can load the data + # after the join using ldmatrix. + b_f16 = tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16)).reshape( + BLOCK_SIZE_N, BLOCK_SIZE_K + ) - accumulator += tl.dot(a, b_f16) + accumulator += tl.dot(a, b_f16.T) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk // 2 @@ -185,6 +233,8 @@ def matmul(a, b): M, K = a.shape _, N = b.shape + fast_upcast_asm = b.is_cuda and b.stride(0) == 1 + c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), @@ -202,6 +252,7 @@ def matmul(a, b): b.stride(1), c.stride(0), c.stride(1), + FAST_UPCAST_ASM=fast_upcast_asm, ) return c