@@ -74,6 +74,51 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
74
74
return out_uint8 , scales_and_zeros
75
75
76
76
77
+ @triton .jit
78
+ def _int4_to_bf16_fast (
79
+ packed_vals , BLOCK_SIZE_N : tl .constexpr , BLOCK_SIZE_K : tl .constexpr
80
+ ):
81
+ # adapted from
82
+ # https://github.com/NVIDIA/cutlass/blob/...
83
+ # ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/numeric_conversion.h#L6486
84
+ cast_lower , cast_upper = tl .inline_asm_elementwise (
85
+ asm = """
86
+ {
87
+ .reg .s32 src_shifted;
88
+ .reg .b32 bias;
89
+
90
+ mov.b32 bias, 0x43084308;
91
+
92
+ shr.s32 src_shifted, $4, 4;
93
+
94
+ // interleaved ordering:
95
+ prmt.b32 $0, $4, src_shifted, 0xF1F0;
96
+ prmt.b32 $1, $4, src_shifted, 0xF3F2;
97
+ prmt.b32 $2, $4, src_shifted, 0xF5F4;
98
+ prmt.b32 $3, $4, src_shifted, 0xF7F6;
99
+
100
+ lop3.b32 $0, $0, 0x000F000F, bias, 0x6a;
101
+ lop3.b32 $1, $1, 0x000F000F, bias, 0x6a;
102
+ lop3.b32 $2, $2, 0x000F000F, bias, 0x6a;
103
+ lop3.b32 $3, $3, 0x000F000F, bias, 0x6a;
104
+
105
+ sub.bf16x2 $0, $0, bias;
106
+ sub.bf16x2 $1, $1, bias;
107
+ sub.bf16x2 $2, $2, bias;
108
+ sub.bf16x2 $3, $3, bias;
109
+ }
110
+ """ ,
111
+ constraints = ("=r,=r,=r,=r," "r" ),
112
+ args = [packed_vals ],
113
+ dtype = (tl .bfloat16 , tl .bfloat16 ),
114
+ is_pure = True ,
115
+ pack = 4 ,
116
+ )
117
+ vals = tl .join (cast_lower , cast_upper )
118
+ vals = tl .reshape (vals , (BLOCK_SIZE_N , BLOCK_SIZE_K ))
119
+ return vals
120
+
121
+
77
122
@triton .autotune (configs = AUTOTUNE_CONFIGS , key = ["M" , "N" , "K" ])
78
123
@triton .jit
79
124
def matmul_kernel (
@@ -105,6 +150,7 @@ def matmul_kernel(
105
150
BLOCK_SIZE_N : tl .constexpr ,
106
151
BLOCK_SIZE_K : tl .constexpr ,
107
152
GROUP_SIZE_M : tl .constexpr ,
153
+ FAST_UPCAST_ASM : tl .constexpr ,
108
154
):
109
155
"""Kernel for computing the matmul C = A x B.
110
156
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
@@ -137,7 +183,7 @@ def matmul_kernel(
137
183
offs_ak = tl .arange (0 , BLOCK_SIZE_K )
138
184
offs_bk = tl .arange (0 , BLOCK_SIZE_K // 2 )
139
185
a_ptrs = a_ptr + (offs_am [:, None ] * stride_am + offs_ak [None , :] * stride_ak )
140
- b_ptrs = b_ptr + (offs_bk [:, None ] * stride_bk + offs_bn [None , :] * stride_bn )
186
+ b_ptrs = b_ptr + (offs_bn [:, None ] * stride_bn + offs_bk [None , :] * stride_bk )
141
187
142
188
# -----------------------------------------------------------
143
189
# Iterate to compute a block of the C matrix.
@@ -150,21 +196,23 @@ def matmul_kernel(
150
196
b = tl .load (b_ptrs )
151
197
tl .static_assert (b .dtype == tl .int8 )
152
198
153
- # Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
154
- # _4_i8 because the literal "4" is considered an i32, which causes the
155
- # shift operands to be widened to i32.
156
- _4_i8 = tl .full ((1 ,), 4 , dtype = tl .int8 )
157
- b_lo = (b << _4_i8 ) >> _4_i8
158
- b_hi = b >> _4_i8
159
- # Workaround: Convert before the join() so that Triton can load the data
160
- # after the join using ldmatrix.
161
- b_f16 = (
162
- tl .join (b_lo .to (tl .bfloat16 ), b_hi .to (tl .bfloat16 ))
163
- .permute (0 , 2 , 1 )
164
- .reshape (BLOCK_SIZE_K , BLOCK_SIZE_N )
165
- )
199
+ if FAST_UPCAST_ASM :
200
+ # Perform the unpack and upcast using PTX asm
201
+ b_f16 = _int4_to_bf16_fast (b , BLOCK_SIZE_N , BLOCK_SIZE_K )
202
+ else :
203
+ # Unpack `b` into an fp16 matrix, taking care to sign-extend b_lo. Use
204
+ # _4_i8 because the literal "4" is considered an i32, which causes the
205
+ # shift operands to be widened to i32.
206
+ _4_i8 = tl .full ((1 ,), 4 , dtype = tl .int8 )
207
+ b_lo = (b << _4_i8 ) >> _4_i8
208
+ b_hi = b >> _4_i8
209
+ # Workaround: Convert before the join() so that Triton can load the data
210
+ # after the join using ldmatrix.
211
+ b_f16 = tl .join (b_lo .to (tl .bfloat16 ), b_hi .to (tl .bfloat16 )).reshape (
212
+ BLOCK_SIZE_N , BLOCK_SIZE_K
213
+ )
166
214
167
- accumulator += tl .dot (a , b_f16 )
215
+ accumulator += tl .dot (a , b_f16 . T )
168
216
a_ptrs += BLOCK_SIZE_K * stride_ak
169
217
b_ptrs += BLOCK_SIZE_K * stride_bk // 2
170
218
@@ -185,6 +233,8 @@ def matmul(a, b):
185
233
M , K = a .shape
186
234
_ , N = b .shape
187
235
236
+ fast_upcast_asm = b .is_cuda and b .stride (0 ) == 1
237
+
188
238
c = torch .empty ((M , N ), device = a .device , dtype = torch .bfloat16 )
189
239
grid = lambda META : (
190
240
triton .cdiv (M , META ["BLOCK_SIZE_M" ]) * triton .cdiv (N , META ["BLOCK_SIZE_N" ]),
@@ -202,6 +252,7 @@ def matmul(a, b):
202
252
b .stride (1 ),
203
253
c .stride (0 ),
204
254
c .stride (1 ),
255
+ FAST_UPCAST_ASM = fast_upcast_asm ,
205
256
)
206
257
return c
207
258
0 commit comments