|
| 1 | +import torch |
| 2 | + |
| 3 | +from triton import Config, autotune, cdiv, heuristics, jit |
| 4 | +from triton import language as tl |
| 5 | +from .matmul_perf_model import early_config_prune, estimate_matmul_time |
| 6 | + |
| 7 | +_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32] |
| 8 | + |
| 9 | + |
| 10 | +def upcast_if_fp8(a): |
| 11 | + if "fp8" in str(a): |
| 12 | + return torch.float16 |
| 13 | + return a |
| 14 | + |
| 15 | + |
| 16 | +def get_higher_dtype(a, b): |
| 17 | + a = upcast_if_fp8(a) |
| 18 | + b = upcast_if_fp8(b) |
| 19 | + if a is b: |
| 20 | + return a |
| 21 | + |
| 22 | + assert a in _ordered_datatypes |
| 23 | + assert b in _ordered_datatypes |
| 24 | + |
| 25 | + for d in _ordered_datatypes: |
| 26 | + if a is d: |
| 27 | + return b |
| 28 | + if b is d: |
| 29 | + return a |
| 30 | + |
| 31 | + |
| 32 | +def init_to_zero(name): |
| 33 | + return lambda nargs: nargs[name].zero_() |
| 34 | + |
| 35 | + |
| 36 | +def get_configs_io_bound(): |
| 37 | + configs = [] |
| 38 | + for num_stages in [2, 3, 4, 5, 6]: |
| 39 | + for block_m in [16, 32]: |
| 40 | + for block_k in [32, 64]: |
| 41 | + for block_n in [32, 64, 128, 256]: |
| 42 | + num_warps = 2 if block_n <= 64 else 4 |
| 43 | + configs.append( |
| 44 | + Config( |
| 45 | + { |
| 46 | + "BLOCK_M": block_m, |
| 47 | + "BLOCK_N": block_n, |
| 48 | + "BLOCK_K": block_k, |
| 49 | + "SPLIT_K": 1, |
| 50 | + }, |
| 51 | + num_stages=num_stages, |
| 52 | + num_warps=num_warps, |
| 53 | + ) |
| 54 | + ) |
| 55 | + # split_k |
| 56 | + for split_k in [2, 4, 8, 16]: |
| 57 | + configs.append( |
| 58 | + Config( |
| 59 | + { |
| 60 | + "BLOCK_M": block_m, |
| 61 | + "BLOCK_N": block_n, |
| 62 | + "BLOCK_K": block_k, |
| 63 | + "SPLIT_K": split_k, |
| 64 | + }, |
| 65 | + num_stages=num_stages, |
| 66 | + num_warps=num_warps, |
| 67 | + pre_hook=init_to_zero("C"), |
| 68 | + ) |
| 69 | + ) |
| 70 | + return configs |
| 71 | + |
| 72 | + |
| 73 | +@autotune( |
| 74 | + configs=[ |
| 75 | + # basic configs for compute-bound matmuls |
| 76 | + Config( |
| 77 | + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, |
| 78 | + num_stages=3, |
| 79 | + num_warps=8, |
| 80 | + ), |
| 81 | + Config( |
| 82 | + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, |
| 83 | + num_stages=3, |
| 84 | + num_warps=8, |
| 85 | + ), |
| 86 | + Config( |
| 87 | + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, |
| 88 | + num_stages=4, |
| 89 | + num_warps=4, |
| 90 | + ), |
| 91 | + Config( |
| 92 | + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, |
| 93 | + num_stages=4, |
| 94 | + num_warps=4, |
| 95 | + ), |
| 96 | + Config( |
| 97 | + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, |
| 98 | + num_stages=4, |
| 99 | + num_warps=4, |
| 100 | + ), |
| 101 | + Config( |
| 102 | + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, |
| 103 | + num_stages=4, |
| 104 | + num_warps=4, |
| 105 | + ), |
| 106 | + Config( |
| 107 | + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, |
| 108 | + num_stages=4, |
| 109 | + num_warps=4, |
| 110 | + ), |
| 111 | + Config( |
| 112 | + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, |
| 113 | + num_stages=4, |
| 114 | + num_warps=4, |
| 115 | + ), |
| 116 | + Config( |
| 117 | + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, |
| 118 | + num_stages=5, |
| 119 | + num_warps=2, |
| 120 | + ), |
| 121 | + # good for int8 |
| 122 | + Config( |
| 123 | + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, |
| 124 | + num_stages=3, |
| 125 | + num_warps=8, |
| 126 | + ), |
| 127 | + Config( |
| 128 | + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
| 129 | + num_stages=3, |
| 130 | + num_warps=8, |
| 131 | + ), |
| 132 | + Config( |
| 133 | + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, |
| 134 | + num_stages=4, |
| 135 | + num_warps=4, |
| 136 | + ), |
| 137 | + Config( |
| 138 | + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, |
| 139 | + num_stages=4, |
| 140 | + num_warps=4, |
| 141 | + ), |
| 142 | + Config( |
| 143 | + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, |
| 144 | + num_stages=4, |
| 145 | + num_warps=4, |
| 146 | + ), |
| 147 | + Config( |
| 148 | + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, |
| 149 | + num_stages=4, |
| 150 | + num_warps=4, |
| 151 | + ), |
| 152 | + Config( |
| 153 | + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, |
| 154 | + num_stages=4, |
| 155 | + num_warps=4, |
| 156 | + ), |
| 157 | + Config( |
| 158 | + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, |
| 159 | + num_stages=4, |
| 160 | + num_warps=4, |
| 161 | + ), |
| 162 | + Config( |
| 163 | + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, |
| 164 | + num_stages=5, |
| 165 | + num_warps=2, |
| 166 | + ), |
| 167 | + ] |
| 168 | + + get_configs_io_bound(), |
| 169 | + key=["M", "N", "K"], |
| 170 | + prune_configs_by={ |
| 171 | + "early_config_prune": early_config_prune, |
| 172 | + "perf_model": estimate_matmul_time, |
| 173 | + "top_k": 10, |
| 174 | + }, |
| 175 | +) |
| 176 | +@heuristics( |
| 177 | + { |
| 178 | + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, |
| 179 | + } |
| 180 | +) |
| 181 | +@jit |
| 182 | +def _kernel( |
| 183 | + A, |
| 184 | + B, |
| 185 | + C, |
| 186 | + M, |
| 187 | + N, |
| 188 | + K, # |
| 189 | + stride_am, |
| 190 | + stride_ak, # |
| 191 | + stride_bk, |
| 192 | + stride_bn, # |
| 193 | + stride_cm, |
| 194 | + stride_cn, # |
| 195 | + acc_dtype: tl.constexpr, # |
| 196 | + input_precision: tl.constexpr, # |
| 197 | + fp8_fast_accum: tl.constexpr, # |
| 198 | + BLOCK_M: tl.constexpr, |
| 199 | + BLOCK_N: tl.constexpr, |
| 200 | + BLOCK_K: tl.constexpr, # |
| 201 | + GROUP_M: tl.constexpr, |
| 202 | + SPLIT_K: tl.constexpr, |
| 203 | + EVEN_K: tl.constexpr, |
| 204 | + AB_DTYPE: tl.constexpr, # |
| 205 | +): |
| 206 | + # matrix multiplication |
| 207 | + pid = tl.program_id(0) |
| 208 | + pid_z = tl.program_id(1) |
| 209 | + grid_m = tl.cdiv(M, BLOCK_M) |
| 210 | + grid_n = tl.cdiv(N, BLOCK_N) |
| 211 | + # re-order program ID for better L2 performance |
| 212 | + width = GROUP_M * grid_n |
| 213 | + group_id = pid // width |
| 214 | + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) |
| 215 | + pid_m = group_id * GROUP_M + (pid % group_size) |
| 216 | + pid_n = (pid % width) // (group_size) |
| 217 | + # do matrix multiplication |
| 218 | + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| 219 | + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| 220 | + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) |
| 221 | + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) |
| 222 | + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) |
| 223 | + # pointers |
| 224 | + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) |
| 225 | + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) |
| 226 | + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) |
| 227 | + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): |
| 228 | + if EVEN_K: |
| 229 | + a = tl.load(A) |
| 230 | + b = tl.load(B) |
| 231 | + else: |
| 232 | + k_remaining = K - k * (BLOCK_K * SPLIT_K) |
| 233 | + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) |
| 234 | + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) |
| 235 | + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) |
| 236 | + if AB_DTYPE is not None: |
| 237 | + a = a.to(AB_DTYPE) |
| 238 | + b = b.to(AB_DTYPE) |
| 239 | + if fp8_fast_accum: |
| 240 | + acc = tl.dot( |
| 241 | + a, b, acc, out_dtype=acc_dtype, input_precision=input_precision |
| 242 | + ) |
| 243 | + else: |
| 244 | + acc += tl.dot(a, b, out_dtype=acc_dtype, input_precision=input_precision) |
| 245 | + A += BLOCK_K * SPLIT_K * stride_ak |
| 246 | + B += BLOCK_K * SPLIT_K * stride_bk |
| 247 | + acc = acc.to(C.dtype.element_ty) |
| 248 | + # rematerialize rm and rn to save registers |
| 249 | + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) |
| 250 | + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) |
| 251 | + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) |
| 252 | + mask = (rm < M)[:, None] & (rn < N)[None, :] |
| 253 | + # handles write-back with reduction-splitting |
| 254 | + if SPLIT_K == 1: |
| 255 | + tl.store(C, acc, mask=mask) |
| 256 | + else: |
| 257 | + tl.atomic_add(C, acc, mask=mask) |
| 258 | + |
| 259 | + |
| 260 | +class _matmul(torch.autograd.Function): |
| 261 | + kernel = _kernel |
| 262 | + |
| 263 | + _locks = {} |
| 264 | + |
| 265 | + @staticmethod |
| 266 | + def _call(a, b, acc_dtype, input_precision, fp8_fast_accum, output_dtype): |
| 267 | + device = a.device |
| 268 | + # handle non-contiguous inputs if necessary |
| 269 | + if a.stride(0) > 1 and a.stride(1) > 1: |
| 270 | + a = a.contiguous() |
| 271 | + if b.stride(0) > 1 and b.stride(1) > 1: |
| 272 | + b = b.contiguous() |
| 273 | + # checks constraints |
| 274 | + assert ( |
| 275 | + a.shape[1] == b.shape[0] |
| 276 | + ), f"incompatible dimensions {a.shape} and {b.shape}" |
| 277 | + M, K = a.shape |
| 278 | + _, N = b.shape |
| 279 | + |
| 280 | + # common type between a and b |
| 281 | + ab_dtype = get_higher_dtype(a.dtype, b.dtype) |
| 282 | + |
| 283 | + # allocates output |
| 284 | + if output_dtype is None: |
| 285 | + output_dtype = ab_dtype |
| 286 | + |
| 287 | + c = torch.empty((M, N), device=device, dtype=output_dtype) |
| 288 | + |
| 289 | + # Allowed types for acc_type given the types of a and b. |
| 290 | + supported_acc_dtypes = { |
| 291 | + torch.float16: (torch.float32, torch.float16), |
| 292 | + torch.bfloat16: (torch.float32, torch.bfloat16), |
| 293 | + torch.float32: (torch.float32,), |
| 294 | + torch.int8: (torch.int32,), |
| 295 | + } |
| 296 | + |
| 297 | + if acc_dtype is None: |
| 298 | + acc_dtype = supported_acc_dtypes[ab_dtype][0] |
| 299 | + else: |
| 300 | + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" |
| 301 | + assert ( |
| 302 | + acc_dtype in supported_acc_dtypes[a.dtype] |
| 303 | + ), "acc_dtype not compatible with the type of a" |
| 304 | + assert ( |
| 305 | + acc_dtype in supported_acc_dtypes[b.dtype] |
| 306 | + ), "acc_dtype not compatible with the type of b" |
| 307 | + |
| 308 | + def to_tl_type(ty): |
| 309 | + return getattr(tl, str(ty).split(".")[-1]) |
| 310 | + |
| 311 | + acc_dtype = to_tl_type(acc_dtype) |
| 312 | + ab_dtype = to_tl_type(ab_dtype) |
| 313 | + output_dtype = to_tl_type(output_dtype) |
| 314 | + |
| 315 | + # Tensor cores support input with mixed float8 types. |
| 316 | + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ |
| 317 | + tl.float8e4nv, |
| 318 | + tl.float8e5, |
| 319 | + ]: |
| 320 | + ab_dtype = None |
| 321 | + # launch kernel |
| 322 | + grid = lambda META: ( |
| 323 | + cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]), |
| 324 | + META["SPLIT_K"], |
| 325 | + ) |
| 326 | + _kernel[grid]( |
| 327 | + a, |
| 328 | + b, |
| 329 | + c, |
| 330 | + M, |
| 331 | + N, |
| 332 | + K, # |
| 333 | + a.stride(0), |
| 334 | + a.stride(1), # |
| 335 | + b.stride(0), |
| 336 | + b.stride(1), # |
| 337 | + c.stride(0), |
| 338 | + c.stride(1), # |
| 339 | + acc_dtype=acc_dtype, # |
| 340 | + input_precision=input_precision, # |
| 341 | + fp8_fast_accum=fp8_fast_accum, # |
| 342 | + GROUP_M=8, |
| 343 | + AB_DTYPE=ab_dtype, |
| 344 | + ) |
| 345 | + return c |
| 346 | + |
| 347 | + @staticmethod |
| 348 | + def forward( |
| 349 | + ctx, |
| 350 | + a, |
| 351 | + b, |
| 352 | + acc_dtype=None, |
| 353 | + input_precision=None, |
| 354 | + fp8_fast_accum=True, |
| 355 | + output_dtype=None, |
| 356 | + ): |
| 357 | + return _matmul._call( |
| 358 | + a, |
| 359 | + b, |
| 360 | + acc_dtype=acc_dtype, |
| 361 | + input_precision=input_precision, |
| 362 | + fp8_fast_accum=fp8_fast_accum, |
| 363 | + output_dtype=output_dtype, |
| 364 | + ) |
| 365 | + |
| 366 | + |
| 367 | +matmul = _matmul.apply |
0 commit comments