@@ -278,7 +278,7 @@ def test_f8f8bf16(self, kernel: str, use_fast_accum: bool) -> None:
278
278
),
279
279
QType = st .sampled_from ([fp8_e4m3 , fp8_e5m2 ]),
280
280
Bias = st .sampled_from ([True , False ]),
281
- CudaGraph = st .sampled_from ([False ]),
281
+ CudaGraph = st .sampled_from ([True , False ]),
282
282
UseTriton = st .sampled_from ([False ] + ([True ] if torch .version .cuda else [])),
283
283
UseFastAccum = st .booleans (),
284
284
InputMultiDim = st .booleans (),
@@ -337,78 +337,62 @@ def test_quantize_fp8_matmul(
337
337
)
338
338
339
339
if Mode == "tensorwise" :
340
- if CudaGraph :
341
- g = torch .cuda .CUDAGraph ()
342
- with torch .cuda .graph (g ):
343
- xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (x )
344
- wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (w )
345
- zq = torch .ops .fbgemm .f8f8bf16 (xq , wq , x_scale * w_scale )
346
- if bias is not None :
347
- zq += bias
348
- g .replay ()
349
- else :
340
+
341
+ def f (
342
+ x : torch .Tensor , w : torch .Tensor , bias : Optional [torch .Tensor ]
343
+ ) -> torch .Tensor :
350
344
xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (x )
351
345
wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (w )
352
346
zq = torch .ops .fbgemm .f8f8bf16 (xq , wq , x_scale * w_scale )
353
347
if bias is not None :
354
348
zq += bias
355
- elif Mode == "tensorwise_broadcast" :
356
- xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (x )
357
- wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (w )
358
- x_scale = x_scale .item ()
359
- w_scale = w_scale .item ()
349
+ return zq
350
+
360
351
if CudaGraph :
352
+ # Warm-up to avoid capture issues
353
+ f (x , w , bias )
354
+
361
355
g = torch .cuda .CUDAGraph ()
362
356
with torch .cuda .graph (g ):
363
- zq = torch .ops .fbgemm .f8f8bf16_tensorwise (
364
- xq , wq , x_scale * w_scale , use_fast_accum = UseFastAccum
365
- )
366
- if bias is not None :
367
- zq += bias
357
+ zq = f (x , w , bias )
368
358
g .replay ()
369
359
else :
360
+ zq = f (x , w , bias )
361
+ elif Mode == "tensorwise_broadcast" :
362
+
363
+ def f (
364
+ xq : torch .Tensor ,
365
+ wq : torch .Tensor ,
366
+ scale : float ,
367
+ bias : Optional [torch .Tensor ],
368
+ ) -> torch .Tensor :
370
369
zq = torch .ops .fbgemm .f8f8bf16_tensorwise (
371
- xq , wq , x_scale * w_scale , use_fast_accum = UseFastAccum
370
+ xq , wq , scale , use_fast_accum = UseFastAccum
372
371
)
373
372
if bias is not None :
374
373
zq += bias
375
- elif Mode == "rowwise" :
374
+ return zq
375
+
376
+ xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (x )
377
+ wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_tensor (w )
378
+ x_scale = x_scale .item ()
379
+ w_scale = w_scale .item ()
380
+
376
381
if CudaGraph :
377
- # Warm up triton functions before cuda graph.
378
- xq , x_scale = quantize_fp8_row (x )
379
- wq , w_scale = quantize_fp8_row (w )
380
- if UseTriton and torch .version .cuda :
381
- zq = matmul_fp8_row (
382
- xq , wq , x_scale , w_scale , fp8_fast_accum = UseFastAccum
383
- )
382
+ # Warm-up to avoid capture issues
383
+ f (xq , wq , x_scale * w_scale , bias )
384
+
384
385
g = torch .cuda .CUDAGraph ()
385
386
with torch .cuda .graph (g ):
386
- if torch .version .cuda :
387
- xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_row (
388
- x , output_dtype = QType
389
- )
390
- wq , w_scale = torch .ops .fbgemm .quantize_fp8_per_row (w )
391
- else :
392
- xq , x_scale = quantize_fp8_row (x )
393
- wq , w_scale = quantize_fp8_row (w )
394
- if UseTriton and torch .version .cuda :
395
- zq = matmul_fp8_row (xq , wq , x_scale , w_scale )
396
- if bias is not None :
397
- zq += bias
398
- else :
399
- zq = torch .ops .fbgemm .f8f8bf16_rowwise (
400
- xq ,
401
- wq ,
402
- x_scale ,
403
- w_scale ,
404
- bias = bias if torch .version .cuda else None ,
405
- use_fast_accum = UseFastAccum ,
406
- )
407
- # Bias fusion not yet supported on AMD.
408
- if bias is not None and torch .version .hip :
409
- zq += bias
387
+ zq = f (xq , wq , x_scale * w_scale , bias )
410
388
g .replay ()
411
389
else :
390
+ zq = f (xq , wq , x_scale * w_scale , bias )
391
+ elif Mode == "rowwise" :
392
+
393
+ def f (
394
+ x : torch .Tensor , w : torch .Tensor , bias : Optional [torch .Tensor ]
395
+ ) -> torch .Tensor :
412
396
if torch .version .cuda :
413
397
xq , x_scale = torch .ops .fbgemm .quantize_fp8_per_row (
414
398
x , output_dtype = QType
@@ -418,9 +402,7 @@ def test_quantize_fp8_matmul(
418
402
xq , x_scale = quantize_fp8_row (x )
419
403
wq , w_scale = quantize_fp8_row (w )
420
404
if UseTriton and torch .version .cuda :
421
- zq = matmul_fp8_row (
422
- xq , wq , x_scale , w_scale , fp8_fast_accum = UseFastAccum
423
- )
405
+ zq = matmul_fp8_row (xq , wq , x_scale , w_scale )
424
406
if bias is not None :
425
407
zq += bias
426
408
else :
@@ -435,14 +417,27 @@ def test_quantize_fp8_matmul(
435
417
# Bias fusion not yet supported on AMD.
436
418
if bias is not None and torch .version .hip :
437
419
zq += bias
438
- elif Mode == "blockwise" :
439
- block_m = block_n = block_k = 128
440
- output_device = torch . device ( self . device )
420
+
421
+ return zq
422
+
441
423
if CudaGraph :
442
- # Need a warmup to compile the Triton kernel before cuda graph
424
+ # Warm-up to avoid capture issues
425
+ f (x , w , bias )
426
+
427
+ g = torch .cuda .CUDAGraph ()
428
+ with torch .cuda .graph (g ):
429
+ zq = f (x , w , bias )
430
+ g .replay ()
431
+ else :
432
+ zq = f (x , w , bias )
433
+ elif Mode == "blockwise" :
443
434
435
+ def f (
436
+ x : torch .Tensor , w : torch .Tensor , bias : Optional [torch .Tensor ]
437
+ ) -> torch .Tensor :
438
+ block_m = block_n = block_k = 128
444
439
wq , w_scale = quantize_fp8_block (
445
- w , block_n , block_k , output_device = output_device
440
+ w , block_n , block_k , output_device = torch . device ( self . device )
446
441
)
447
442
xq , x_scale = quantize_fp8_block (x , block_m , block_k )
448
443
if UseTriton :
@@ -463,52 +458,18 @@ def test_quantize_fp8_matmul(
463
458
if bias is not None :
464
459
zq += bias
465
460
461
+ return zq
462
+
463
+ if CudaGraph :
464
+ # Warm-up to avoid capture issues
465
+ f (x , w , bias )
466
+
466
467
g = torch .cuda .CUDAGraph ()
467
468
with torch .cuda .graph (g ):
468
- wq , w_scale = quantize_fp8_block (
469
- w , block_n , block_k , output_device = output_device
470
- )
471
- xq , x_scale = quantize_fp8_block (x , block_m , block_k )
472
- if UseTriton :
473
- zq = matmul_fp8_block (
474
- xq ,
475
- wq ,
476
- x_scale ,
477
- w_scale ,
478
- block_m ,
479
- block_n ,
480
- block_k ,
481
- fp8_fast_accum = UseFastAccum ,
482
- )
483
- else :
484
- zq = torch .ops .fbgemm .f8f8bf16_blockwise (
485
- xq , wq , x_scale , w_scale , block_m , block_n , block_k
486
- )
487
- if bias is not None :
488
- zq += bias
469
+ zq = f (x , w , bias )
489
470
g .replay ()
490
471
else :
491
- wq , w_scale = quantize_fp8_block (
492
- w , block_n , block_k , output_device = output_device
493
- )
494
- xq , x_scale = quantize_fp8_block (x , block_m , block_k )
495
- if UseTriton :
496
- zq = matmul_fp8_block (
497
- xq ,
498
- wq ,
499
- x_scale ,
500
- w_scale ,
501
- block_m ,
502
- block_n ,
503
- block_k ,
504
- fp8_fast_accum = UseFastAccum ,
505
- )
506
- else :
507
- zq = torch .ops .fbgemm .f8f8bf16_blockwise (
508
- xq , wq , x_scale , w_scale , block_m , block_n , block_k
509
- )
510
- if bias is not None :
511
- zq += bias
472
+ zq = f (x , w , bias )
512
473
else :
513
474
raise ValueError (f"Invalid mode { Mode } " )
514
475
0 commit comments