30
30
from torchao .utils import (
31
31
TORCH_VERSION_AT_LEAST_2_3 ,
32
32
TORCH_VERSION_AT_LEAST_2_5 ,
33
- benchmark_model ,
33
+ TorchAOBaseTensor ,
34
34
)
35
35
36
36
from torchao .quantization .granularity import (
61
61
"autoquant_v2" ,
62
62
"DEFAULT_AUTOQUANT_CLASS_LIST" ,
63
63
"DEFAULT_INT4_AUTOQUANT_CLASS_LIST" ,
64
+ "DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST" ,
64
65
"OTHER_AUTOQUANT_CLASS_LIST" ,
65
66
"_is_linear" ,
66
67
]
@@ -288,7 +289,7 @@ def to_quantized(self, error_on_unseen, **kwargs):
288
289
)
289
290
elif (self .logged_data == {}) and not error_on_unseen :
290
291
# default back to non-quantized weight if not seen
291
- self = AQFloatLinearWeight .from_float (self .weight )
292
+ self = AQDefaultLinearWeight .from_float (self .weight )
292
293
return self
293
294
294
295
# only want to print shape (at start) and final result (at end)
@@ -360,7 +361,7 @@ def count_shapes(self, do_print=True):
360
361
print (f"best_cls={ best_cls } \n " )
361
362
# TODO handle random cls args/kwargs? or should they be curried?
362
363
if best_cls is None :
363
- best_cls = AQFloatLinearWeight
364
+ best_cls = AQDefaultLinearWeight
364
365
365
366
self = best_cls .from_float (self .weight )
366
367
return self
@@ -802,7 +803,7 @@ class AQInt4G256WeightOnlyQuantizedLinearWeight(
802
803
group_size : int = 256
803
804
804
805
805
- class AQFloatLinearWeight (torch .Tensor , AQMixin ):
806
+ class AQDefaultLinearWeight (torch .Tensor , AQMixin ):
806
807
"""
807
808
A class to be used in concert with AutoQuantizableLinearWeight to provide a
808
809
default/non-quantized option. Only implements the bare minimum needed to work with the
@@ -823,6 +824,130 @@ def from_float(cls, weight):
823
824
return weight
824
825
825
826
827
+ class Float32Tensor (TorchAOBaseTensor ):
828
+ """ Tensor subclass tensor for fp32 dtype
829
+ """
830
+ def __init__ (self , weight ):
831
+ self .weight = weight .to (torch .float32 )
832
+
833
+ @staticmethod
834
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
835
+ _DTYPE = torch .float32
836
+ orig_dtype = act_mat .dtype
837
+ return torch .nn .functional .linear (
838
+ act_mat .to (_DTYPE ),
839
+ w_qtensor .weight ,
840
+ bias .to (_DTYPE ) if bias is not None else bias ,
841
+ ).to (dtype = orig_dtype )
842
+
843
+ def _apply_fn_to_data (self , fn ):
844
+ return self .__class__ (
845
+ fn (self .weight ),
846
+ )
847
+
848
+ @classmethod
849
+ def from_float (cls , weight ):
850
+ return cls (weight )
851
+
852
+ @Float32Tensor .implements ([torch .nn .functional .linear , aten .linear .default ])
853
+ def _ (func , types , args , kwargs ):
854
+ input_tensor , weight_tensor , bias = (
855
+ args [0 ],
856
+ args [1 ],
857
+ args [2 ] if len (args ) > 2 else None ,
858
+ )
859
+ return weight_tensor ._quantized_linear_op (input_tensor , weight_tensor , bias )
860
+
861
+ @Float32Tensor .implements (aten .detach .default )
862
+ def _ (func , types , args , kwargs ):
863
+ return return_and_correct_aliasing (
864
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
865
+ )
866
+
867
+
868
+ @Float32Tensor .implements (aten .clone .default )
869
+ def _ (func , types , args , kwargs ):
870
+ return return_and_correct_aliasing (
871
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
872
+ )
873
+
874
+
875
+ @Float32Tensor .implements (aten ._to_copy .default )
876
+ def _ (func , types , args , kwargs ):
877
+ return return_and_correct_aliasing (
878
+ func ,
879
+ args ,
880
+ kwargs ,
881
+ args [0 ].to (* args [1 :], ** kwargs )._apply_fn_to_data (torch .clone ),
882
+ )
883
+
884
+
885
+ class BFloat16Tensor (Float32Tensor ):
886
+ def __init__ (self , weight ):
887
+ self .weight = weight .to (torch .bfloat16 )
888
+
889
+ @staticmethod
890
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
891
+ _DTYPE = torch .bfloat16
892
+ orig_dtype = act_mat .dtype
893
+ return torch .nn .functional .linear (
894
+ act_mat .to (_DTYPE ),
895
+ w_qtensor .weight ,
896
+ bias .to (_DTYPE ) if bias is not None else bias ,
897
+ ).to (dtype = orig_dtype )
898
+
899
+
900
+ class Float16Tensor (Float32Tensor ):
901
+ def __init__ (self , weight ):
902
+ self .weight = weight .to (torch .float16 )
903
+
904
+ @staticmethod
905
+ def _quantized_linear_op (act_mat , w_qtensor , bias ):
906
+ _DTYPE = torch .float16
907
+ orig_dtype = act_mat .dtype
908
+ return torch .nn .functional .linear (
909
+ act_mat .to (_DTYPE ),
910
+ w_qtensor .weight ,
911
+ bias .to (_DTYPE ) if bias is not None else bias ,
912
+ ).to (dtype = orig_dtype )
913
+
914
+
915
+ class AQFloat32LinearWeight (Float32Tensor , AQMixin ):
916
+ """
917
+ AutoQuantizable version for float32 precision weight
918
+
919
+ (also converts input activation and bias to float32, and restores the original precision after
920
+ linear)
921
+ """
922
+ @classmethod
923
+ def from_float (cls , weight ):
924
+ return super (AQFloat32LinearWeight , cls ).from_float (weight )
925
+
926
+
927
+ class AQBFloat16LinearWeight (BFloat16Tensor , AQMixin ):
928
+ """
929
+ AutoQuantizable version for bfloat16 precision weight
930
+
931
+ (also converts input activation and bias to bfloat16, and restores the original precision after
932
+ linear)
933
+ """
934
+ @classmethod
935
+ def from_float (cls , weight ):
936
+ return super (AQBFloat16LinearWeight , cls ).from_float (weight )
937
+
938
+
939
+ class AQFloat16LinearWeight (Float16Tensor , AQMixin ):
940
+ """
941
+ AutoQuantizable version for float16 precision weight
942
+
943
+ (also converts input activation and bias to float16, and restores the original precision after
944
+ linear)
945
+ """
946
+ @classmethod
947
+ def from_float (cls , weight ):
948
+ return super (AQFloat16LinearWeight , cls ).from_float (weight )
949
+
950
+
826
951
class AQFloat8WeightOnlyQuantizedLinearWeight (AffineQuantizedTensor , AQMixin ):
827
952
"""
828
953
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
@@ -936,7 +1061,7 @@ def get_weight_block_size(x):
936
1061
937
1062
# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
938
1063
DEFAULT_AUTOQUANT_CLASS_LIST = [
939
- AQFloatLinearWeight ,
1064
+ AQDefaultLinearWeight ,
940
1065
AQInt8WeightOnlyQuantizedLinearWeight ,
941
1066
AQInt8WeightOnlyQuantizedLinearWeight2 ,
942
1067
# AQInt8WeightOnlyQuantizedLinearWeight3,
@@ -945,11 +1070,17 @@ def get_weight_block_size(x):
945
1070
]
946
1071
947
1072
DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
948
- AQFloatLinearWeight ,
1073
+ AQDefaultLinearWeight ,
949
1074
AQInt8DynamicallyQuantizedLinearWeight ,
950
1075
AQInt4G64WeightOnlyQuantizedLinearWeight ,
951
1076
]
952
1077
1078
+ DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST = [
1079
+ AQFloat32LinearWeight ,
1080
+ AQBFloat16LinearWeight ,
1081
+ AQFloat16LinearWeight ,
1082
+ ]
1083
+
953
1084
OTHER_AUTOQUANT_CLASS_LIST = [
954
1085
AQFloat8WeightOnlyQuantizedLinearWeight ,
955
1086
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight ,
0 commit comments