Skip to content

Commit 8c8bcea

Browse files
committed
Add int4 PTX ASM support
1 parent a130026 commit 8c8bcea

File tree

1 file changed

+69
-15
lines changed

1 file changed

+69
-15
lines changed

tritonbench/operators/int4_gemm/kernel.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,53 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
7474
return out_uint8, scales_and_zeros
7575

7676

77+
@triton.jit
78+
def _int4_to_bf16_fast(packed_vals,
79+
BLOCK_SIZE_N: tl.constexpr,
80+
BLOCK_SIZE_K: tl.constexpr):
81+
# adapted from
82+
# https://github.com/NVIDIA/cutlass/blob/...
83+
# ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/numeric_conversion.h#L6486
84+
cast_lower, cast_upper = tl.inline_asm_elementwise(
85+
asm="""
86+
{
87+
.reg .s32 src_shifted;
88+
.reg .b32 bias;
89+
90+
mov.b32 bias, 0x43084308;
91+
92+
shr.s32 src_shifted, $4, 4;
93+
94+
// interleaved ordering:
95+
prmt.b32 $0, $4, src_shifted, 0xF1F0;
96+
prmt.b32 $1, $4, src_shifted, 0xF3F2;
97+
prmt.b32 $2, $4, src_shifted, 0xF5F4;
98+
prmt.b32 $3, $4, src_shifted, 0xF7F6;
99+
100+
lop3.b32 $0, $0, 0x000F000F, bias, 0x6a;
101+
lop3.b32 $1, $1, 0x000F000F, bias, 0x6a;
102+
lop3.b32 $2, $2, 0x000F000F, bias, 0x6a;
103+
lop3.b32 $3, $3, 0x000F000F, bias, 0x6a;
104+
105+
sub.bf16x2 $0, $0, bias;
106+
sub.bf16x2 $1, $1, bias;
107+
sub.bf16x2 $2, $2, bias;
108+
sub.bf16x2 $3, $3, bias;
109+
}
110+
""",
111+
constraints=(
112+
"=r,=r,=r,=r,"
113+
"r"),
114+
args=[packed_vals],
115+
dtype=(tl.bfloat16, tl.bfloat16),
116+
is_pure=True,
117+
pack=4,
118+
)
119+
vals = tl.join(cast_lower, cast_upper)
120+
vals = tl.reshape(vals, (BLOCK_SIZE_N, BLOCK_SIZE_K))
121+
return vals
122+
123+
77124
@triton.autotune(configs=AUTOTUNE_CONFIGS, key=["M", "N", "K"])
78125
@triton.jit
79126
def matmul_kernel(
@@ -105,6 +152,7 @@ def matmul_kernel(
105152
BLOCK_SIZE_N: tl.constexpr,
106153
BLOCK_SIZE_K: tl.constexpr,
107154
GROUP_SIZE_M: tl.constexpr,
155+
FAST_UPCAST_ASM: tl.constexpr,
108156
):
109157
"""Kernel for computing the matmul C = A x B.
110158
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
@@ -137,7 +185,7 @@ def matmul_kernel(
137185
offs_ak = tl.arange(0, BLOCK_SIZE_K)
138186
offs_bk = tl.arange(0, BLOCK_SIZE_K // 2)
139187
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak)
140-
b_ptrs = b_ptr + (offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
188+
b_ptrs = b_ptr + (offs_bn[:, None] * stride_bn + offs_bk[None, :] * stride_bk)
141189

142190
# -----------------------------------------------------------
143191
# Iterate to compute a block of the C matrix.
@@ -150,21 +198,24 @@ def matmul_kernel(
150198
b = tl.load(b_ptrs)
151199
tl.static_assert(b.dtype == tl.int8)
152200

153-
# Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
154-
# _4_i8 because the literal "4" is considered an i32, which causes the
155-
# shift operands to be widened to i32.
156-
_4_i8 = tl.full((1,), 4, dtype=tl.int8)
157-
b_lo = (b << _4_i8) >> _4_i8
158-
b_hi = b >> _4_i8
159-
# Workaround: Convert before the join() so that Triton can load the data
160-
# after the join using ldmatrix.
161-
b_f16 = (
162-
tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16))
163-
.permute(0, 2, 1)
164-
.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N)
165-
)
201+
if FAST_UPCAST_ASM:
202+
# Perform the unpack and upcast using PTX asm
203+
b_f16 = _int4_to_bf16_fast(b, BLOCK_SIZE_N, BLOCK_SIZE_K)
204+
else:
205+
# Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
206+
# _4_i8 because the literal "4" is considered an i32, which causes the
207+
# shift operands to be widened to i32.
208+
_4_i8 = tl.full((1,), 4, dtype=tl.int8)
209+
b_lo = (b << _4_i8) >> _4_i8
210+
b_hi = b >> _4_i8
211+
# Workaround: Convert before the join() so that Triton can load the data
212+
# after the join using ldmatrix.
213+
b_f16 = (
214+
tl.join(b_lo.to(tl.bfloat16), b_hi.to(tl.bfloat16))
215+
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
216+
)
166217

167-
accumulator += tl.dot(a, b_f16)
218+
accumulator += tl.dot(a, b_f16.T)
168219
a_ptrs += BLOCK_SIZE_K * stride_ak
169220
b_ptrs += BLOCK_SIZE_K * stride_bk // 2
170221

@@ -185,6 +236,8 @@ def matmul(a, b):
185236
M, K = a.shape
186237
_, N = b.shape
187238

239+
fast_upcast_asm = (b.is_cuda and b.stride(0) == 1)
240+
188241
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
189242
grid = lambda META: (
190243
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
@@ -202,6 +255,7 @@ def matmul(a, b):
202255
b.stride(1),
203256
c.stride(0),
204257
c.stride(1),
258+
FAST_UPCAST_ASM=fast_upcast_asm,
205259
)
206260
return c
207261

0 commit comments

Comments
 (0)