@@ -552,6 +552,43 @@ def test_quantize(self):
552
552
with torch .no_grad ():
553
553
m (x )
554
554
555
+ @unittest .skip (
556
+ "TODO enable this test after https://github.com/pytorch/pytorch/pull/140967 lands in CI"
557
+ )
558
+ @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
559
+ @unittest .skipIf (not is_sm_at_least_89 , "CUDA 8.9 not available" )
560
+ @pytest .mark .parametrize (
561
+ "recipe_name" ,
562
+ [
563
+ Float8LinearRecipeName .ALL_TENSORWISE ,
564
+ # TODO(future PR): enable axiswise recipes
565
+ ],
566
+ )
567
+ def test_zero_dim (self , recipe_name ):
568
+ # Note: we only test M == 0 because we can assume that K == 0 and N == 0
569
+ # are not important
570
+ M , K , N = 0 , 64 , 128
571
+
572
+ x0_ref = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 ).requires_grad_ ()
573
+ x0_fp8 = copy .deepcopy (x0_ref )
574
+ config = recipe_name_to_linear_config (recipe_name )
575
+
576
+ m_ref = nn .Sequential (nn .Linear (K , N , device = "cuda" , dtype = torch .bfloat16 ))
577
+ m_fp8 = copy .deepcopy (m_ref )
578
+ m_fp8 = convert_to_float8_training (m_fp8 , config = config )
579
+
580
+ y_ref = m_ref (x0_ref )
581
+ y_ref .sum ().backward ()
582
+
583
+ y_fp8 = m_fp8 (x0_fp8 )
584
+ y_fp8 .sum ().backward ()
585
+
586
+ assert torch .allclose (y_ref , y_fp8 , rtol = 0 , atol = 0 )
587
+ assert torch .allclose (
588
+ m_ref [0 ].weight .grad , m_fp8 [0 ].weight .grad , rtol = 0 , atol = 0
589
+ )
590
+ assert torch .allclose (x0_ref .grad , x0_fp8 .grad , rtol = 0 , atol = 0 )
591
+
555
592
556
593
class TestScaledMM :
557
594
@unittest .skipIf (
0 commit comments