@@ -657,3 +657,301 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
657
657
assert x .t ().dtype == x_reconstructed_t .dtype , (
658
658
f"Transpose dtype mismatch: { x .t ().dtype } vs { x_reconstructed_t .dtype } "
659
659
)
660
+
661
+
662
+ @pytest .mark .parametrize (
663
+ "shape" ,
664
+ [
665
+ (128 , 4 ),
666
+ (256 , 8 ),
667
+ (100 , 3 ),
668
+ (4 , 4 ),
669
+ (50 , 10 ),
670
+ (384 , 12 ),
671
+ ],
672
+ )
673
+ @pytest .mark .parametrize (
674
+ "use_triton_kernel" , [False , True ] if torch .cuda .is_available () else [False ]
675
+ )
676
+ @pytest .mark .skipif (
677
+ not TORCH_VERSION_AT_LEAST_2_8 , reason = "torch.compile requires PyTorch 2.8+"
678
+ )
679
+ def test_to_blocked_from_blocked_roundtrip (shape , use_triton_kernel : bool ):
680
+ from torchao .prototype .mx_formats .utils import from_blocked , to_blocked
681
+
682
+ rows , cols = shape
683
+ device = "cuda" if torch .cuda .is_available () else "cpu"
684
+
685
+ original = torch .randint (0 , 255 , (rows , cols ), device = device , dtype = torch .uint8 )
686
+
687
+ blocked = to_blocked (original , use_triton_kernel = use_triton_kernel )
688
+ reconstructed = from_blocked (blocked , rows , cols )
689
+
690
+ torch .testing .assert_close (
691
+ original ,
692
+ reconstructed ,
693
+ atol = 0.0 ,
694
+ rtol = 0.0 ,
695
+ msg = f"Roundtrip failed for shape { shape } with use_triton_kernel={ use_triton_kernel } " ,
696
+ )
697
+
698
+
699
+ @pytest .mark .parametrize ("is_swizzled_scales" , [False , True ])
700
+ @pytest .mark .parametrize (
701
+ "shape" ,
702
+ [
703
+ (32 , 64 ),
704
+ (16 , 32 ),
705
+ (64 , 128 ),
706
+ (384 , 128 ),
707
+ ],
708
+ )
709
+ @pytest .mark .skipif (
710
+ not TORCH_VERSION_AT_LEAST_2_8 , reason = "torch.compile requires PyTorch 2.8+"
711
+ )
712
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
713
+ def test_nvfp4_swizzled_scales_construction (is_swizzled_scales , shape ):
714
+ """
715
+ Test that NVFP4Tensor can be constructed with swizzled scales and
716
+ that the _is_swizzled_scales flag is set correctly.
717
+ """
718
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
719
+
720
+ M , K = shape
721
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
722
+
723
+ tensor = NVFP4Tensor .to_nvfp4 (data , is_swizzled_scales = is_swizzled_scales )
724
+ assert tensor ._is_swizzled_scales == is_swizzled_scales
725
+ reconstructed = tensor .to_dtype (torch .bfloat16 )
726
+ assert reconstructed .shape == data .shape
727
+
728
+
729
+ @pytest .mark .parametrize (
730
+ "slice_dim,slice_spec" ,
731
+ [
732
+ # Row slicing - must align with 128-row boundaries
733
+ pytest .param (0 , slice (0 , 128 ), id = "slice_rows[0:128]" ),
734
+ pytest .param (0 , slice (128 , 256 ), id = "slice_rows[128:256]" ),
735
+ # Column slicing - must align with 64-column boundaries (4 scale columns * 16 block_size)
736
+ pytest .param (1 , slice (0 , 64 ), id = "slice_cols[0:64]" ),
737
+ pytest .param (1 , slice (64 , 128 ), id = "slice_cols[64:128]" ),
738
+ pytest .param (1 , slice (0 , 128 ), id = "slice_cols[0:128]_full_width" ),
739
+ # Test tensor parallelism patterns (half splits)
740
+ pytest .param (1 , slice (0 , 2048 ), id = "slice_cols[0:2048]_tp_first_half" ),
741
+ pytest .param (1 , slice (2048 , 4096 ), id = "slice_cols[2048:4096]_tp_second_half" ),
742
+ # Test quarter splits
743
+ pytest .param (1 , slice (0 , 1024 ), id = "slice_cols[0:1024]_quarter" ),
744
+ pytest .param (1 , slice (1024 , 2048 ), id = "slice_cols[1024:2048]_quarter" ),
745
+ ],
746
+ )
747
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
748
+ @pytest .mark .skipif (
749
+ not TORCH_VERSION_AT_LEAST_2_8 , reason = "NVFP4 requires PyTorch 2.8+"
750
+ )
751
+ def test_nvfp4_swizzled_scales_slicing (slice_dim , slice_spec ):
752
+ """
753
+ Test that slicing works correctly with swizzled scales and maintains
754
+ the swizzled state in the output tensor.
755
+ """
756
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
757
+
758
+ # Use larger tensor sizes that align with swizzled requirements
759
+ if slice_dim == 0 :
760
+ # For row slicing, need at least 256 rows to test 128-row boundaries
761
+ M , K = 256 , 4096
762
+ else :
763
+ # For column slicing, need multiples of 64 columns for alignment
764
+ M , K = 128 , 4096
765
+
766
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
767
+
768
+ tensor = NVFP4Tensor .to_nvfp4 (data , is_swizzled_scales = True )
769
+ assert tensor ._is_swizzled_scales == True
770
+
771
+ if slice_dim == 0 :
772
+ sliced_tensor = tensor [slice_spec , :]
773
+ else :
774
+ sliced_tensor = tensor [:, slice_spec ]
775
+
776
+ # Verify sliced tensor maintains swizzled state
777
+ assert sliced_tensor ._is_swizzled_scales == True
778
+
779
+ # Verify sliced tensor can be dequantized
780
+ sliced_reconstructed = sliced_tensor .to_dtype (torch .bfloat16 )
781
+
782
+ # Compare with direct slicing of original data
783
+ original_reconstructed = tensor .to_dtype (torch .bfloat16 )
784
+ if slice_dim == 0 :
785
+ expected = original_reconstructed [slice_spec , :]
786
+ else :
787
+ expected = original_reconstructed [:, slice_spec ]
788
+
789
+ torch .testing .assert_close (sliced_reconstructed , expected , atol = 1e-6 , rtol = 1e-6 )
790
+
791
+
792
+ @pytest .mark .parametrize (
793
+ "slice_dim,slice_spec,expected_error" ,
794
+ [
795
+ # Row slicing with misaligned boundaries
796
+ pytest .param (
797
+ 0 ,
798
+ slice (0 , 100 ),
799
+ "Row slicing of NVFP4Tensor with swizzled scales requires" ,
800
+ id = "misaligned_row_end" ,
801
+ ),
802
+ pytest .param (
803
+ 0 ,
804
+ slice (50 , 150 ),
805
+ "Row slicing of NVFP4Tensor with swizzled scales requires" ,
806
+ id = "misaligned_row_start" ,
807
+ ),
808
+ # Column slicing with misaligned boundaries
809
+ pytest .param (
810
+ 1 ,
811
+ slice (0 , 32 ),
812
+ "Column slicing of NVFP4Tensor with swizzled scales requires" ,
813
+ id = "misaligned_col_32" ,
814
+ ),
815
+ pytest .param (
816
+ 1 ,
817
+ slice (16 , 80 ),
818
+ "Column slicing of NVFP4Tensor with swizzled scales requires" ,
819
+ id = "misaligned_col_start" ,
820
+ ),
821
+ pytest .param (
822
+ 1 ,
823
+ slice (0 , 100 ),
824
+ "Column slicing of NVFP4Tensor with swizzled scales requires" ,
825
+ id = "misaligned_col_end" ,
826
+ ),
827
+ # Odd column boundaries (FP4 packing requirement)
828
+ pytest .param (
829
+ 1 ,
830
+ slice (1 , 65 ),
831
+ "start index to be a multiple of 64, got 1" ,
832
+ id = "odd_start" ,
833
+ ),
834
+ pytest .param (
835
+ 1 ,
836
+ slice (0 , 65 ),
837
+ " multiple of 64 or equal to tensor size 4096, got 65" ,
838
+ id = "odd_end" ,
839
+ ),
840
+ ],
841
+ )
842
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
843
+ @pytest .mark .skipif (
844
+ not TORCH_VERSION_AT_LEAST_2_8 , reason = "NVFP4 requires PyTorch 2.8+"
845
+ )
846
+ def test_nvfp4_swizzled_scales_slicing_errors (slice_dim , slice_spec , expected_error ):
847
+ """
848
+ Test that slicing raises appropriate errors for misaligned boundaries.
849
+ """
850
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
851
+
852
+ M , K = 256 , 4096
853
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
854
+ tensor = NVFP4Tensor .to_nvfp4 (data , is_swizzled_scales = True )
855
+
856
+ with pytest .raises (RuntimeError , match = expected_error ):
857
+ if slice_dim == 0 :
858
+ _ = tensor [slice_spec , :]
859
+ else :
860
+ _ = tensor [:, slice_spec ]
861
+
862
+
863
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
864
+ @pytest .mark .skipif (
865
+ not TORCH_VERSION_AT_LEAST_2_8 , reason = "NVFP4 requires PyTorch 2.8+"
866
+ )
867
+ def test_nvfp4_swizzled_scales_view_semantics ():
868
+ """
869
+ Test that slicing maintains proper view semantics where possible.
870
+ """
871
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
872
+
873
+ M , K = 256 , 4096
874
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
875
+ tensor = NVFP4Tensor .to_nvfp4 (data , is_swizzled_scales = True )
876
+
877
+ # Test row slicing (should maintain views)
878
+ sliced_tensor = tensor [0 :128 , :]
879
+
880
+ # Test that the sliced tensor shares storage with original for data
881
+ # (Note: scales might not share storage due to swizzled layout complexity)
882
+ assert sliced_tensor ._data .data_ptr () == tensor ._data .data_ptr ()
883
+
884
+ # Test full-width column slicing (should maintain views)
885
+ full_width_slice = tensor [:, 0 :K ]
886
+ assert full_width_slice ._scale_e4m3 .data_ptr () == tensor ._scale_e4m3 .data_ptr ()
887
+ assert full_width_slice ._data .data_ptr () == tensor ._data .data_ptr ()
888
+
889
+
890
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
891
+ @pytest .mark .skipif (
892
+ not TORCH_VERSION_AT_LEAST_2_8 , reason = "NVFP4 requires PyTorch 2.8+"
893
+ )
894
+ def test_nvfp4_swizzled_scales_serialization ():
895
+ """
896
+ Test that tensor flatten/unflatten preserves the swizzled scales state.
897
+ """
898
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
899
+
900
+ M , K = 32 , 64
901
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
902
+
903
+ # Create tensor with swizzled scales
904
+ original_tensor = NVFP4Tensor .to_nvfp4 (data , is_swizzled_scales = True )
905
+
906
+ # Test serialization
907
+ tensor_list , ctx = original_tensor .__tensor_flatten__ ()
908
+
909
+ # Verify swizzled flag is preserved in context
910
+ assert "_is_swizzled_scales" in ctx
911
+ assert ctx ["_is_swizzled_scales" ] == True
912
+
913
+ # Test deserialization
914
+ inner_tensors = {}
915
+ for name in tensor_list :
916
+ inner_tensors [name ] = getattr (original_tensor , name )
917
+
918
+ reconstructed_tensor = NVFP4Tensor .__tensor_unflatten__ (
919
+ inner_tensors , ctx , None , None
920
+ )
921
+
922
+ # Verify the swizzled state is preserved
923
+ assert reconstructed_tensor ._is_swizzled_scales == True
924
+
925
+ # Verify functionality is preserved
926
+ original_dq = original_tensor .to_dtype (torch .bfloat16 )
927
+ reconstructed_dq = reconstructed_tensor .to_dtype (torch .bfloat16 )
928
+
929
+ torch .testing .assert_close (original_dq , reconstructed_dq , atol = 1e-6 , rtol = 1e-6 )
930
+
931
+
932
+ @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
933
+ @pytest .mark .skipif (
934
+ not TORCH_VERSION_AT_LEAST_2_8 , reason = "NVFP4 requires PyTorch 2.8+"
935
+ )
936
+ def test_nvfp4_swizzled_scales_get_scales_method ():
937
+ """
938
+ Test that the get_scales() method correctly unswizzles scales when needed.
939
+ """
940
+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
941
+
942
+ M , K = 32 , 64
943
+ data = torch .randn (M , K , device = "cuda" , dtype = torch .bfloat16 )
944
+
945
+ # Create tensors with both storage methods
946
+ regular_tensor = NVFP4Tensor .to_nvfp4 (data , is_swizzled_scales = False )
947
+ swizzled_tensor = NVFP4Tensor .to_nvfp4 (data , is_swizzled_scales = True )
948
+
949
+ # Get scales from both tensors and verify they are equal
950
+ regular_scales = regular_tensor .get_hp_scales ()
951
+ swizzled_scales = swizzled_tensor .get_hp_scales ()
952
+ torch .testing .assert_close (regular_scales , swizzled_scales , atol = 0.0 , rtol = 0.0 )
953
+
954
+ # Verify scales have the expected shape
955
+ expected_shape = (M , K // 16 )
956
+ assert regular_scales .shape == expected_shape
957
+ assert swizzled_scales .shape == expected_shape
0 commit comments