Skip to content

Commit 3b10b81

Browse files
committed
Add int4 > bf16 PTX asm support
1 parent a130026 commit 3b10b81

File tree

1 file changed

+66
-15
lines changed

1 file changed

+66
-15
lines changed

tritonbench/operators/int4_gemm/kernel.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,51 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
7474
return out_uint8, scales_and_zeros
7575

7676

77+
@triton.jit
78+
def _int4_to_bf16_fast(
79+
packed_vals, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr
80+
):
81+
# adapted from
82+
# https://github.com/NVIDIA/cutlass/blob/...
83+
# ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/numeric_conversion.h#L6486
84+
cast_lower, cast_upper = tl.inline_asm_elementwise(
85+
asm="""
86+
{
87+
.reg .s32 src_shifted;
88+
.reg .b32 bias;
89+
90+
mov.b32 bias, 0x43084308;
91+
92+
shr.s32 src_shifted, $4, 4;
93+
94+
// interleaved ordering:
95+
prmt.b32 $0, $4, src_shifted, 0xF1F0;
96+
prmt.b32 $1, $4, src_shifted, 0xF3F2;
97+
prmt.b32 $2, $4, src_shifted, 0xF5F4;
98+
prmt.b32 $3, $4, src_shifted, 0xF7F6;
99+
100+
lop3.b32 $0, $0, 0x000F000F, bias, 0x6a;
101+
lop3.b32 $1, $1, 0x000F000F, bias, 0x6a;
102+
lop3.b32 $2, $2, 0x000F000F, bias, 0x6a;
103+
lop3.b32 $3, $3, 0x000F000F, bias, 0x6a;
104+
105+
sub.bf16x2 $0, $0, bias;
106+
sub.bf16x2 $1, $1, bias;
107+
sub.bf16x2 $2, $2, bias;
108+
sub.bf16x2 $3, $3, bias;
109+
}
110+
""",
111+
constraints=("=r,=r,=r,=r," "r"),
112+
args=[packed_vals],
113+
dtype=(tl.bfloat16, tl.bfloat16),
114+
is_pure=True,
115+
pack=4,
116+
)
117+
vals = tl.join(cast_lower, cast_upper)
118+
vals = tl.reshape(vals, (BLOCK_SIZE_N, BLOCK_SIZE_K))
119+
return vals
120+
121+
77122
@triton.autotune(configs=AUTOTUNE_CONFIGS, key=["M", "N", "K"])
78123
@triton.jit
79124
def matmul_kernel(
@@ -105,6 +150,7 @@ def matmul_kernel(
105150
BLOCK_SIZE_N: tl.constexpr,
106151
BLOCK_SIZE_K: tl.constexpr,
107152
GROUP_SIZE_M: tl.constexpr,
153+
FAST_UPCAST_ASM: tl.constexpr,
108154
):
109155
"""Kernel for computing the matmul C = A x B.
110156
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
@@ -137,7 +183,7 @@ def matmul_kernel(
137183
offs_ak = tl.arange(0, BLOCK_SIZE_K)
138184
offs_bk = tl.arange(0, BLOCK_SIZE_K // 2)
139185
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
140-
b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
186+
b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_bk[None, :] * stride_bk)
141187

142188
# -----------------------------------------------------------
143189
# Iterate to compute a block of the C matrix.
@@ -150,21 +196,23 @@ def matmul_kernel(
150196
b = tl.load(b_ptrs)
151197
tl.static_assert(b.dtype == tl.int8)
152198

153-
# Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
154-
# _4_i8 because the literal "4" is considered an i32, which causes the
155-
# shift operands to be widened to i32.
156-
_4_i8 = tl.full((1,), 4, dtype=tl.int8)
157-
b_lo = (b << _4_i8) >> _4_i8
158-
b_hi = b >> _4_i8
159-
# Workaround: Convert before the join() so that Triton can load the data
160-
# after the join using ldmatrix.
161-
b_f16 = (
162-
tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16))
163-
.permute(0, 2, 1)
164-
.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N)
165-
)
199+
if FAST_UPCAST_ASM:
200+
# Perform the unpack and upcast using PTX asm
201+
b_f16 = _int4_to_bf16_fast(b, BLOCK_SIZE_N, BLOCK_SIZE_K)
202+
else:
203+
# Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
204+
# _4_i8 because the literal "4" is considered an i32, which causes the
205+
# shift operands to be widened to i32.
206+
_4_i8 = tl.full((1,), 4, dtype=tl.int8)
207+
b_lo = (b << _4_i8) >> _4_i8
208+
b_hi = b >> _4_i8
209+
# Workaround: Convert before the join() so that Triton can load the data
210+
# after the join using ldmatrix.
211+
b_f16 = tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16)).reshape(
212+
BLOCK_SIZE_N, BLOCK_SIZE_K
213+
)
166214

167-
accumulator += tl.dot(a, b_f16)
215+
accumulator += tl.dot(a, b_f16.T)
168216
a_ptrs += BLOCK_SIZE_K * stride_ak
169217
b_ptrs += BLOCK_SIZE_K * stride_bk // 2
170218

@@ -185,6 +233,8 @@ def matmul(a, b):
185233
M, K = a.shape
186234
_, N = b.shape
187235

236+
fast_upcast_asm = b.is_cuda and b.stride(0) == 1
237+
188238
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
189239
grid = lambda META: (
190240
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
@@ -202,6 +252,7 @@ def matmul(a, b):
202252
b.stride(1),
203253
c.stride(0),
204254
c.stride(1),
255+
FAST_UPCAST_ASM=fast_upcast_asm,
205256
)
206257
return c
207258

0 commit comments

Comments
 (0)