@@ -3111,7 +3111,7 @@ def triton_quantize_fp8_block(
3111
3111
M , K = x .shape
3112
3112
grid_m = triton .cdiv (M , block_m )
3113
3113
grid_k = triton .cdiv (K , block_k )
3114
- x_scale = torch .ones ((grid_m , grid_k ), device = x .device , dtype = torch .float32 )
3114
+ x_scale = torch .empty ((grid_m , grid_k ), device = x .device , dtype = torch .float32 )
3115
3115
x_fp8 = torch .empty ((M , K ), device = x .device , dtype = pt_dtype )
3116
3116
3117
3117
_kernel_quantize_fp8_block [(grid_m * grid_k ,)](
@@ -3222,6 +3222,230 @@ def quantize_fp8_block(
3222
3222
return x_fp8 .view (x_shape ), 1 / x_scale # pyre-ignore
3223
3223
3224
3224
3225
+ @triton .autotune (
3226
+ configs = [
3227
+ Config ({"GROUP_LOAD" : 2 }),
3228
+ Config ({"GROUP_LOAD" : 4 }),
3229
+ Config ({"GROUP_LOAD" : 8 }),
3230
+ Config ({"GROUP_LOAD" : 16 }),
3231
+ Config ({"GROUP_LOAD" : 32 }),
3232
+ ],
3233
+ key = ["K" ],
3234
+ )
3235
+ @triton .jit
3236
+ def _kernel_quantize_fp8_group (
3237
+ A ,
3238
+ A_scale ,
3239
+ A_fp8 ,
3240
+ scale_ub ,
3241
+ M ,
3242
+ K ,
3243
+ stride_am ,
3244
+ stride_ak ,
3245
+ stride_om ,
3246
+ stride_ok ,
3247
+ stride_a_scale_m ,
3248
+ stride_a_scale_k ,
3249
+ TL_FP8_DTYPE : tl .constexpr ,
3250
+ MAX_FP8 : tl .constexpr ,
3251
+ EPS : tl .constexpr ,
3252
+ CLAMP_MAX : tl .constexpr ,
3253
+ USE_INT64 : tl .constexpr ,
3254
+ GROUP_SIZE : tl .constexpr ,
3255
+ GROUP_LOAD : tl .constexpr ,
3256
+ ):
3257
+ """Quantize and scale each GROUP_SIZE chunk of each row.
3258
+
3259
+ Scale per group i is computed as 1 / (MAX_FP8 / max(abs(A[i:i+GROUP_SIZE])))
3260
+
3261
+ Each kernel thread is responsible for one row and loads and processes a tunable
3262
+ number of groups at once.
3263
+
3264
+ Args:
3265
+ A (Tensor): [M, K] higher precision input tensor.
3266
+ A_scale (Tensor): [M, cdiv(K, GROUP_SIZE)] reciprocal scale tensor per group.
3267
+ A_fp8 (Tensor): [M, K] fp8 scaled tensor. A_fp8 = A * a
3268
+ scale_ub (Tensor): [1] Maximum allowed value for scale.
3269
+ M (int): Number of rows.
3270
+ K (int): Number of columns.
3271
+ stride_am (int): Stride of m dimension of A.
3272
+ stride_ak (int): Stride of k dimension of A.
3273
+ stride_om (int): Stride of m dimension of output.
3274
+ stride_ok (int): Stride of k dimension of output.
3275
+ stride_a_scale_m (int): Stride of m dimension of A_scale.
3276
+ stride_a_scale_k (int): Stride of k dimension of A_scale.
3277
+ TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
3278
+ MAX_FP8 (float): Maxmimum expressible value for FP8.
3279
+ EPS (float): Epsilon value for numerical stability.
3280
+ CLAMP_MAX (bool): Whether to apply scale_ub.
3281
+ USE_INT64 (bool): Whether to index using int64, which may be needed for large tensors.
3282
+ GROUP_SIZE (int): Group size for K dimension of A_scale and kernel.
3283
+ GROUP_LOAD (int): Number of groups to load and process simultaneously.
3284
+ """
3285
+ pid = tl .program_id (0 )
3286
+ if USE_INT64 :
3287
+ pid = pid .to (tl .int64 )
3288
+ # We load group_size * group_load chunks at a time.
3289
+ row_offset = pid * stride_am
3290
+ out_offset = pid * stride_om
3291
+ scale_row_offset = pid * stride_a_scale_m
3292
+ k_offset = tl .arange (0 , GROUP_LOAD * GROUP_SIZE )
3293
+ scale_k_offset = tl .arange (0 , GROUP_LOAD )
3294
+ NUM_GROUPS : tl .constexpr = K // GROUP_SIZE
3295
+
3296
+ for k in range (0 , tl .cdiv (K , (GROUP_LOAD * GROUP_SIZE ))):
3297
+ # Load groups of the input.
3298
+ chunk_offset = k_offset + k * GROUP_LOAD * GROUP_SIZE
3299
+ a = tl .load (
3300
+ A + row_offset + chunk_offset * stride_ak , mask = chunk_offset < K , other = 0.0
3301
+ )
3302
+ # View loaded chunk as a set of groups.
3303
+ a_grouped = tl .reshape (a , [GROUP_LOAD , GROUP_SIZE ])
3304
+ # Reduce over groups.
3305
+ group_max = tl .max (tl .abs (a_grouped ), axis = 1 )
3306
+ # Apply clamping if specified.
3307
+ if CLAMP_MAX :
3308
+ ub = tl .load (scale_ub )
3309
+ group_max = tl .clamp (group_max , EPS , ub )
3310
+ else :
3311
+ group_max = tl .maximum (group_max , EPS )
3312
+ # Scale and quantize.
3313
+ a_scale = MAX_FP8 / group_max
3314
+ scale_chunk_offset = scale_k_offset + k * GROUP_LOAD
3315
+ tl .store (
3316
+ A_scale + scale_row_offset + scale_chunk_offset * stride_a_scale_k ,
3317
+ 1.0 / a_scale ,
3318
+ mask = scale_chunk_offset < NUM_GROUPS ,
3319
+ )
3320
+ # Apply scale to input.
3321
+ a_fp8 = a_grouped * a_scale [:, None ]
3322
+ # Clamp to FP8 range to avoid overflow
3323
+ a_fp8 = tl .clamp (a_fp8 , - MAX_FP8 , MAX_FP8 ).to (TL_FP8_DTYPE )
3324
+ # Write to output.
3325
+ tl .store (
3326
+ A_fp8 + out_offset + chunk_offset * stride_ok ,
3327
+ tl .ravel (a_fp8 ),
3328
+ mask = chunk_offset < K ,
3329
+ )
3330
+
3331
+
3332
+ def triton_quantize_fp8_group (
3333
+ x : torch .Tensor ,
3334
+ group_size : int = 128 ,
3335
+ scale_ub : Optional [torch .Tensor ] = None ,
3336
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
3337
+ """
3338
+ Quantize a tensor to fp8 with group-wise scalings.
3339
+
3340
+ Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
3341
+
3342
+ Args:
3343
+ x (torch.Tensor): [M, K] higher precision input tensor.
3344
+ group_size (int): Group size for M dimension of scale.
3345
+ scale_ub: Maximum allowed value for scale.
3346
+
3347
+ Returns:
3348
+ torch.Tensor: [M, K] fp8 scaled tensor.
3349
+ torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
3350
+ """
3351
+ assert x .device != torch .device (
3352
+ "cpu"
3353
+ ), "Triton groupwise quantization not supported on cpu."
3354
+ x_shape = x .shape
3355
+ x = x .view (- 1 , x .size (- 1 ))
3356
+ pt_dtype , tl_dtype , max_fp8 , eps = get_fp8_constants ()
3357
+ M , K = x .shape
3358
+ k_groups = triton .cdiv (K , group_size )
3359
+ x_scale = torch .empty ((M , k_groups ), device = x .device , dtype = torch .float32 )
3360
+ x_fp8 = torch .empty ((M , K ), device = x .device , dtype = pt_dtype )
3361
+ _kernel_quantize_fp8_group [(M ,)](
3362
+ x ,
3363
+ x_scale ,
3364
+ x_fp8 ,
3365
+ scale_ub ,
3366
+ M ,
3367
+ K ,
3368
+ x .stride (0 ),
3369
+ x .stride (1 ),
3370
+ x_fp8 .stride (0 ),
3371
+ x_fp8 .stride (1 ),
3372
+ x_scale .stride (0 ),
3373
+ x_scale .stride (1 ),
3374
+ TL_FP8_DTYPE = tl_dtype ,
3375
+ MAX_FP8 = max_fp8 ,
3376
+ EPS = eps ,
3377
+ CLAMP_MAX = scale_ub is not None ,
3378
+ USE_INT64 = x .numel () > (2 ** 32 - 1 ),
3379
+ GROUP_SIZE = group_size ,
3380
+ )
3381
+ return x_fp8 .view (x_shape ), x_scale
3382
+
3383
+
3384
+ def quantize_fp8_group (
3385
+ x : torch .Tensor ,
3386
+ group_size : int = 128 ,
3387
+ scale_ub : Optional [torch .Tensor ] = None ,
3388
+ use_triton : bool = True ,
3389
+ output_device : Optional [torch .device ] = None ,
3390
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
3391
+ """
3392
+ Quantize a tensor to fp8 with group-wise scalings and optionally move to output device.
3393
+
3394
+ Scale per group i is computed as 1 / (MAX_FP8 / max(abs(x[i:i+group_size])))
3395
+
3396
+ Args:
3397
+ x (Tensor): [M, K] higher precision input tensor.
3398
+ group_size (int): Group size for M dimension of scale.
3399
+ scale_ub: Maximum allowed value for scale.
3400
+ use_triton (bool): Whether to use triton kernel or pytorch.
3401
+ output_device (torch.device): Device to optionally move the scaled tensors to.
3402
+
3403
+ Returns:
3404
+ torch.Tensor: [M, K] fp8 scaled tensor.
3405
+ torch.Tensor: [M, cdiv(K, group_size)] reciprocal scale tensor per group.
3406
+ """
3407
+ x_shape = x .shape
3408
+ x = x .view (- 1 , x .size (- 1 ))
3409
+ if x .device == torch .device ("cpu" ):
3410
+ logger .info ("Triton does not support cpu, falling back to torch ops." )
3411
+ use_triton = False
3412
+ if use_triton :
3413
+ xq , x_scale = triton_quantize_fp8_group (x , group_size , scale_ub )
3414
+ return xq .view (x_shape ), x_scale
3415
+ # else use pytorch implementation.
3416
+ if not output_device :
3417
+ output_device = x .device
3418
+
3419
+ # Get constants.
3420
+ pt_dtype , _ , max_fp8 , eps = get_fp8_constants ()
3421
+
3422
+ M , K = x .shape
3423
+ assert (
3424
+ K % group_size == 0
3425
+ ), "K must be divisible by group_size for cpu implementation."
3426
+ k_groups = triton .cdiv (K , group_size )
3427
+ # View input as colleciton of groups for reduction.
3428
+ x_grouped = x .view (M , k_groups , group_size ).to (torch .float32 )
3429
+ # Reduce over groups.
3430
+ group_max = x_grouped .abs ().amax (dim = 2 )
3431
+ # Apply clamping.
3432
+ group_max = (
3433
+ torch .clamp (group_max , min = eps , max = scale_ub .item ())
3434
+ if scale_ub
3435
+ else torch .clamp (group_max , min = eps )
3436
+ )
3437
+ x_scale = torch .empty ((M , k_groups ), dtype = torch .float32 , device = output_device )
3438
+ x_scale = max_fp8 / group_max # pyre-ignore
3439
+ # pyre-ignore[16]: Undefined attribute [16]
3440
+ x_scale [x_scale == float ("inf" )] = 1.0
3441
+ # pyre-ignore[16]: Undefined attribute [16]
3442
+ x_fp8 = x .view (- 1 , k_groups , group_size ) * x_scale .unsqueeze (2 )
3443
+ # Cast and move data to output device (for cpu weight loading).
3444
+ x_fp8 = x_fp8 .to (device = output_device , dtype = pt_dtype )
3445
+ x_scale = x_scale .to (output_device ) # pyre-ignore
3446
+ return x_fp8 .view (x_shape ), 1 / x_scale # pyre-ignore
3447
+
3448
+
3225
3449
def need_split_k (SIZE_M , SIZE_N , SIZE_K ):
3226
3450
return (SIZE_M < 64 or SIZE_N < 64 ) and SIZE_K > 1024
3227
3451
0 commit comments