Skip to content

Commit 6bdbc78

Browse files
jwfrommfacebook-github-bot
authored andcommitted
New DeepGemm Style Groupwise Kernel (#4365)
Summary: X-link: facebookresearch/FBGEMM#1433 Pull Request resolved: #4365 Initial enablement of CUTLASS' new groupwise scaling API for FP8 GEMM. This diff adds all the needed scaffolding and we confirm that the kernel runs and produces correct outputs, but I do not yet include tuning that would yield better performance. Interestingly, CUTLASS wants group/block scales in MN major format, while every other groupwise implementation I've seen uses K major. I add an option to our triton blockwise quantization kernels to support this layout. In benchmarking the performance of those quantization kernels, I see that trition blockwise in general (with or without K major output) is quite slow. We may need to iterate on that if this becomes a commonly used kernel. One other interesting consideration is that we may actually have performance benefits from using smaller tiles along N. Right now, we are forced to use tiles that are at least 128 along N. Going down to scale blocks of size [64, 128] would let us use tiles of size 64, opening up more kernel configuration options. Reviewed By: jiawenliu64 Differential Revision: D76830629 fbshipit-source-id: cbaf199f54b0b627ff63eba6b8af90d94d448863
1 parent 474f66c commit 6bdbc78

File tree

9 files changed

+564
-7
lines changed

9 files changed

+564
-7
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3018,6 +3018,7 @@ def _kernel_quantize_fp8_block(
30183018
CLAMP_MAX: tl.constexpr,
30193019
BLOCK_M: tl.constexpr,
30203020
BLOCK_K: tl.constexpr,
3021+
K_MAJOR: tl.constexpr,
30213022
) -> None:
30223023
"""Quantize and scale each [BLOCK_M, BLOCK_K] block.
30233024
@@ -3047,6 +3048,7 @@ def _kernel_quantize_fp8_block(
30473048
CLAMP_MAX (bool): Whether to apply scale_ub.
30483049
BLOCK_M (int): Block size for M dimension of A_scale and kernel.
30493050
BLOCK_K (int): Block size for K dimension of A_scale and kernel.
3051+
K_MAJOR (bool): Whether output scales should be K major (True) or MN major (False).
30503052
"""
30513053
pid = tl.program_id(0)
30523054
grid_k = tl.cdiv(K, BLOCK_K)
@@ -3068,9 +3070,12 @@ def _kernel_quantize_fp8_block(
30683070
block_max = tl.maximum(block_max, EPS)
30693071
scale = MAX_FP8 / block_max
30703072

3071-
tl.store(
3072-
A_scale + block_m * stride_a_scale_m + block_k * stride_a_scale_k, 1.0 / scale
3073-
)
3073+
# Write in transposed order if specified.
3074+
if K_MAJOR:
3075+
scale_offset = block_m * stride_a_scale_m + block_k * stride_a_scale_k
3076+
else:
3077+
scale_offset = block_k * stride_a_scale_m + block_m * stride_a_scale_k
3078+
tl.store(A_scale + scale_offset, 1.0 / scale)
30743079
a_fp8 = a_block * scale
30753080
# Clamp A to fp8 range to make sure there's no overflow.
30763081
# This is required for AMD. Nvidia's default saturation
@@ -3085,6 +3090,7 @@ def triton_quantize_fp8_block(
30853090
block_m: int = 256,
30863091
block_k: int = 256,
30873092
scale_ub: Optional[torch.Tensor] = None,
3093+
K_major: bool = True,
30883094
) -> Tuple[torch.Tensor, torch.Tensor]:
30893095
"""
30903096
Quantize a tensor to fp8 with block-wise scalings.
@@ -3096,10 +3102,12 @@ def triton_quantize_fp8_block(
30963102
block_m (int): Block size for M dimension of scale.
30973103
block_k (int): Block size for K dimension of scale.
30983104
scale_ub: Maximum allowed value for scale.
3105+
K_major (bool): Whether output scales should be K major (True) or MN major (False).
30993106
31003107
Returns:
31013108
torch.Tensor : [M, K] fp8 scaled tensor.
3102-
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block.
3109+
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
3110+
if K_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
31033111
"""
31043112
assert x.device != torch.device(
31053113
"cpu"
@@ -3111,7 +3119,10 @@ def triton_quantize_fp8_block(
31113119
M, K = x.shape
31123120
grid_m = triton.cdiv(M, block_m)
31133121
grid_k = triton.cdiv(K, block_k)
3114-
x_scale = torch.empty((grid_m, grid_k), device=x.device, dtype=torch.float32)
3122+
if K_major:
3123+
x_scale = torch.empty((grid_m, grid_k), device=x.device, dtype=torch.float32)
3124+
else:
3125+
x_scale = torch.ones((grid_k, grid_m), device=x.device, dtype=torch.float32)
31153126
x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
31163127

31173128
_kernel_quantize_fp8_block[(grid_m * grid_k,)](
@@ -3139,6 +3150,8 @@ def triton_quantize_fp8_block(
31393150
BLOCK_M=block_m,
31403151
# pyre-ignore[6]: Incompatible parameter type [6]
31413152
BLOCK_K=block_k,
3153+
# pyre-ignore[6]: Incompatible parameter type [6]
3154+
K_MAJOR=K_major,
31423155
)
31433156

31443157
return x_fp8.view(x_shape), x_scale
@@ -3151,6 +3164,7 @@ def quantize_fp8_block(
31513164
scale_ub: Optional[torch.Tensor] = None,
31523165
use_triton: bool = True,
31533166
output_device: Optional[torch.device] = None,
3167+
K_major: bool = True,
31543168
) -> Tuple[torch.Tensor, torch.Tensor]:
31553169
"""
31563170
Quantize a tensor to fp8 with block-wise scalings and optionally move to output device.
@@ -3164,18 +3178,20 @@ def quantize_fp8_block(
31643178
scale_ub: Maximum allowed value for scale.
31653179
use_triton (bool): Whether to use triton kernel or pytorch.
31663180
output_device (torch.device): Device to optionally move the scaled tensors to.
3181+
K_major (bool): Whether output scales should be K major (True) or MN major (False).
31673182
31683183
Returns:
31693184
torch.Tensor: [M, K] fp8 scaled tensor.
3170-
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block.
3185+
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
3186+
if K_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
31713187
"""
31723188
x_shape = x.shape
31733189
x = x.view(-1, x.size(-1))
31743190
if x.device == torch.device("cpu"):
31753191
logger.info("Triton does not support cpu, falling back to torch ops.")
31763192
use_triton = False
31773193
if use_triton:
3178-
xq, x_scale = triton_quantize_fp8_block(x, block_m, block_k, scale_ub)
3194+
xq, x_scale = triton_quantize_fp8_block(x, block_m, block_k, scale_ub, K_major)
31793195
return xq.view(x_shape), x_scale
31803196
# else use pytorch implementation.
31813197
if not output_device:
@@ -3219,6 +3235,8 @@ def quantize_fp8_block(
32193235
x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
32203236
x_scale = x_scale.to(output_device) # pyre-ignore
32213237
del x, x_padded
3238+
if not K_major:
3239+
x_scale = x_scale.t().contiguous()
32223240
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
32233241

32243242

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,48 @@ def cuda(self) -> bool:
14901490
return True
14911491

14921492

1493+
@register_quantize_op
1494+
class FP8CutlassGroupwiseGemm(QuantizeOpBase):
1495+
"""
1496+
FP8 matmul with group / block scaling.
1497+
"""
1498+
1499+
def preprocess(self, x, w):
1500+
# Quantize weights.
1501+
# Scale is expected to be in [K, N] layout (N Major).
1502+
wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128, K_major=False)
1503+
# Return processed tensors.
1504+
return x, wq, w_scale
1505+
1506+
def quantize(self, x, wq, w_scale):
1507+
# Scale is expected to be in [K, M] layout (M Major).
1508+
xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128, K_major=False)
1509+
# Pretranspose scales to deepgemm format.
1510+
return xq, wq, x_scale, w_scale
1511+
1512+
def compute(self, xq, wq, x_scale, w_scale):
1513+
return torch.ops.fbgemm.f8f8bf16_groupwise(xq, wq, x_scale, w_scale)
1514+
1515+
def quantize_and_compute(self, x, wq, w_scale):
1516+
xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
1517+
return self.compute(xq, wq, x_scale, w_scale)
1518+
1519+
@property
1520+
def name(self) -> str:
1521+
if torch.version.cuda:
1522+
return "cutlass_groupwise"
1523+
else:
1524+
return "ck_groupwise"
1525+
1526+
@property
1527+
def hip(self) -> bool:
1528+
return False
1529+
1530+
@property
1531+
def cuda(self) -> bool:
1532+
return True
1533+
1534+
14931535
####################################################################################################
14941536
# CUTLASS kernel v2
14951537
####################################################################################################
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <ATen/cuda/CUDAContext.h>
11+
#include <c10/cuda/CUDAGuard.h>
12+
// clang-format on
13+
14+
#include "f8f8bf16_groupwise/f8f8bf16_groupwise_manifest.cuh"
15+
#include "fbgemm_gpu/quantize/tuning_cache.hpp"
16+
#include "fbgemm_gpu/quantize/utils.h"
17+
18+
namespace fbgemm_gpu {
19+
20+
#if CUDART_VERSION >= 12000
21+
22+
// FP8 Groupwise Cutlass kernel dispatch.
23+
Kernel_f8f8bf16_groupwise
24+
get_kernel_via_heuristic(int arch, int M, int N, int K) {
25+
// Use shape heuristics to dispatch to optimized kernel configuration.
26+
// Initial enablement includes only one schedule.
27+
if (M <= 16) {
28+
return f8f8bf16_groupwise_128_16_128_1_1_1_9_t;
29+
} else {
30+
return f8f8bf16_groupwise_128_128_128_1_2_1_9_f;
31+
}
32+
}
33+
34+
Kernel_f8f8bf16_groupwise get_kernel_via_tuning(
35+
int arch,
36+
int M,
37+
int N,
38+
int K,
39+
at::Tensor XQ,
40+
at::Tensor WQ,
41+
at::Tensor x_scale,
42+
at::Tensor w_scale) {
43+
// One cache per kernel type
44+
static TuningCache cache("f8f8bf16_groupwise");
45+
46+
// Reducing amount of auto tuning by rounding up M to next power of 2.
47+
M = nextPowerOf2(M);
48+
// Use (M, N, K) shape as the key.
49+
const std::string shape_key =
50+
std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(K);
51+
const auto& kernels = get_f8f8bf16_groupwise_kernels(arch);
52+
auto kernel = cache.findBestKernelMaybeAutotune(
53+
shape_key, kernels, XQ, WQ, x_scale, w_scale);
54+
return kernel;
55+
}
56+
57+
// FP8 Rowwise Cutlass kernel dispatch.
58+
at::Tensor dispatch_fp8_groupwise_kernel(
59+
at::Tensor XQ,
60+
at::Tensor WQ,
61+
at::Tensor x_scale,
62+
at::Tensor w_scale) {
63+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
64+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
65+
int K = XQ.size(-1);
66+
67+
static int arch = -1;
68+
// Avoid expensive cudaGetDeviceProperties call.
69+
if (arch < 0) {
70+
cudaDeviceProp prop;
71+
cudaGetDeviceProperties(&prop, 0);
72+
if (prop.major >= 10) {
73+
arch = 10;
74+
int runtimeVersion;
75+
C10_CUDA_CHECK(cudaRuntimeGetVersion(&runtimeVersion));
76+
TORCH_CHECK(
77+
runtimeVersion >= 12080,
78+
"FP8 GEMM on sm100a or above requires cuda >= 12.8");
79+
} else {
80+
arch = 9;
81+
}
82+
}
83+
84+
// Select kernel to run via heuristics or tuning.
85+
auto kernel = [&]() {
86+
if (std::getenv("FBGEMM_AUTOTUNE_ENABLE")) {
87+
return get_kernel_via_tuning(arch, M, N, K, XQ, WQ, x_scale, w_scale);
88+
} else {
89+
return get_kernel_via_heuristic(arch, M, N, K);
90+
}
91+
}();
92+
// Invoke kernel
93+
return kernel(XQ, WQ, x_scale, w_scale);
94+
}
95+
96+
at::Tensor f8f8bf16_groupwise(
97+
at::Tensor XQ, // FP8
98+
at::Tensor WQ, // FP8
99+
at::Tensor x_scale,
100+
at::Tensor w_scale) {
101+
// Invoke and return rowwise kernel without output argument.
102+
return dispatch_fp8_groupwise_kernel(XQ, WQ, x_scale, w_scale);
103+
}
104+
105+
#else
106+
107+
at::Tensor f8f8bf16_groupwise(
108+
at::Tensor XQ, // FP8
109+
at::Tensor WQ, // FP8
110+
at::Tensor x_scale,
111+
at::Tensor w_scale) {
112+
throw std::runtime_error(
113+
"CUDA version is older than 12.0"); // requires CUDA>=12
114+
}
115+
#endif
116+
117+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f8f8bf16_groupwise_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor f8f8bf16_groupwise_128_128_128_1_2_1_9_f(
14+
at::Tensor XQ,
15+
at::Tensor WQ,
16+
at::Tensor x_scale,
17+
at::Tensor w_scale) {
18+
// Dispatch this kernel to the correct underlying implementation.
19+
return f8f8bf16_groupwise_wrapper<128, 128, 128, 1, 2, 1, 9, false>(
20+
XQ, WQ, x_scale, w_scale);
21+
}
22+
23+
} // namespace fbgemm_gpu
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f8f8bf16_groupwise_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
at::Tensor f8f8bf16_groupwise_128_16_128_1_1_1_9_t(
14+
at::Tensor XQ,
15+
at::Tensor WQ,
16+
at::Tensor x_scale,
17+
at::Tensor w_scale) {
18+
// Dispatch this kernel to the correct underlying implementation.
19+
return f8f8bf16_groupwise_wrapper<128, 16, 128, 1, 1, 1, 9, true>(
20+
XQ, WQ, x_scale, w_scale);
21+
}
22+
23+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)