@@ -3090,7 +3090,7 @@ def triton_quantize_fp8_block(
3090
3090
block_m : int = 256 ,
3091
3091
block_k : int = 256 ,
3092
3092
scale_ub : Optional [torch .Tensor ] = None ,
3093
- K_major : bool = True ,
3093
+ k_major : bool = True ,
3094
3094
) -> Tuple [torch .Tensor , torch .Tensor ]:
3095
3095
"""
3096
3096
Quantize a tensor to fp8 with block-wise scalings.
@@ -3102,12 +3102,12 @@ def triton_quantize_fp8_block(
3102
3102
block_m (int): Block size for M dimension of scale.
3103
3103
block_k (int): Block size for K dimension of scale.
3104
3104
scale_ub: Maximum allowed value for scale.
3105
- K_major (bool): Whether output scales should be K major (True) or MN major (False).
3105
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
3106
3106
3107
3107
Returns:
3108
3108
torch.Tensor : [M, K] fp8 scaled tensor.
3109
3109
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
3110
- if K_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
3110
+ if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
3111
3111
"""
3112
3112
assert x .device != torch .device (
3113
3113
"cpu"
@@ -3119,10 +3119,10 @@ def triton_quantize_fp8_block(
3119
3119
M , K = x .shape
3120
3120
grid_m = triton .cdiv (M , block_m )
3121
3121
grid_k = triton .cdiv (K , block_k )
3122
- if K_major :
3122
+ if k_major :
3123
3123
x_scale = torch .empty ((grid_m , grid_k ), device = x .device , dtype = torch .float32 )
3124
3124
else :
3125
- x_scale = torch .ones ((grid_k , grid_m ), device = x .device , dtype = torch .float32 )
3125
+ x_scale = torch .empty ((grid_k , grid_m ), device = x .device , dtype = torch .float32 )
3126
3126
x_fp8 = torch .empty ((M , K ), device = x .device , dtype = pt_dtype )
3127
3127
3128
3128
_kernel_quantize_fp8_block [(grid_m * grid_k ,)](
@@ -3151,7 +3151,7 @@ def triton_quantize_fp8_block(
3151
3151
# pyre-ignore[6]: Incompatible parameter type [6]
3152
3152
BLOCK_K = block_k ,
3153
3153
# pyre-ignore[6]: Incompatible parameter type [6]
3154
- K_MAJOR = K_major ,
3154
+ K_MAJOR = k_major ,
3155
3155
)
3156
3156
3157
3157
return x_fp8 .view (x_shape ), x_scale
@@ -3164,7 +3164,7 @@ def quantize_fp8_block(
3164
3164
scale_ub : Optional [torch .Tensor ] = None ,
3165
3165
use_triton : bool = True ,
3166
3166
output_device : Optional [torch .device ] = None ,
3167
- K_major : bool = True ,
3167
+ k_major : bool = True ,
3168
3168
) -> Tuple [torch .Tensor , torch .Tensor ]:
3169
3169
"""
3170
3170
Quantize a tensor to fp8 with block-wise scalings and optionally move to output device.
@@ -3178,20 +3178,20 @@ def quantize_fp8_block(
3178
3178
scale_ub: Maximum allowed value for scale.
3179
3179
use_triton (bool): Whether to use triton kernel or pytorch.
3180
3180
output_device (torch.device): Device to optionally move the scaled tensors to.
3181
- K_major (bool): Whether output scales should be K major (True) or MN major (False).
3181
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
3182
3182
3183
3183
Returns:
3184
3184
torch.Tensor: [M, K] fp8 scaled tensor.
3185
3185
torch.Tensor: [cdiv(M, block_m), cdiv(K, block_k)] reciprocal scale tensor per block
3186
- if K_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
3186
+ if k_major is True, otherwise [cdiv(K, block_k), cdiv(M, block_M)].
3187
3187
"""
3188
3188
x_shape = x .shape
3189
3189
x = x .view (- 1 , x .size (- 1 ))
3190
3190
if x .device == torch .device ("cpu" ):
3191
3191
logger .info ("Triton does not support cpu, falling back to torch ops." )
3192
3192
use_triton = False
3193
3193
if use_triton :
3194
- xq , x_scale = triton_quantize_fp8_block (x , block_m , block_k , scale_ub , K_major )
3194
+ xq , x_scale = triton_quantize_fp8_block (x , block_m , block_k , scale_ub , k_major )
3195
3195
return xq .view (x_shape ), x_scale
3196
3196
# else use pytorch implementation.
3197
3197
if not output_device :
@@ -3219,7 +3219,6 @@ def quantize_fp8_block(
3219
3219
if scale_ub is not None :
3220
3220
block_max = torch .clamp (block_max , min = eps , max = scale_ub .item ())
3221
3221
else :
3222
- # pyre-ignore[6]: Incompatible parameter type [6]
3223
3222
block_max = torch .clamp (block_max , min = eps )
3224
3223
x_scale = torch .empty ((grid_m , grid_k ), dtype = torch .float32 , device = output_device )
3225
3224
x_scale = max_fp8 / block_max .to (torch .float32 ) # pyre-ignore
@@ -3235,7 +3234,7 @@ def quantize_fp8_block(
3235
3234
x_fp8 = x_fp8 .to (device = output_device , dtype = pt_dtype )
3236
3235
x_scale = x_scale .to (output_device ) # pyre-ignore
3237
3236
del x , x_padded
3238
- if not K_major :
3237
+ if not k_major :
3239
3238
x_scale = x_scale .t ().contiguous ()
3240
3239
return x_fp8 .view (x_shape ), 1 / x_scale # pyre-ignore
3241
3240
@@ -3256,6 +3255,7 @@ def _kernel_quantize_fp8_group(
3256
3255
A_scale ,
3257
3256
A_fp8 ,
3258
3257
scale_ub ,
3258
+ m_sizes ,
3259
3259
M ,
3260
3260
K ,
3261
3261
stride_am ,
@@ -3270,6 +3270,8 @@ def _kernel_quantize_fp8_group(
3270
3270
CLAMP_MAX : tl .constexpr ,
3271
3271
USE_INT64 : tl .constexpr ,
3272
3272
GROUP_SIZE : tl .constexpr ,
3273
+ USE_M_MAJOR : tl .constexpr ,
3274
+ G : tl .constexpr ,
3273
3275
GROUP_LOAD : tl .constexpr ,
3274
3276
):
3275
3277
"""Quantize and scale each GROUP_SIZE chunk of each row.
@@ -3284,6 +3286,7 @@ def _kernel_quantize_fp8_group(
3284
3286
A_scale (Tensor): [M, cdiv(K, GROUP_SIZE)] reciprocal scale tensor per group.
3285
3287
A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a
3286
3288
scale_ub (Tensor): [1] Maximum allowed value for scale.
3289
+ m_sizes (Optional[Tensor]): [G] Number of rows in each group.
3287
3290
M (int): Number of rows.
3288
3291
K (int): Number of columns.
3289
3292
stride_am (int): Stride of m dimension of A.
@@ -3298,6 +3301,8 @@ def _kernel_quantize_fp8_group(
3298
3301
CLAMP_MAX (bool): Whether to apply scale_ub.
3299
3302
USE_INT64 (bool): Whether to index using int64, which may be needed for large tensors.
3300
3303
GROUP_SIZE (int): Group size for K dimension of A_scale and kernel.
3304
+ USE_M_MAJOR (bool): Whether to use grouped M-major layout for A_scale.
3305
+ G (int): Number of groups in A_scale, only relevant when m_sizes is provided.
3301
3306
GROUP_LOAD (int): Number of groups to load and process simultaneously.
3302
3307
"""
3303
3308
pid = tl .program_id (0 )
@@ -3311,6 +3316,26 @@ def _kernel_quantize_fp8_group(
3311
3316
scale_k_offset = tl .arange (0 , GROUP_LOAD )
3312
3317
NUM_GROUPS : tl .constexpr = K // GROUP_SIZE
3313
3318
3319
+ # When dealing with an M-major grouped gemm, we need to figure out
3320
+ # which group this thread corresponds to and figure out the corresponding
3321
+ # scale offset.
3322
+ group_offset = 0
3323
+ group_cumsum = 0
3324
+ group_M = 0
3325
+ stop = False
3326
+ if USE_M_MAJOR and G > 0 :
3327
+ # Iterate over groups to both compute the cumulative sum and find which group we are in.
3328
+ for i in range (G ):
3329
+ if not stop :
3330
+ group_M = tl .cast (tl .load (m_sizes + i ), pid .dtype )
3331
+ if (group_cumsum + group_M ) <= pid :
3332
+ group_cumsum += group_M
3333
+ else :
3334
+ # Indicate we are finished computing cumsum.
3335
+ stop = True
3336
+
3337
+ group_offset = group_cumsum * NUM_GROUPS
3338
+
3314
3339
for k in range (0 , tl .cdiv (K , (GROUP_LOAD * GROUP_SIZE ))):
3315
3340
# Load groups of the input.
3316
3341
chunk_offset = k_offset + k * GROUP_LOAD * GROUP_SIZE
@@ -3330,11 +3355,31 @@ def _kernel_quantize_fp8_group(
3330
3355
# Scale and quantize.
3331
3356
a_scale = MAX_FP8 / group_max
3332
3357
scale_chunk_offset = scale_k_offset + k * GROUP_LOAD
3333
- tl .store (
3334
- A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k ,
3335
- 1.0 / a_scale ,
3336
- mask = scale_chunk_offset < NUM_GROUPS ,
3337
- )
3358
+
3359
+ if USE_M_MAJOR and G > 0 :
3360
+ tl .store (
3361
+ A_scale
3362
+ + group_offset
3363
+ + (pid - group_cumsum ) * stride_a_scale_k
3364
+ + (scale_chunk_offset * group_M ),
3365
+ 1.0 / a_scale ,
3366
+ mask = scale_chunk_offset < NUM_GROUPS ,
3367
+ )
3368
+ else :
3369
+ if USE_M_MAJOR :
3370
+ tl .store (
3371
+ A_scale
3372
+ + pid * stride_a_scale_k
3373
+ + scale_chunk_offset * stride_a_scale_m ,
3374
+ 1.0 / a_scale ,
3375
+ mask = scale_chunk_offset < NUM_GROUPS ,
3376
+ )
3377
+ else :
3378
+ tl .store (
3379
+ A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k ,
3380
+ 1.0 / a_scale ,
3381
+ mask = scale_chunk_offset < NUM_GROUPS ,
3382
+ )
3338
3383
# Apply scale to input.
3339
3384
a_fp8 = a_grouped * a_scale [:, None ]
3340
3385
# Clamp to FP8 range to avoid overflow
@@ -3351,6 +3396,8 @@ def triton_quantize_fp8_group(
3351
3396
x : torch .Tensor ,
3352
3397
group_size : int = 128 ,
3353
3398
scale_ub : Optional [torch .Tensor ] = None ,
3399
+ m_sizes : Optional [torch .Tensor ] = None ,
3400
+ k_major : bool = True ,
3354
3401
) -> Tuple [torch .Tensor , torch .Tensor ]:
3355
3402
"""
3356
3403
Quantize a tensor to fp8 with group-wise scalings.
@@ -3361,6 +3408,8 @@ def triton_quantize_fp8_group(
3361
3408
x (torch.Tensor): [M, K] higher precision input tensor.
3362
3409
group_size (int): Group size for M dimension of scale.
3363
3410
scale_ub: Maximum allowed value for scale.
3411
+ m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
3412
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
3364
3413
3365
3414
Returns:
3366
3415
torch.Tensor: [M, K] fp8 scaled tensor.
@@ -3374,13 +3423,17 @@ def triton_quantize_fp8_group(
3374
3423
pt_dtype , tl_dtype , max_fp8 , eps = get_fp8_constants ()
3375
3424
M , K = x .shape
3376
3425
k_groups = triton .cdiv (K , group_size )
3377
- x_scale = torch .empty ((M , k_groups ), device = x .device , dtype = torch .float32 )
3426
+ if k_major :
3427
+ x_scale = torch .empty ((M , k_groups ), device = x .device , dtype = torch .float32 )
3428
+ else :
3429
+ x_scale = torch .empty ((k_groups , M ), device = x .device , dtype = torch .float32 )
3378
3430
x_fp8 = torch .empty ((M , K ), device = x .device , dtype = pt_dtype )
3379
3431
_kernel_quantize_fp8_group [(M ,)](
3380
3432
x ,
3381
3433
x_scale ,
3382
3434
x_fp8 ,
3383
3435
scale_ub ,
3436
+ m_sizes ,
3384
3437
M ,
3385
3438
K ,
3386
3439
x .stride (0 ),
@@ -3395,6 +3448,8 @@ def triton_quantize_fp8_group(
3395
3448
CLAMP_MAX = scale_ub is not None ,
3396
3449
USE_INT64 = x .numel () > (2 ** 32 - 1 ),
3397
3450
GROUP_SIZE = group_size ,
3451
+ USE_M_MAJOR = m_sizes is not None or k_major is False ,
3452
+ G = m_sizes .numel () if m_sizes is not None else 0 ,
3398
3453
)
3399
3454
return x_fp8 .view (x_shape ), x_scale
3400
3455
@@ -3403,6 +3458,8 @@ def quantize_fp8_group(
3403
3458
x : torch .Tensor ,
3404
3459
group_size : int = 128 ,
3405
3460
scale_ub : Optional [torch .Tensor ] = None ,
3461
+ m_sizes : Optional [torch .Tensor ] = None ,
3462
+ k_major : bool = True ,
3406
3463
use_triton : bool = True ,
3407
3464
output_device : Optional [torch .device ] = None ,
3408
3465
) -> Tuple [torch .Tensor , torch .Tensor ]:
@@ -3415,6 +3472,9 @@ def quantize_fp8_group(
3415
3472
x (Tensor): [M, K] higher precision input tensor.
3416
3473
group_size (int): Group size for M dimension of scale.
3417
3474
scale_ub: Maximum allowed value for scale.
3475
+ m_sizes: Optional input for grouped gemm to specify the number of rows in each group.
3476
+ k_major (bool): Whether output scales should be K major (True) or MN major (False).
3477
+ This is needed because some kernels like cutlass require a special layout for scales.
3418
3478
use_triton (bool): Whether to use triton kernel or pytorch.
3419
3479
output_device (torch.device): Device to optionally move the scaled tensors to.
3420
3480
@@ -3428,7 +3488,9 @@ def quantize_fp8_group(
3428
3488
logger .info ("Triton does not support cpu, falling back to torch ops." )
3429
3489
use_triton = False
3430
3490
if use_triton :
3431
- xq , x_scale = triton_quantize_fp8_group (x , group_size , scale_ub )
3491
+ xq , x_scale = triton_quantize_fp8_group (
3492
+ x , group_size , scale_ub , m_sizes , k_major
3493
+ )
3432
3494
return xq .view (x_shape ), x_scale
3433
3495
# else use pytorch implementation.
3434
3496
if not output_device :
@@ -3441,6 +3503,7 @@ def quantize_fp8_group(
3441
3503
assert (
3442
3504
K % group_size == 0
3443
3505
), "K must be divisible by group_size for cpu implementation."
3506
+ assert m_sizes is None , "m_sizes is not supported for cpu implementation."
3444
3507
k_groups = triton .cdiv (K , group_size )
3445
3508
# View input as colleciton of groups for reduction.
3446
3509
x_grouped = x .view (M , k_groups , group_size ).to (torch .float32 )
@@ -3461,6 +3524,8 @@ def quantize_fp8_group(
3461
3524
# Cast and move data to output device (for cpu weight loading).
3462
3525
x_fp8 = x_fp8 .to (device = output_device , dtype = pt_dtype )
3463
3526
x_scale = x_scale .to (output_device ) # pyre-ignore
3527
+ if not k_major :
3528
+ x_scale = x_scale .t ().contiguous ()
3464
3529
return x_fp8 .view (x_shape ), 1 / x_scale # pyre-ignore
3465
3530
3466
3531
0 commit comments