@@ -531,8 +531,45 @@ def test_inference_mode(self):
531
531
with torch .inference_mode (mode = True ):
532
532
m (x )
533
533
534
+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
535
+ @unittest .skipIf (not is_cuda_8_9 , "CUDA 8.9 not available" )
536
+ @pytest .mark .parametrize (
537
+ "recipe_name" ,
538
+ [
539
+ Float8LinearRecipeName .ALL_TENSORWISE ,
540
+ # TODO(future PR): enable axiswise recipes
541
+ ],
542
+ )
543
+ def test_zero_dim (self , recipe_name ):
544
+ # Note: we only test M == 0 because we can assume that K == 0 and N == 0
545
+ # are not important
546
+ M , K , N = 0 , 64 , 128
547
+
548
+ x0_ref = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 ).requires_grad_ ()
549
+ x0_fp8 = copy .deepcopy (x0_ref )
550
+ x1 = torch .randn (32 , K , device = "cuda" , dtype = torch .bfloat16 )
551
+ config = recipe_name_to_linear_config (recipe_name )
552
+
553
+ m_ref = nn .Sequential (nn .Linear (K , N , device = "cuda" , dtype = torch .bfloat16 ))
554
+ m_fp8 = copy .deepcopy (m_ref )
555
+ m_fp8 = convert_to_float8_training (m_fp8 , config = config )
556
+
557
+ y_ref = m_ref (x0_ref )
558
+ y_ref .sum ().backward ()
559
+
560
+ y_fp8 = m_fp8 (x0_fp8 )
561
+ y_fp8 .sum ().backward ()
562
+
563
+ assert torch .allclose (
564
+ y_ref , y_fp8 , rtol = 0 , atol = 0 )
565
+ assert torch .allclose (
566
+ m_ref [0 ].weight .grad , m_fp8 [0 ].weight .grad , rtol = 0 , atol = 0 )
567
+ assert torch .allclose (
568
+ x0_ref .grad , x0_fp8 .grad , rtol = 0 , atol = 0 )
569
+
534
570
535
571
class TestScaledMM :
572
+
536
573
@unittest .skipIf (
537
574
not is_cuda_8_9 ,
538
575
"CUDA not available" ,
0 commit comments