Skip to content

Commit f22a3d4

Browse files
jwfrommfacebook-github-bot
authored andcommitted
DeepGemm Style Groupwise Group Gemm. (#4464)
Summary: X-link: facebookresearch/FBGEMM#1523 Pull Request resolved: #4464 WIP: Basic foundation all working but CUTLASS requires an annoying layout for activation scales. There's no efficient way to generate the scales in this layout using the existing quantization routines we have, so I'll have to write a new one. Reviewed By: jiawenliu64, jianyuh Differential Revision: D77162544 fbshipit-source-id: a5152ffdaa31f78bffaccf11ea29ec116108b2ae
1 parent a6007f4 commit f22a3d4

File tree

8 files changed

+804
-21
lines changed

8 files changed

+804
-21
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3090,7 +3090,7 @@ def triton_quantize_fp8_block(
30903090
block_m: int = 256,
30913091
block_k: int = 256,
30923092
scale_ub: Optional[torch.Tensor] = None,
3093-
K_major: bool = True,
3093+
k_major: bool = True,
30943094
) -> Tuple[torch.Tensor, torch.Tensor]:
30953095
"""
30963096
Quantize a tensor to fp8 with block-wise scalings.
@@ -3102,12 +3102,12 @@ def triton_quantize_fp8_block(
31023102
block_m (int): Block size for M dimension of scale.
31033103
block_k (int): Block size for K dimension of scale.
31043104
scale_ub: Maximum allowed value for scale.
3105-
K_major (bool): Whether output scales should be K major (True) or MN major (False).
3105+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
31063106
31073107
Returns:
31083108
torch.Tensor : [M, K] fp8 scaled tensor.
31093109
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
3110-
if K_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
3110+
if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
31113111
"""
31123112
assert x.device != torch.device(
31133113
"cpu"
@@ -3119,10 +3119,10 @@ def triton_quantize_fp8_block(
31193119
M, K = x.shape
31203120
grid_m = triton.cdiv(M, block_m)
31213121
grid_k = triton.cdiv(K, block_k)
3122-
if K_major:
3122+
if k_major:
31233123
x_scale = torch.empty((grid_m, grid_k), device=x.device, dtype=torch.float32)
31243124
else:
3125-
x_scale = torch.ones((grid_k, grid_m), device=x.device, dtype=torch.float32)
3125+
x_scale = torch.empty((grid_k, grid_m), device=x.device, dtype=torch.float32)
31263126
x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
31273127

31283128
_kernel_quantize_fp8_block[(grid_m * grid_k,)](
@@ -3151,7 +3151,7 @@ def triton_quantize_fp8_block(
31513151
# pyre-ignore[6]: Incompatible parameter type [6]
31523152
BLOCK_K=block_k,
31533153
# pyre-ignore[6]: Incompatible parameter type [6]
3154-
K_MAJOR=K_major,
3154+
K_MAJOR=k_major,
31553155
)
31563156

31573157
return x_fp8.view(x_shape), x_scale
@@ -3164,7 +3164,7 @@ def quantize_fp8_block(
31643164
scale_ub: Optional[torch.Tensor] = None,
31653165
use_triton: bool = True,
31663166
output_device: Optional[torch.device] = None,
3167-
K_major: bool = True,
3167+
k_major: bool = True,
31683168
) -> Tuple[torch.Tensor, torch.Tensor]:
31693169
"""
31703170
Quantize a tensor to fp8 with block-wise scalings and optionally move to output device.
@@ -3178,20 +3178,20 @@ def quantize_fp8_block(
31783178
scale_ub: Maximum allowed value for scale.
31793179
use_triton (bool): Whether to use triton kernel or pytorch.
31803180
output_device (torch.device): Device to optionally move the scaled tensors to.
3181-
K_major (bool): Whether output scales should be K major (True) or MN major (False).
3181+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
31823182
31833183
Returns:
31843184
torch.Tensor: [M, K] fp8 scaled tensor.
31853185
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
3186-
if K_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
3186+
if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
31873187
"""
31883188
x_shape = x.shape
31893189
x = x.view(-1, x.size(-1))
31903190
if x.device == torch.device("cpu"):
31913191
logger.info("Triton does not support cpu, falling back to torch ops.")
31923192
use_triton = False
31933193
if use_triton:
3194-
xq, x_scale = triton_quantize_fp8_block(x, block_m, block_k, scale_ub, K_major)
3194+
xq, x_scale = triton_quantize_fp8_block(x, block_m, block_k, scale_ub, k_major)
31953195
return xq.view(x_shape), x_scale
31963196
# else use pytorch implementation.
31973197
if not output_device:
@@ -3219,7 +3219,6 @@ def quantize_fp8_block(
32193219
if scale_ub is not None:
32203220
block_max = torch.clamp(block_max, min=eps, max=scale_ub.item())
32213221
else:
3222-
# pyre-ignore[6]: Incompatible parameter type [6]
32233222
block_max = torch.clamp(block_max, min=eps)
32243223
x_scale = torch.empty((grid_m, grid_k), dtype=torch.float32, device=output_device)
32253224
x_scale = max_fp8 / block_max.to(torch.float32) # pyre-ignore
@@ -3235,7 +3234,7 @@ def quantize_fp8_block(
32353234
x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
32363235
x_scale = x_scale.to(output_device) # pyre-ignore
32373236
del x, x_padded
3238-
if not K_major:
3237+
if not k_major:
32393238
x_scale = x_scale.t().contiguous()
32403239
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
32413240

@@ -3256,6 +3255,7 @@ def _kernel_quantize_fp8_group(
32563255
A_scale,
32573256
A_fp8,
32583257
scale_ub,
3258+
m_sizes,
32593259
M,
32603260
K,
32613261
stride_am,
@@ -3270,6 +3270,8 @@ def _kernel_quantize_fp8_group(
32703270
CLAMP_MAX: tl.constexpr,
32713271
USE_INT64: tl.constexpr,
32723272
GROUP_SIZE: tl.constexpr,
3273+
USE_M_MAJOR: tl.constexpr,
3274+
G: tl.constexpr,
32733275
GROUP_LOAD: tl.constexpr,
32743276
):
32753277
"""Quantize and scale each GROUP_SIZE chunk of each row.
@@ -3284,6 +3286,7 @@ def _kernel_quantize_fp8_group(
32843286
A_scale (Tensor): [M, cdiv(K, GROUP_SIZE)] reciprocal scale tensor per group.
32853287
A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a
32863288
scale_ub (Tensor): [1] Maximum allowed value for scale.
3289+
m_sizes (Optional[Tensor]): [G] Number of rows in each group.
32873290
M (int): Number of rows.
32883291
K (int): Number of columns.
32893292
stride_am (int): Stride of m dimension of A.
@@ -3298,6 +3301,8 @@ def _kernel_quantize_fp8_group(
32983301
CLAMP_MAX (bool): Whether to apply scale_ub.
32993302
USE_INT64 (bool): Whether to index using int64, which may be needed for large tensors.
33003303
GROUP_SIZE (int): Group size for K dimension of A_scale and kernel.
3304+
USE_M_MAJOR (bool): Whether to use grouped M-major layout for A_scale.
3305+
G (int): Number of groups in A_scale, only relevant when m_sizes is provided.
33013306
GROUP_LOAD (int): Number of groups to load and process simultaneously.
33023307
"""
33033308
pid = tl.program_id(0)
@@ -3311,6 +3316,26 @@ def _kernel_quantize_fp8_group(
33113316
scale_k_offset = tl.arange(0, GROUP_LOAD)
33123317
NUM_GROUPS: tl.constexpr = K // GROUP_SIZE
33133318

3319+
# When dealing with an M-major grouped gemm, we need to figure out
3320+
# which group this thread corresponds to and figure out the corresponding
3321+
# scale offset.
3322+
group_offset = 0
3323+
group_cumsum = 0
3324+
group_M = 0
3325+
stop = False
3326+
if USE_M_MAJOR and G > 0:
3327+
# Iterate over groups to both compute the cumulative sum and find which group we are in.
3328+
for i in range(G):
3329+
if not stop:
3330+
group_M = tl.cast(tl.load(m_sizes + i), pid.dtype)
3331+
if (group_cumsum + group_M) <= pid:
3332+
group_cumsum += group_M
3333+
else:
3334+
# Indicate we are finished computing cumsum.
3335+
stop = True
3336+
3337+
group_offset = group_cumsum * NUM_GROUPS
3338+
33143339
for k in range(0, tl.cdiv(K, (GROUP_LOAD * GROUP_SIZE))):
33153340
# Load groups of the input.
33163341
chunk_offset = k_offset + k * GROUP_LOAD * GROUP_SIZE
@@ -3330,11 +3355,31 @@ def _kernel_quantize_fp8_group(
33303355
# Scale and quantize.
33313356
a_scale = MAX_FP8 / group_max
33323357
scale_chunk_offset = scale_k_offset + k * GROUP_LOAD
3333-
tl.store(
3334-
A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k,
3335-
1.0 / a_scale,
3336-
mask=scale_chunk_offset < NUM_GROUPS,
3337-
)
3358+
3359+
if USE_M_MAJOR and G > 0:
3360+
tl.store(
3361+
A_scale
3362+
+ group_offset
3363+
+ (pid - group_cumsum) * stride_a_scale_k
3364+
+ (scale_chunk_offset * group_M),
3365+
1.0 / a_scale,
3366+
mask=scale_chunk_offset < NUM_GROUPS,
3367+
)
3368+
else:
3369+
if USE_M_MAJOR:
3370+
tl.store(
3371+
A_scale
3372+
+ pid * stride_a_scale_k
3373+
+ scale_chunk_offset * stride_a_scale_m,
3374+
1.0 / a_scale,
3375+
mask=scale_chunk_offset < NUM_GROUPS,
3376+
)
3377+
else:
3378+
tl.store(
3379+
A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k,
3380+
1.0 / a_scale,
3381+
mask=scale_chunk_offset < NUM_GROUPS,
3382+
)
33383383
# Apply scale to input.
33393384
a_fp8 = a_grouped * a_scale[:, None]
33403385
# Clamp to FP8 range to avoid overflow
@@ -3351,6 +3396,8 @@ def triton_quantize_fp8_group(
33513396
x: torch.Tensor,
33523397
group_size: int = 128,
33533398
scale_ub: Optional[torch.Tensor] = None,
3399+
m_sizes: Optional[torch.Tensor] = None,
3400+
k_major: bool = True,
33543401
) -> Tuple[torch.Tensor, torch.Tensor]:
33553402
"""
33563403
Quantize a tensor to fp8 with group-wise scalings.
@@ -3361,6 +3408,8 @@ def triton_quantize_fp8_group(
33613408
x (torch.Tensor): [M, K] higher precision input tensor.
33623409
group_size (int): Group size for M dimension of scale.
33633410
scale_ub: Maximum allowed value for scale.
3411+
m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
3412+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
33643413
33653414
Returns:
33663415
torch.Tensor: [M, K] fp8 scaled tensor.
@@ -3374,13 +3423,17 @@ def triton_quantize_fp8_group(
33743423
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
33753424
M, K = x.shape
33763425
k_groups = triton.cdiv(K, group_size)
3377-
x_scale = torch.empty((M, k_groups), device=x.device, dtype=torch.float32)
3426+
if k_major:
3427+
x_scale = torch.empty((M, k_groups), device=x.device, dtype=torch.float32)
3428+
else:
3429+
x_scale = torch.empty((k_groups, M), device=x.device, dtype=torch.float32)
33783430
x_fp8 = torch.empty((M, K), device=x.device, dtype=pt_dtype)
33793431
_kernel_quantize_fp8_group[(M,)](
33803432
x,
33813433
x_scale,
33823434
x_fp8,
33833435
scale_ub,
3436+
m_sizes,
33843437
M,
33853438
K,
33863439
x.stride(0),
@@ -3395,6 +3448,8 @@ def triton_quantize_fp8_group(
33953448
CLAMP_MAX=scale_ub is not None,
33963449
USE_INT64=x.numel() > (2**32 - 1),
33973450
GROUP_SIZE=group_size,
3451+
USE_M_MAJOR=m_sizes is not None or k_major is False,
3452+
G=m_sizes.numel() if m_sizes is not None else 0,
33983453
)
33993454
return x_fp8.view(x_shape), x_scale
34003455

@@ -3403,6 +3458,8 @@ def quantize_fp8_group(
34033458
x: torch.Tensor,
34043459
group_size: int = 128,
34053460
scale_ub: Optional[torch.Tensor] = None,
3461+
m_sizes: Optional[torch.Tensor] = None,
3462+
k_major: bool = True,
34063463
use_triton: bool = True,
34073464
output_device: Optional[torch.device] = None,
34083465
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -3415,6 +3472,9 @@ def quantize_fp8_group(
34153472
x (Tensor): [M, K] higher precision input tensor.
34163473
group_size (int): Group size for M dimension of scale.
34173474
scale_ub: Maximum allowed value for scale.
3475+
m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
3476+
k_major (bool): Whether output scales should be K major (True) or MN major (False).
3477+
This is needed because some kernels like cutlass require a special layout for scales.
34183478
use_triton (bool): Whether to use triton kernel or pytorch.
34193479
output_device (torch.device): Device to optionally move the scaled tensors to.
34203480
@@ -3428,7 +3488,9 @@ def quantize_fp8_group(
34283488
logger.info("Triton does not support cpu, falling back to torch ops.")
34293489
use_triton = False
34303490
if use_triton:
3431-
xq, x_scale = triton_quantize_fp8_group(x, group_size, scale_ub)
3491+
xq, x_scale = triton_quantize_fp8_group(
3492+
x, group_size, scale_ub, m_sizes, k_major
3493+
)
34323494
return xq.view(x_shape), x_scale
34333495
# else use pytorch implementation.
34343496
if not output_device:
@@ -3441,6 +3503,7 @@ def quantize_fp8_group(
34413503
assert (
34423504
K % group_size == 0
34433505
), "K must be divisible by group_size for cpu implementation."
3506+
assert m_sizes is None, "m_sizes is not supported for cpu implementation."
34443507
k_groups = triton.cdiv(K, group_size)
34453508
# View input as colleciton of groups for reduction.
34463509
x_grouped = x.view(M, k_groups, group_size).to(torch.float32)
@@ -3461,6 +3524,8 @@ def quantize_fp8_group(
34613524
# Cast and move data to output device (for cpu weight loading).
34623525
x_fp8 = x_fp8.to(device=output_device, dtype=pt_dtype)
34633526
x_scale = x_scale.to(output_device) # pyre-ignore
3527+
if not k_major:
3528+
x_scale = x_scale.t().contiguous()
34643529
return x_fp8.view(x_shape), 1 / x_scale # pyre-ignore
34653530

34663531

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,56 @@ def cuda(self) -> bool:
12431243
return True
12441244

12451245

1246+
@register_quantize_op
1247+
class FP8StackedGroupwiseGroupedGemm(QuantizeOpBase):
1248+
"""
1249+
FP8 grouped matmul with groupwise scaling and stacked inputs.
1250+
"""
1251+
1252+
def preprocess(self, x, w):
1253+
m_values = [i.shape[0] for i in x]
1254+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
1255+
# Quantize weights.
1256+
wq, w_scale = zip(
1257+
*[quantize_fp8_block(i, block_m=128, block_k=128, k_major=False) for i in w]
1258+
)
1259+
# Group weights as single tensor.
1260+
wq = torch.stack(wq, dim=0).contiguous()
1261+
w_scale = torch.stack(w_scale, dim=0).contiguous()
1262+
# Also view input as flattened.
1263+
x = torch.concat(x, dim=0).contiguous()
1264+
# Return processed tensors.
1265+
return x, wq, w_scale, m_sizes
1266+
1267+
def quantize(self, x, wq, w_scale, m_sizes):
1268+
xq, x_scale = quantize_fp8_group(x, m_sizes=m_sizes)
1269+
return xq, wq, x_scale, w_scale, m_sizes
1270+
1271+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
1272+
return torch.ops.fbgemm.f8f8bf16_groupwise_grouped(
1273+
xq, wq, x_scale, w_scale, m_sizes
1274+
)
1275+
1276+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
1277+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
1278+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
1279+
1280+
@property
1281+
def name(self) -> str:
1282+
if torch.version.cuda:
1283+
return "cutlass_groupwise_grouped"
1284+
else:
1285+
return "ck_groupwise_grouped"
1286+
1287+
@property
1288+
def hip(self) -> bool:
1289+
return False
1290+
1291+
@property
1292+
def cuda(self) -> bool:
1293+
return True
1294+
1295+
12461296
@register_quantize_op
12471297
class BF16GroupedGemm(QuantizeOpBase):
12481298
"""
@@ -1499,13 +1549,13 @@ class FP8CutlassGroupwiseGemm(QuantizeOpBase):
14991549
def preprocess(self, x, w):
15001550
# Quantize weights.
15011551
# Scale is expected to be in [K, N] layout (N Major).
1502-
wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128, K_major=False)
1552+
wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128, k_major=False)
15031553
# Return processed tensors.
15041554
return x, wq, w_scale
15051555

15061556
def quantize(self, x, wq, w_scale):
15071557
# Scale is expected to be in [K, M] layout (M Major).
1508-
xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128, K_major=False)
1558+
xq, x_scale = quantize_fp8_group(x, k_major=False)
15091559
# Pretranspose scales to deepgemm format.
15101560
return xq, wq, x_scale, w_scale
15111561

0 commit comments

Comments
 (0)