7
7
)
8
8
from torchao .utils import (
9
9
TORCH_VERSION_AT_LEAST_2_4 ,
10
+ TORCH_VERSION_AT_LEAST_2_5 ,
10
11
)
11
12
from torchao .quantization .quant_api import (
13
+ int4_weight_only ,
14
+ int8_weight_only ,
15
+ int8_dynamic_activation_int8_weight ,
16
+ quantize_ ,
12
17
_replace_with_custom_fn_if_matches_filter ,
13
18
)
14
19
import copy
15
20
21
+ def _int8wo_api (mod , ** kwargs ):
22
+ if TORCH_VERSION_AT_LEAST_2_4 :
23
+ quantize_ (mod , int8_weight_only (** kwargs ), set_inductor_config = False )
24
+ if not TORCH_VERSION_AT_LEAST_2_5 :
25
+ unwrap_tensor_subclass (mod )
26
+ else :
27
+ change_linear_weights_to_int8_woqtensors (mod , ** kwargs )
28
+
29
+ def _int8da_int8w_api (mod , ** kwargs ):
30
+ if TORCH_VERSION_AT_LEAST_2_4 :
31
+ quantize_ (mod , int8_dynamic_activation_int8_weight (** kwargs ), set_inductor_config = False )
32
+ if not TORCH_VERSION_AT_LEAST_2_5 :
33
+ unwrap_tensor_subclass (mod )
34
+ else :
35
+ change_linear_weights_to_int8_dqtensors (mod , ** kwargs )
36
+
37
+ def _int4wo_api (mod , ** kwargs ):
38
+ if TORCH_VERSION_AT_LEAST_2_4 :
39
+ kwargs_copy = kwargs .copy ()
40
+ if "groupsize" in kwargs_copy :
41
+ kwargs_copy ["group_size" ] = kwargs_copy ["groupsize" ]
42
+ del kwargs_copy ["groupsize" ]
43
+ quantize_ (mod , int4_weight_only (** kwargs_copy ), set_inductor_config = False )
44
+ if not TORCH_VERSION_AT_LEAST_2_5 :
45
+ unwrap_tensor_subclass (mod )
46
+ else :
47
+ change_linear_weights_to_int4_woqtensors (mod , ** kwargs )
48
+
16
49
class ToyLinearModel (torch .nn .Module ):
17
- def __init__ (self , m = 64 , n = 32 , k = 64 ):
50
+ """Single linear for m * k * n problem size
51
+ """
52
+ def __init__ (self , m = 64 , n = 32 , k = 64 , has_bias = False , dtype = torch .float , device = "cuda" ):
18
53
super ().__init__ ()
19
- self .linear1 = torch .nn .Linear (m , n , bias = False ).to (torch .float )
20
- self .linear2 = torch .nn .Linear (n , k , bias = False ).to (torch .float )
54
+ self .m = m
55
+ self .dtype = dtype
56
+ self .device = device
57
+ self .linear = torch .nn .Linear (k , n , bias = has_bias ).to (dtype = self .dtype , device = self .device )
21
58
22
- def example_inputs (self , batch_size = 1 , dtype = torch . float , device = "cpu" ):
23
- return (torch .randn (batch_size , self .linear1 .in_features , dtype = dtype , device = device ),)
59
+ def example_inputs (self ):
60
+ return (torch .randn (self . m , self .linear .in_features , dtype = self . dtype , device = self . device ),)
24
61
25
62
def forward (self , x ):
26
- x = self .linear1 (x )
27
- x = self .linear2 (x )
63
+ x = self .linear (x )
28
64
return x
29
65
30
66
def _ref_change_linear_weights_to_int8_dqtensors (model , filter_fn = None , ** kwargs ):
@@ -69,14 +105,17 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
69
105
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors (Int4WeightOnlyQuantizedLinearWeight )
70
106
71
107
72
- def _bench_quantized_tensor_subclass_perf (api , ref_api , kwargs = None ):
108
+ torch ._dynamo .config .cache_size_limit = 50000
109
+
110
+ @torch .no_grad
111
+ def _bench_quantized_tensor_subclass_perf (api , ref_api , M , N , K , kwargs = None ):
73
112
if kwargs is None :
74
113
kwargs = {}
75
114
76
- m = ToyLinearModel (1024 , 1024 , 1024 ).eval ().to (torch .bfloat16 ).to ("cuda" )
115
+ m = ToyLinearModel (M , N , K , has_bias = True , dtype = torch .bfloat16 , device = "cuda" ).eval ()
116
+ m_bf16 = copy .deepcopy (m )
77
117
m_ref = copy .deepcopy (m )
78
- # setting batch_size to 20 to be compatible with the kernel
79
- example_inputs = m .example_inputs (batch_size = 20 , dtype = torch .bfloat16 , device = "cuda" )
118
+ example_inputs = m .example_inputs ()
80
119
81
120
api (m , ** kwargs )
82
121
@@ -91,27 +130,41 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
91
130
# perf comparison
92
131
from torchao .utils import benchmark_model
93
132
# warmup
94
- WARMUP = 5
133
+ WARMUP = 20
95
134
RUNS = 100
96
- m = torch .compile (m , mode = 'max-autotune' , fullgraph = True )
97
-
98
- benchmark_model (m , WARMUP , example_inputs )
99
- elapsed_time = benchmark_model (m , RUNS , example_inputs )
100
135
101
136
m_ref = torch .compile (m_ref , mode = 'max-autotune' , fullgraph = True )
102
137
benchmark_model (m_ref , WARMUP , example_inputs )
103
138
ref_elapsed_time = benchmark_model (m_ref , RUNS , example_inputs )
104
139
105
- print (f"elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } " )
106
- assert elapsed_time < 1.05 * ref_elapsed_time
140
+ m = torch .compile (m , mode = 'max-autotune' , fullgraph = True )
141
+ benchmark_model (m , WARMUP , example_inputs )
142
+ elapsed_time = benchmark_model (m , RUNS , example_inputs )
143
+
144
+
145
+ m_bf16 = torch .compile (m_bf16 , mode = 'max-autotune' , fullgraph = True )
146
+ benchmark_model (m_bf16 , WARMUP , example_inputs )
147
+ bf16_elapsed_time = benchmark_model (m_bf16 , RUNS , example_inputs )
148
+
149
+ print (f"{ (M , N , K )} : elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } " )
107
150
108
151
if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch .cuda .is_available ():
152
+ all_shapes = [
153
+ (20 , 2048 , 2048 ),
154
+ ]
155
+
156
+ print ("_int8da_int8w_api" )
109
157
from torchao .quantization .quant_api import change_linear_weights_to_int8_dqtensors
110
- _bench_quantized_tensor_subclass_perf (change_linear_weights_to_int8_dqtensors , _ref_change_linear_weights_to_int8_dqtensors )
158
+ for M , N , K in all_shapes :
159
+ _bench_quantized_tensor_subclass_perf (_int8da_int8w_api , _ref_change_linear_weights_to_int8_dqtensors , M , N , K )
111
160
161
+ print ("_int8wo_api" )
112
162
from torchao .quantization .quant_api import change_linear_weights_to_int8_woqtensors
113
- _bench_quantized_tensor_subclass_perf (change_linear_weights_to_int8_woqtensors , _ref_change_linear_weights_to_int8_woqtensors )
163
+ for M , N , K in all_shapes :
164
+ _bench_quantized_tensor_subclass_perf (_int8wo_api , _ref_change_linear_weights_to_int8_woqtensors , M , N , K )
114
165
166
+ print ("_int4wo_api" )
115
167
kwargs = {"groupsize" : 32 }
116
168
from torchao .quantization .quant_api import change_linear_weights_to_int4_woqtensors
117
- _bench_quantized_tensor_subclass_perf (change_linear_weights_to_int4_woqtensors , _ref_change_linear_weights_to_int4_woqtensors , kwargs )
169
+ for M , N , K in all_shapes :
170
+ _bench_quantized_tensor_subclass_perf (_int4wo_api , _ref_change_linear_weights_to_int4_woqtensors , M , N , K , kwargs )
0 commit comments