37
37
from torchao .utils import (
38
38
TORCH_VERSION_AT_LEAST_2_8 ,
39
39
is_sm_at_least_89 ,
40
+ is_sm_at_least_90 ,
40
41
is_sm_at_least_100 ,
41
42
)
42
43
@@ -459,10 +460,29 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
459
460
"mm_config" , [NVFP4MMConfig .DYNAMIC , NVFP4MMConfig .WEIGHT_ONLY ]
460
461
)
461
462
@pytest .mark .parametrize ("inpt_dtype" , [torch .bfloat16 , torch .float32 ])
463
+ @pytest .mark .parametrize ("use_triton_kernel" , [True , False ])
464
+ @pytest .mark .parametrize (
465
+ "shapes" ,
466
+ [
467
+ (128 , 64 , 256 ),
468
+ (256 , 128 , 512 ),
469
+ (145 , 64 , 256 ),
470
+ (128 , 96 , 256 ),
471
+ (128 , 160 , 256 ),
472
+ (64 , 64 , 256 ),
473
+ (200 , 192 , 256 ),
474
+ ],
475
+ ids = lambda s : f"{ s [0 ]} x{ s [1 ]} x{ s [2 ]} " ,
476
+ )
462
477
@torch .no_grad ()
463
478
@skip_if_rocm ("ROCm float4 gemm require gfx950" )
464
479
def test_inference_subclass_nvfp4 (
465
- bias : bool , compile : bool , mm_config : NVFP4MMConfig , inpt_dtype : torch .dtype
480
+ bias : bool ,
481
+ compile : bool ,
482
+ mm_config : NVFP4MMConfig ,
483
+ inpt_dtype : torch .dtype ,
484
+ use_triton_kernel : bool ,
485
+ shapes : tuple ,
466
486
):
467
487
"""
468
488
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
@@ -477,16 +497,20 @@ def test_inference_subclass_nvfp4(
477
497
478
498
if mm_config == NVFP4MMConfig .WEIGHT_ONLY and compile :
479
499
pytest .skip ("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile" )
480
- m = nn .Linear (64 , 256 , bias = bias , dtype = inpt_dtype , device = "cuda" )
500
+ batch_size , in_features , out_features = shapes
501
+
502
+ m = nn .Linear (in_features , out_features , bias = bias , dtype = inpt_dtype , device = "cuda" )
481
503
m_mx = copy .deepcopy (m )
482
504
483
- config = NVFP4InferenceConfig (mm_config = mm_config )
505
+ config = NVFP4InferenceConfig (
506
+ mm_config = mm_config , use_triton_kernel = use_triton_kernel
507
+ )
484
508
quantize_ (m_mx , config = config )
485
509
486
510
if compile :
487
511
m_mx = torch .compile (m_mx , fullgraph = True , backend = "aot_eager" )
488
512
489
- x = torch .randn (128 , 64 , device = "cuda" , dtype = inpt_dtype )
513
+ x = torch .randn (batch_size , in_features , device = "cuda" , dtype = inpt_dtype )
490
514
y_ref = m (x )
491
515
y_mx = m_mx (x )
492
516
sqnr = compute_error (y_ref , y_mx )
@@ -513,14 +537,33 @@ def test_inference_subclass_nvfp4(
513
537
@pytest .mark .parametrize ("compile" , [False ])
514
538
@pytest .mark .parametrize ("bias" , [True , False ])
515
539
@pytest .mark .parametrize ("inpt_dtype" , [torch .bfloat16 , torch .float32 ])
540
+ @pytest .mark .parametrize ("use_triton_kernel" , [True , False ])
541
+ @pytest .mark .parametrize (
542
+ "shapes" ,
543
+ [
544
+ (128 , 64 , 256 ),
545
+ (256 , 128 , 512 ),
546
+ (157 , 64 , 256 ),
547
+ (128 , 96 , 256 ),
548
+ (128 , 160 , 256 ),
549
+ (64 , 64 , 256 ),
550
+ (200 , 192 , 256 ),
551
+ ],
552
+ ids = lambda s : f"{ s [0 ]} x{ s [1 ]} x{ s [2 ]} " ,
553
+ )
516
554
@torch .no_grad ()
517
555
@skip_if_rocm ("ROCm float4 gemm require gfx950" )
556
+ @pytest .mark .skipif (
557
+ not is_sm_at_least_90 (), reason = "CUDA capability >= 9.0 required for fp8e4nv"
558
+ )
518
559
def test_nvfp4_matmul_with_amax (
519
560
use_gelu : bool ,
520
561
mm_config : NVFP4MMConfig ,
521
562
compile : bool ,
522
563
bias : bool ,
523
564
inpt_dtype : torch .dtype ,
565
+ use_triton_kernel : bool ,
566
+ shapes : tuple ,
524
567
):
525
568
from torchao .prototype .mx_formats .nvfp4_tensor import (
526
569
NVFP4Tensor ,
@@ -537,7 +580,7 @@ def test_nvfp4_matmul_with_amax(
537
580
if mm_config == NVFP4MMConfig .WEIGHT_ONLY and compile :
538
581
pytest .skip ("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile" )
539
582
540
- m , k , n = 64 , 256 , 128
583
+ m , k , n = shapes
541
584
542
585
# Create activation tensor
543
586
if use_gelu :
@@ -559,12 +602,14 @@ def test_nvfp4_matmul_with_amax(
559
602
per_tensor_scale = a_scale ,
560
603
mm_config = mm_config ,
561
604
is_swizzled_scales = True ,
605
+ use_triton_kernel = use_triton_kernel ,
562
606
)
563
607
B_nvfp4 = NVFP4Tensor .to_nvfp4 (
564
608
B ,
565
609
per_tensor_scale = b_scale ,
566
610
mm_config = mm_config ,
567
611
is_swizzled_scales = True ,
612
+ use_triton_kernel = use_triton_kernel ,
568
613
)
569
614
570
615
func = torch .compile (F .linear , fullgraph = True ) if compile else F .linear
0 commit comments