@@ -74,6 +74,53 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
74
74
return out_uint8 , scales_and_zeros
75
75
76
76
77
+ @triton .jit
78
+ def _int4_to_bf16_fast (packed_vals ,
79
+ BLOCK_SIZE_N : tl .constexpr ,
80
+ BLOCK_SIZE_K : tl .constexpr ):
81
+ # adapted from
82
+ # https://github.com/NVIDIA/cutlass/blob/...
83
+ # ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/numeric_conversion.h#L6486
84
+ cast_lower , cast_upper = tl .inline_asm_elementwise (
85
+ asm = """
86
+ {
87
+ .reg .s32 src_shifted;
88
+ .reg .b32 bias;
89
+
90
+ mov.b32 bias, 0x43084308;
91
+
92
+ shr.s32 src_shifted, $4, 4;
93
+
94
+ // interleaved ordering:
95
+ prmt.b32 $0, $4, src_shifted, 0xF1F0;
96
+ prmt.b32 $1, $4, src_shifted, 0xF3F2;
97
+ prmt.b32 $2, $4, src_shifted, 0xF5F4;
98
+ prmt.b32 $3, $4, src_shifted, 0xF7F6;
99
+
100
+ lop3.b32 $0, $0, 0x000F000F, bias, 0x6a;
101
+ lop3.b32 $1, $1, 0x000F000F, bias, 0x6a;
102
+ lop3.b32 $2, $2, 0x000F000F, bias, 0x6a;
103
+ lop3.b32 $3, $3, 0x000F000F, bias, 0x6a;
104
+
105
+ sub.bf16x2 $0, $0, bias;
106
+ sub.bf16x2 $1, $1, bias;
107
+ sub.bf16x2 $2, $2, bias;
108
+ sub.bf16x2 $3, $3, bias;
109
+ }
110
+ """ ,
111
+ constraints = (
112
+ "=r,=r,=r,=r,"
113
+ "r" ),
114
+ args = [packed_vals ],
115
+ dtype = (tl .bfloat16 , tl .bfloat16 ),
116
+ is_pure = True ,
117
+ pack = 4 ,
118
+ )
119
+ vals = tl .join (cast_lower , cast_upper )
120
+ vals = tl .reshape (vals , (BLOCK_SIZE_N , BLOCK_SIZE_K ))
121
+ return vals
122
+
123
+
77
124
@triton .autotune (configs = AUTOTUNE_CONFIGS , key = ["M" , "N" , "K" ])
78
125
@triton .jit
79
126
def matmul_kernel (
@@ -105,6 +152,7 @@ def matmul_kernel(
105
152
BLOCK_SIZE_N : tl .constexpr ,
106
153
BLOCK_SIZE_K : tl .constexpr ,
107
154
GROUP_SIZE_M : tl .constexpr ,
155
+ FAST_UPCAST_ASM : tl .constexpr ,
108
156
):
109
157
"""Kernel for computing the matmul C = A x B.
110
158
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
@@ -137,7 +185,7 @@ def matmul_kernel(
137
185
offs_ak = tl .arange (0 , BLOCK_SIZE_K )
138
186
offs_bk = tl .arange (0 , BLOCK_SIZE_K // 2 )
139
187
a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_ak [None , :] * stride_ak )
140
- b_ptrs = b_ptr + (offs_bk [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
188
+ b_ptrs = b_ptr + (offs_bn [:, None ] * stride_bn + offs_bk [None , :] * stride_bk )
141
189
142
190
# -----------------------------------------------------------
143
191
# Iterate to compute a block of the C matrix.
@@ -150,21 +198,24 @@ def matmul_kernel(
150
198
b = tl .load (b_ptrs )
151
199
tl .static_assert (b .dtype == tl .int8 )
152
200
153
- # Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
154
- # _4_i8 because the literal "4" is considered an i32, which causes the
155
- # shift operands to be widened to i32.
156
- _4_i8 = tl .full ((1 ,), 4 , dtype = tl .int8 )
157
- b_lo = (b << _4_i8 ) >> _4_i8
158
- b_hi = b >> _4_i8
159
- # Workaround: Convert before the join() so that Triton can load the data
160
- # after the join using ldmatrix.
161
- b_f16 = (
162
- tl .join (b_lo .to (tl .bfloat16 ), b_hi .to (tl .bfloat16 ))
163
- .permute (0 , 2 , 1 )
164
- .reshape (BLOCK_SIZE_K , BLOCK_SIZE_N )
165
- )
201
+ if FAST_UPCAST_ASM :
202
+ # Perform the unpack and upcast using PTX asm
203
+ b_f16 = _int4_to_bf16_fast (b , BLOCK_SIZE_N , BLOCK_SIZE_K )
204
+ else :
205
+ # Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
206
+ # _4_i8 because the literal "4" is considered an i32, which causes the
207
+ # shift operands to be widened to i32.
208
+ _4_i8 = tl .full ((1 ,), 4 , dtype = tl .int8 )
209
+ b_lo = (b << _4_i8 ) >> _4_i8
210
+ b_hi = b >> _4_i8
211
+ # Workaround: Convert before the join() so that Triton can load the data
212
+ # after the join using ldmatrix.
213
+ b_f16 = (
214
+ tl .join (b_lo .to (tl .bfloat16 ), b_hi .to (tl .bfloat16 ))
215
+ .reshape (BLOCK_SIZE_N , BLOCK_SIZE_K )
216
+ )
166
217
167
- accumulator += tl .dot (a , b_f16 )
218
+ accumulator += tl .dot (a , b_f16 . T )
168
219
a_ptrs += BLOCK_SIZE_K * stride_ak
169
220
b_ptrs += BLOCK_SIZE_K * stride_bk // 2
170
221
@@ -185,6 +236,8 @@ def matmul(a, b):
185
236
M , K = a .shape
186
237
_ , N = b .shape
187
238
239
+ fast_upcast_asm = (b .is_cuda and b .stride (0 ) == 1 )
240
+
188
241
c = torch .empty ((M , N ), device = a .device , dtype = torch .bfloat16 )
189
242
grid = lambda META : (
190
243
triton .cdiv (M , META ["BLOCK_SIZE_M" ]) * triton .cdiv (N , META ["BLOCK_SIZE_N" ]),
@@ -202,6 +255,7 @@ def matmul(a, b):
202
255
b .stride (1 ),
203
256
c .stride (0 ),
204
257
c .stride (1 ),
258
+ FAST_UPCAST_ASM = fast_upcast_asm ,
205
259
)
206
260
return c
207
261
0 commit comments