Skip to content

Commit 5938645

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Groupwise quantization kernel (#4439)
Summary: Pull Request resolved: #4439 X-link: facebookresearch/FBGEMM#1503 When doing groupwise quantization, we previously used the blockwise kernel ` quantize_fp8_block(w, block_m=1, block_k=128)`. However, this is quite inefficient as the blockwise kernel needs to use a 2D grid. We can be much faster using 1 thread per row and iterating over groups within that row. This diff introduces a bespoke groupwise quantization kernel that is dramatically faster than the flattened block approach. Reviewed By: jiawenliu64 Differential Revision: D77689544 fbshipit-source-id: 8059e73d2794f70d9c1f995c908ca036f4cb1680
1 parent 8feae04 commit 5938645

File tree

3 files changed

+259
-4
lines changed

3 files changed

+259
-4
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
matmul_fp8_block,
2222
matmul_fp8_row,
2323
quantize_fp8_block,
24+
quantize_fp8_group,
2425
# packed_row unpacks the values, packed_row_raw returns just the packed tensor
2526
quantize_fp8_packed_row,
2627
quantize_fp8_packed_row_raw,
@@ -517,6 +518,37 @@ def _quantize_matmul_fp8(
517518
(3, 4, 5), torch.device("cuda"), use_bias=False
518519
)
519520

521+
def test_quantize_fp8_group(self) -> None:
522+
def _test_quantize_fp8_group(
523+
shape: Tuple[int, int],
524+
group_size: int,
525+
use_scale_ub: bool = False,
526+
) -> None:
527+
M, K = shape
528+
a = torch.randn(M, K, dtype=torch.float, device="cuda")
529+
530+
scale_ub = (
531+
torch.tensor([1200], dtype=torch.float, device="cuda")
532+
if use_scale_ub
533+
else None
534+
)
535+
536+
a_fp8, a_scale = quantize_fp8_group(a, group_size, scale_ub=scale_ub)
537+
538+
a_torch = a_fp8.to(torch.float)
539+
540+
# Undo scaling.
541+
a_torch = a_torch.view(-1, K // group_size, group_size) * a_scale.unsqueeze(
542+
-1
543+
)
544+
a_torch = a_torch.view(M, K)
545+
546+
self.assertTrue(torch.allclose(a, a_torch, atol=2e-1, rtol=5e-2))
547+
548+
_test_quantize_fp8_group((128, 128), 128)
549+
_test_quantize_fp8_group((1, 256), 64)
550+
_test_quantize_fp8_group((2, 384), 128, use_scale_ub=True)
551+
520552
def test_quantize_fp8_block(self) -> None:
521553
def _test_quantize_fp8_block(
522554
shape: Tuple[int, int],

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 225 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3111,7 +3111,7 @@ def triton_quantize_fp8_block(
31113111
M, K = x.shape
31123112
grid_m = triton.cdiv(M, block_m)
31133113
grid_k = triton.cdiv(K, block_k)
3114-
x_scale = torch.ones((grid_m, grid_k), device=x.device, dtype=torch.float32)
3114+
x_scale = torch.empty((grid_m, grid_k), device=x.device, dtype=torch.float32)
31153115
x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
31163116

31173117
_kernel_quantize_fp8_block[(grid_m * grid_k,)](
@@ -3222,6 +3222,230 @@ def quantize_fp8_block(
32223222
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
32233223

32243224

3225+
@triton.autotune(
3226+
configs=[
3227+
Config({"GROUP_LOAD": 2}),
3228+
Config({"GROUP_LOAD": 4}),
3229+
Config({"GROUP_LOAD": 8}),
3230+
Config({"GROUP_LOAD": 16}),
3231+
Config({"GROUP_LOAD": 32}),
3232+
],
3233+
key=["K"],
3234+
)
3235+
@triton.jit
3236+
def _kernel_quantize_fp8_group(
3237+
A,
3238+
A_scale,
3239+
A_fp8,
3240+
scale_ub,
3241+
M,
3242+
K,
3243+
stride_am,
3244+
stride_ak,
3245+
stride_om,
3246+
stride_ok,
3247+
stride_a_scale_m,
3248+
stride_a_scale_k,
3249+
TL_FP8_DTYPE: tl.constexpr,
3250+
MAX_FP8: tl.constexpr,
3251+
EPS: tl.constexpr,
3252+
CLAMP_MAX: tl.constexpr,
3253+
USE_INT64: tl.constexpr,
3254+
GROUP_SIZE: tl.constexpr,
3255+
GROUP_LOAD: tl.constexpr,
3256+
):
3257+
"""Quantize and scale each GROUP_SIZE chunk of each row.
3258+
3259+
Scale per group i is computed as 1 / (MAX_FP8 / max(abs(A[i:i+GROUP_SIZE])))
3260+
3261+
Each kernel thread is responsible for one row and loads and processes a tunable
3262+
number of groups at once.
3263+
3264+
Args:
3265+
A (Tensor): [M, K] higher precision input tensor.
3266+
A_scale (Tensor): [M, cdiv(K, GROUP_SIZE)] reciprocal scale tensor per group.
3267+
A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a
3268+
scale_ub (Tensor): [1] Maximum allowed value for scale.
3269+
M (int): Number of rows.
3270+
K (int): Number of columns.
3271+
stride_am (int): Stride of m dimension of A.
3272+
stride_ak (int): Stride of k dimension of A.
3273+
stride_om (int): Stride of m dimension of output.
3274+
stride_ok (int): Stride of k dimension of output.
3275+
stride_a_scale_m (int): Stride of m dimension of A_scale.
3276+
stride_a_scale_k (int): Stride of k dimension of A_scale.
3277+
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
3278+
MAX_FP8 (float): Maxmimum expressible value for FP8.
3279+
EPS (float): Epsilon value for numerical stability.
3280+
CLAMP_MAX (bool): Whether to apply scale_ub.
3281+
USE_INT64 (bool): Whether to index using int64, which may be needed for large tensors.
3282+
GROUP_SIZE (int): Group size for K dimension of A_scale and kernel.
3283+
GROUP_LOAD (int): Number of groups to load and process simultaneously.
3284+
"""
3285+
pid = tl.program_id(0)
3286+
if USE_INT64:
3287+
pid = pid.to(tl.int64)
3288+
# We load group_size * group_load chunks at a time.
3289+
row_offset = pid * stride_am
3290+
out_offset = pid * stride_om
3291+
scale_row_offset = pid * stride_a_scale_m
3292+
k_offset = tl.arange(0, GROUP_LOAD * GROUP_SIZE)
3293+
scale_k_offset = tl.arange(0, GROUP_LOAD)
3294+
NUM_GROUPS: tl.constexpr = K // GROUP_SIZE
3295+
3296+
for k in range(0, tl.cdiv(K, (GROUP_LOAD * GROUP_SIZE))):
3297+
# Load groups of the input.
3298+
chunk_offset = k_offset + k * GROUP_LOAD * GROUP_SIZE
3299+
a = tl.load(
3300+
A + row_offset + chunk_offset * stride_ak, mask=chunk_offset < K, other=0.0
3301+
)
3302+
# View loaded chunk as a set of groups.
3303+
a_grouped = tl.reshape(a, [GROUP_LOAD, GROUP_SIZE])
3304+
# Reduce over groups.
3305+
group_max = tl.max(tl.abs(a_grouped), axis=1)
3306+
# Apply clamping if specified.
3307+
if CLAMP_MAX:
3308+
ub = tl.load(scale_ub)
3309+
group_max = tl.clamp(group_max, EPS, ub)
3310+
else:
3311+
group_max = tl.maximum(group_max, EPS)
3312+
# Scale and quantize.
3313+
a_scale = MAX_FP8 / group_max
3314+
scale_chunk_offset = scale_k_offset + k * GROUP_LOAD
3315+
tl.store(
3316+
A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k,
3317+
1.0 / a_scale,
3318+
mask=scale_chunk_offset < NUM_GROUPS,
3319+
)
3320+
# Apply scale to input.
3321+
a_fp8 = a_grouped * a_scale[:, None]
3322+
# Clamp to FP8 range to avoid overflow
3323+
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
3324+
# Write to output.
3325+
tl.store(
3326+
A_fp8 + out_offset + chunk_offset * stride_ok,
3327+
tl.ravel(a_fp8),
3328+
mask=chunk_offset < K,
3329+
)
3330+
3331+
3332+
def triton_quantize_fp8_group(
3333+
x: torch.Tensor,
3334+
group_size: int = 128,
3335+
scale_ub: Optional[torch.Tensor] = None,
3336+
) -> Tuple[torch.Tensor, torch.Tensor]:
3337+
"""
3338+
Quantize a tensor to fp8 with group-wise scalings.
3339+
3340+
Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
3341+
3342+
Args:
3343+
x (torch.Tensor): [M, K] higher precision input tensor.
3344+
group_size (int): Group size for M dimension of scale.
3345+
scale_ub: Maximum allowed value for scale.
3346+
3347+
Returns:
3348+
torch.Tensor: [M, K] fp8 scaled tensor.
3349+
torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
3350+
"""
3351+
assert x.device != torch.device(
3352+
"cpu"
3353+
), "Triton groupwise quantization not supported on cpu."
3354+
x_shape = x.shape
3355+
x = x.view(-1, x.size(-1))
3356+
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
3357+
M, K = x.shape
3358+
k_groups = triton.cdiv(K, group_size)
3359+
x_scale = torch.empty((M, k_groups), device=x.device, dtype=torch.float32)
3360+
x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
3361+
_kernel_quantize_fp8_group[(M,)](
3362+
x,
3363+
x_scale,
3364+
x_fp8,
3365+
scale_ub,
3366+
M,
3367+
K,
3368+
x.stride(0),
3369+
x.stride(1),
3370+
x_fp8.stride(0),
3371+
x_fp8.stride(1),
3372+
x_scale.stride(0),
3373+
x_scale.stride(1),
3374+
TL_FP8_DTYPE=tl_dtype,
3375+
MAX_FP8=max_fp8,
3376+
EPS=eps,
3377+
CLAMP_MAX=scale_ub is not None,
3378+
USE_INT64=x.numel() > (2**32 - 1),
3379+
GROUP_SIZE=group_size,
3380+
)
3381+
return x_fp8.view(x_shape), x_scale
3382+
3383+
3384+
def quantize_fp8_group(
3385+
x: torch.Tensor,
3386+
group_size: int = 128,
3387+
scale_ub: Optional[torch.Tensor] = None,
3388+
use_triton: bool = True,
3389+
output_device: Optional[torch.device] = None,
3390+
) -> Tuple[torch.Tensor, torch.Tensor]:
3391+
"""
3392+
Quantize a tensor to fp8 with group-wise scalings and optionally move to output device.
3393+
3394+
Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
3395+
3396+
Args:
3397+
x (Tensor): [M, K] higher precision input tensor.
3398+
group_size (int): Group size for M dimension of scale.
3399+
scale_ub: Maximum allowed value for scale.
3400+
use_triton (bool): Whether to use triton kernel or pytorch.
3401+
output_device (torch.device): Device to optionally move the scaled tensors to.
3402+
3403+
Returns:
3404+
torch.Tensor: [M, K] fp8 scaled tensor.
3405+
torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
3406+
"""
3407+
x_shape = x.shape
3408+
x = x.view(-1, x.size(-1))
3409+
if x.device == torch.device("cpu"):
3410+
logger.info("Triton does not support cpu, falling back to torch ops.")
3411+
use_triton = False
3412+
if use_triton:
3413+
xq, x_scale = triton_quantize_fp8_group(x, group_size, scale_ub)
3414+
return xq.view(x_shape), x_scale
3415+
# else use pytorch implementation.
3416+
if not output_device:
3417+
output_device = x.device
3418+
3419+
# Get constants.
3420+
pt_dtype, _, max_fp8, eps = get_fp8_constants()
3421+
3422+
M, K = x.shape
3423+
assert (
3424+
K % group_size == 0
3425+
), "K must be divisible by group_size for cpu implementation."
3426+
k_groups = triton.cdiv(K, group_size)
3427+
# View input as colleciton of groups for reduction.
3428+
x_grouped = x.view(M, k_groups, group_size).to(torch.float32)
3429+
# Reduce over groups.
3430+
group_max = x_grouped.abs().amax(dim=2)
3431+
# Apply clamping.
3432+
group_max = (
3433+
torch.clamp(group_max, min=eps, max=scale_ub.item())
3434+
if scale_ub
3435+
else torch.clamp(group_max, min=eps)
3436+
)
3437+
x_scale = torch.empty((M, k_groups), dtype=torch.float32, device=output_device)
3438+
x_scale = max_fp8 / group_max # pyre-ignore
3439+
# pyre-ignore[16]: Undefined attribute [16]
3440+
x_scale[x_scale == float("inf")] = 1.0
3441+
# pyre-ignore[16]: Undefined attribute [16]
3442+
x_fp8 = x.view(-1, k_groups, group_size) * x_scale.unsqueeze(2)
3443+
# Cast and move data to output device (for cpu weight loading).
3444+
x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
3445+
x_scale = x_scale.to(output_device) # pyre-ignore
3446+
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
3447+
3448+
32253449
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
32263450
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
32273451

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
matmul_fp8_block,
2727
matmul_fp8_row,
2828
quantize_fp8_block,
29+
quantize_fp8_group,
2930
quantize_fp8_row,
3031
scale_fp8_row,
3132
triton_quantize_fp8_row,
@@ -1119,9 +1120,7 @@ def preprocess(self, x, w):
11191120
return x, wq, w_scale, out
11201121

11211122
def quantize(self, x, wq, w_scale, out):
1122-
xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128)
1123-
# Pretranspose scales to deepgemm format.
1124-
x_scale = get_col_major_tma_aligned_tensor(x_scale)
1123+
xq, x_scale = quantize_fp8_group(x, group_size=128)
11251124
return xq, wq, x_scale, w_scale, out
11261125

11271126
def compute(self, xq, wq, x_scale, w_scale, out):

0 commit comments

Comments
 (0)