@@ -531,6 +531,43 @@ def test_inference_mode(self):
531
531
with torch .inference_mode (mode = True ):
532
532
m (x )
533
533
534
+ @unittest .skip (
535
+ "TODO enable this test after https://github.com/pytorch/pytorch/pull/140967 lands in CI"
536
+ )
537
+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
538
+ @unittest .skipIf (not is_cuda_8_9 , "CUDA 8.9 not available" )
539
+ @pytest .mark .parametrize (
540
+ "recipe_name" ,
541
+ [
542
+ Float8LinearRecipeName .ALL_TENSORWISE ,
543
+ # TODO(future PR): enable axiswise recipes
544
+ ],
545
+ )
546
+ def test_zero_dim (self , recipe_name ):
547
+ # Note: we only test M == 0 because we can assume that K == 0 and N == 0
548
+ # are not important
549
+ M , K , N = 0 , 64 , 128
550
+
551
+ x0_ref = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 ).requires_grad_ ()
552
+ x0_fp8 = copy .deepcopy (x0_ref )
553
+ config = recipe_name_to_linear_config (recipe_name )
554
+
555
+ m_ref = nn .Sequential (nn .Linear (K , N , device = "cuda" , dtype = torch .bfloat16 ))
556
+ m_fp8 = copy .deepcopy (m_ref )
557
+ m_fp8 = convert_to_float8_training (m_fp8 , config = config )
558
+
559
+ y_ref = m_ref (x0_ref )
560
+ y_ref .sum ().backward ()
561
+
562
+ y_fp8 = m_fp8 (x0_fp8 )
563
+ y_fp8 .sum ().backward ()
564
+
565
+ assert torch .allclose (y_ref , y_fp8 , rtol = 0 , atol = 0 )
566
+ assert torch .allclose (
567
+ m_ref [0 ].weight .grad , m_fp8 [0 ].weight .grad , rtol = 0 , atol = 0
568
+ )
569
+ assert torch .allclose (x0_ref .grad , x0_fp8 .grad , rtol = 0 , atol = 0 )
570
+
534
571
535
572
class TestScaledMM :
536
573
@unittest .skipIf (
0 commit comments