9
9
import unittest
10
10
11
11
import torch
12
- from parameterized import parameterized
13
12
14
- from torchao .float8 .float8_utils import EPS as float8_eps
15
13
from torchao .quantization .quant_primitives import (
16
14
MappingType ,
17
15
ZeroPointDomain ,
18
16
choose_qparams_affine ,
19
- choose_qparams_affine_float8 ,
17
+ choose_qparams_affine_tinygemm ,
20
18
dequantize_affine ,
21
- dequantize_affine_float8 ,
22
19
fake_quantize_affine ,
23
20
fake_quantize_affine_cachemask ,
24
21
quantize_affine ,
25
- quantize_affine_float8 ,
26
22
)
27
23
28
24
# TODO: remove test for utils?
@@ -650,35 +646,6 @@ def test_raises(self):
650
646
with self .assertRaisesRegex (RuntimeError , "is invalid for input of size 1" ):
651
647
_ = quantize_affine (input , block_size , scale , zero_point , dtype )
652
648
653
- def test_not_preserve_zero_not_supported (self ):
654
- """Making sure preserve_zero == False is not supported for symmetric quant"""
655
- input = torch .randn (10 , 256 )
656
- n_bit = 4
657
- mapping_type = MappingType .SYMMETRIC
658
- dtype = torch .int8
659
- block_size = (1 , 128 )
660
- quant_min = 0
661
- quant_max = 2 ** n_bit - 1
662
- eps = 1e-6
663
- scale_dtype = torch .bfloat16
664
- zero_point_dtype = torch .bfloat16
665
- with self .assertRaisesRegex (
666
- ValueError ,
667
- "preserve_zero == False is not supported for symmetric quantization" ,
668
- ):
669
- choose_qparams_affine (
670
- input ,
671
- mapping_type ,
672
- block_size ,
673
- dtype ,
674
- quant_min ,
675
- quant_max ,
676
- eps ,
677
- scale_dtype = scale_dtype ,
678
- zero_point_dtype = zero_point_dtype ,
679
- preserve_zero = False ,
680
- )
681
-
682
649
def test_get_groupwise_affine_qparams (self ):
683
650
input = torch .randn (10 , 256 )
684
651
n_bit = 4
@@ -702,22 +669,33 @@ def test_get_groupwise_affine_qparams(self):
702
669
dtype = torch .bfloat16 ,
703
670
zero_point_domain = zero_point_domain ,
704
671
)
705
- scale , zero_point = choose_qparams_affine (
706
- input ,
707
- mapping_type ,
708
- block_size ,
709
- dtype ,
710
- quant_min ,
711
- quant_max ,
712
- eps ,
713
- scale_dtype = scale_dtype ,
714
- zero_point_dtype = zero_point_dtype ,
715
- preserve_zero = zero_point_domain == ZeroPointDomain .INT ,
716
- zero_point_domain = zero_point_domain ,
717
- )
672
+ if zero_point_domain == ZeroPointDomain .FLOAT :
673
+ scale , zero_point = choose_qparams_affine_tinygemm (
674
+ input ,
675
+ mapping_type ,
676
+ block_size ,
677
+ dtype ,
678
+ quant_min ,
679
+ quant_max ,
680
+ eps ,
681
+ scale_dtype = scale_dtype ,
682
+ zero_point_dtype = zero_point_dtype ,
683
+ )
684
+ else :
685
+ scale , zero_point = choose_qparams_affine (
686
+ input ,
687
+ mapping_type ,
688
+ block_size ,
689
+ dtype ,
690
+ quant_min ,
691
+ quant_max ,
692
+ eps ,
693
+ scale_dtype = scale_dtype ,
694
+ zero_point_dtype = zero_point_dtype ,
695
+ )
718
696
719
- self .assertTrue (torch .equal (scale , scale_ref ))
720
- self .assertTrue (torch .equal (zero_point , zero_point_ref ))
697
+ self .assertTrue (torch .equal (scale , scale_ref ))
698
+ self .assertTrue (torch .equal (zero_point , zero_point_ref ))
721
699
722
700
def test_groupwise_affine_quantize_tensor_from_qparams (self ):
723
701
input = torch .randn (10 , 256 )
@@ -847,120 +825,6 @@ def test_fake_quantize_affine_cachemask(self):
847
825
torch .testing .assert_close (dequantized , fake_quantized )
848
826
torch .testing .assert_close (expected_mask , mask )
849
827
850
- def test_none_zero_point_domain (self ):
851
- """A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should"""
852
- input = torch .randn (10 , 256 )
853
- mapping_type = MappingType .SYMMETRIC
854
- dtype = torch .int8
855
- block_size = (1 , 128 )
856
- quant_min = None
857
- quant_max = None
858
- eps = 1e-6
859
- scale_dtype = torch .float32
860
- zero_point_dtype = torch .int64
861
- try :
862
- _ , zero_point = choose_qparams_affine (
863
- input ,
864
- mapping_type ,
865
- block_size ,
866
- dtype ,
867
- quant_min ,
868
- quant_max ,
869
- eps ,
870
- scale_dtype = scale_dtype ,
871
- zero_point_dtype = zero_point_dtype ,
872
- preserve_zero = True ,
873
- zero_point_domain = None ,
874
- )
875
- except ValueError :
876
- # This exception was expected
877
- # Now test for ZeroPointDomain.NONE
878
- _ , zero_point = choose_qparams_affine (
879
- input ,
880
- mapping_type ,
881
- block_size ,
882
- dtype ,
883
- quant_min ,
884
- quant_max ,
885
- eps ,
886
- scale_dtype = scale_dtype ,
887
- zero_point_dtype = zero_point_dtype ,
888
- preserve_zero = True ,
889
- zero_point_domain = ZeroPointDomain .NONE ,
890
- )
891
- self .assertTrue (zero_point is None )
892
- else :
893
- # An exception should have been thrown for zero_point_domain None
894
- self .assertTrue (
895
- False ,
896
- msg = "A runtime exception should have been thrown for zero_point_domain None" ,
897
- )
898
-
899
- @parameterized .expand (
900
- [
901
- (
902
- torch .float32 ,
903
- torch .float8_e4m3fn ,
904
- ),
905
- (
906
- torch .float32 ,
907
- torch .float8_e5m2 ,
908
- ),
909
- (
910
- torch .bfloat16 ,
911
- torch .float8_e4m3fn ,
912
- ),
913
- (
914
- torch .bfloat16 ,
915
- torch .float8_e5m2 ,
916
- ),
917
- ]
918
- )
919
- def test_float8_quant_primitives (self , hp_dtype , float8_dtype ):
920
- input = torch .randn (10 , 10 )
921
-
922
- # float8 quantization primitives
923
- scale = choose_qparams_affine_float8 (input , float8_dtype = float8_dtype )
924
- quantized = quantize_affine_float8 (input , scale , float8_dtype = float8_dtype )
925
- dequantized = dequantize_affine_float8 (quantized , scale , output_dtype = hp_dtype )
926
-
927
- # reference implementation using generic primitives
928
- expected_scale , _ = choose_qparams_affine (
929
- input ,
930
- MappingType .SYMMETRIC ,
931
- input .shape ,
932
- float8_dtype ,
933
- eps = float8_eps , # use same EPS as float8 training
934
- scale_dtype = torch .float32 ,
935
- quant_min = torch .finfo (float8_dtype ).min ,
936
- quant_max = torch .finfo (float8_dtype ).max ,
937
- )
938
- expected_quantized = quantize_affine (
939
- input ,
940
- input .shape ,
941
- scale ,
942
- output_dtype = float8_dtype ,
943
- quant_min = torch .finfo (float8_dtype ).min ,
944
- quant_max = torch .finfo (float8_dtype ).max ,
945
- zero_point = None ,
946
- zero_point_domain = ZeroPointDomain .NONE ,
947
- )
948
- expected_dequantized = dequantize_affine (
949
- expected_quantized ,
950
- input .shape ,
951
- scale ,
952
- input_dtype = float8_dtype ,
953
- output_dtype = hp_dtype ,
954
- quant_min = torch .finfo (float8_dtype ).min ,
955
- quant_max = torch .finfo (float8_dtype ).max ,
956
- zero_point = None ,
957
- zero_point_domain = ZeroPointDomain .NONE ,
958
- )
959
-
960
- self .assertTrue (torch .equal (expected_scale , scale ))
961
- torch .testing .assert_close (expected_quantized , quantized )
962
- torch .testing .assert_close (expected_dequantized , dequantized )
963
-
964
828
965
829
if __name__ == "__main__" :
966
830
unittest .main ()
0 commit comments