1
- """Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs
2
- """
1
+ """Benchmarks for affine quantized tensor, this includes int8 dynamic quant, int8 weight only quant and int4 weight only quant APIs"""
2
+
3
+ import copy
4
+
3
5
import torch
6
+
7
+ from torchao .quantization .quant_api import (
8
+ _replace_with_custom_fn_if_matches_filter ,
9
+ int4_weight_only ,
10
+ int8_dynamic_activation_int8_weight ,
11
+ int8_weight_only ,
12
+ quantize_ ,
13
+ )
4
14
from torchao .quantization .subclass import (
5
- Int8WeightOnlyQuantizedLinearWeight ,
6
15
Int4WeightOnlyQuantizedLinearWeight ,
16
+ Int8WeightOnlyQuantizedLinearWeight ,
7
17
)
8
18
from torchao .utils import (
9
19
TORCH_VERSION_AT_LEAST_2_4 ,
10
20
TORCH_VERSION_AT_LEAST_2_5 ,
21
+ unwrap_tensor_subclass ,
11
22
)
12
- from torchao .quantization .quant_api import (
13
- int4_weight_only ,
14
- int8_weight_only ,
15
- int8_dynamic_activation_int8_weight ,
16
- quantize_ ,
17
- _replace_with_custom_fn_if_matches_filter ,
18
- )
19
- import copy
20
- from torchao .utils import unwrap_tensor_subclass
23
+
21
24
22
25
def _int8wo_api (mod , ** kwargs ):
23
26
if TORCH_VERSION_AT_LEAST_2_4 :
@@ -27,14 +30,20 @@ def _int8wo_api(mod, **kwargs):
27
30
else :
28
31
change_linear_weights_to_int8_woqtensors (mod , ** kwargs )
29
32
33
+
30
34
def _int8da_int8w_api (mod , ** kwargs ):
31
35
if TORCH_VERSION_AT_LEAST_2_4 :
32
- quantize_ (mod , int8_dynamic_activation_int8_weight (** kwargs ), set_inductor_config = False )
36
+ quantize_ (
37
+ mod ,
38
+ int8_dynamic_activation_int8_weight (** kwargs ),
39
+ set_inductor_config = False ,
40
+ )
33
41
if not TORCH_VERSION_AT_LEAST_2_5 :
34
42
unwrap_tensor_subclass (mod )
35
43
else :
36
44
change_linear_weights_to_int8_dqtensors (mod , ** kwargs )
37
45
46
+
38
47
def _int4wo_api (mod , ** kwargs ):
39
48
if TORCH_VERSION_AT_LEAST_2_4 :
40
49
kwargs_copy = kwargs .copy ()
@@ -47,31 +56,43 @@ def _int4wo_api(mod, **kwargs):
47
56
else :
48
57
change_linear_weights_to_int4_woqtensors (mod , ** kwargs )
49
58
59
+
50
60
class ToyLinearModel (torch .nn .Module ):
51
- """Single linear for m * k * n problem size
52
- """
53
- def __init__ (self , m = 64 , n = 32 , k = 64 , has_bias = False , dtype = torch .float , device = "cuda" ):
61
+ """Single linear for m * k * n problem size"""
62
+
63
+ def __init__ (
64
+ self , m = 64 , n = 32 , k = 64 , has_bias = False , dtype = torch .float , device = "cuda"
65
+ ):
54
66
super ().__init__ ()
55
67
self .m = m
56
68
self .dtype = dtype
57
69
self .device = device
58
- self .linear = torch .nn .Linear (k , n , bias = has_bias ).to (dtype = self .dtype , device = self .device )
70
+ self .linear = torch .nn .Linear (k , n , bias = has_bias ).to (
71
+ dtype = self .dtype , device = self .device
72
+ )
59
73
60
74
def example_inputs (self ):
61
- return (torch .randn (self .m , self .linear .in_features , dtype = self .dtype , device = self .device ),)
75
+ return (
76
+ torch .randn (
77
+ self .m , self .linear .in_features , dtype = self .dtype , device = self .device
78
+ ),
79
+ )
62
80
63
81
def forward (self , x ):
64
82
x = self .linear (x )
65
83
return x
66
84
85
+
67
86
def _ref_change_linear_weights_to_int8_dqtensors (model , filter_fn = None , ** kwargs ):
68
87
"""
69
88
The deprecated implementation for int8 dynamic quant API, used as a reference for
70
89
numerics and performance
71
90
"""
72
- from torchao .quantization .quant_api import _in_features_greater_than_16
73
- from torchao .quantization .quant_api import _is_linear
74
- from torchao .quantization .quant_api import _get_subclass_inserter
91
+ from torchao .quantization .quant_api import (
92
+ _get_subclass_inserter ,
93
+ _in_features_greater_than_16 ,
94
+ _is_linear ,
95
+ )
75
96
from torchao .quantization .subclass import Int8DynamicallyQuantizedLinearWeight
76
97
77
98
if filter_fn is None :
@@ -80,40 +101,54 @@ def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs
80
101
)
81
102
82
103
_replace_with_custom_fn_if_matches_filter (
83
- model , _get_subclass_inserter (Int8DynamicallyQuantizedLinearWeight , enable_parametrization = False , ** kwargs ), filter_fn
104
+ model ,
105
+ _get_subclass_inserter (
106
+ Int8DynamicallyQuantizedLinearWeight , enable_parametrization = False , ** kwargs
107
+ ),
108
+ filter_fn ,
84
109
)
85
110
111
+
86
112
def _get_ref_change_linear_weights_to_woqtensors (deprecated_tenosr_subclass ):
87
113
def _ref_change_linear_weights_to_woqtensors (model , filter_fn = None , ** kwargs ):
88
114
"""
89
115
The deprecated implementation for weight only quant API, used as a reference for
90
116
numerics and performance
91
117
"""
92
- from torchao .quantization .quant_api import _is_linear
93
- from torchao .quantization .quant_api import _get_subclass_inserter
118
+ from torchao .quantization .quant_api import _get_subclass_inserter , _is_linear
94
119
95
120
filter_fn = kwargs .pop ("filter_fn" , _is_linear )
96
121
97
122
_replace_with_custom_fn_if_matches_filter (
98
123
model ,
99
- _get_subclass_inserter (deprecated_tenosr_subclass , enable_parametrization = True , ** kwargs ),
124
+ _get_subclass_inserter (
125
+ deprecated_tenosr_subclass , enable_parametrization = True , ** kwargs
126
+ ),
100
127
filter_fn ,
101
128
)
102
129
103
130
return _ref_change_linear_weights_to_woqtensors
104
131
105
- _ref_change_linear_weights_to_int8_woqtensors = _get_ref_change_linear_weights_to_woqtensors (Int8WeightOnlyQuantizedLinearWeight )
106
- _ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors (Int4WeightOnlyQuantizedLinearWeight )
132
+
133
+ _ref_change_linear_weights_to_int8_woqtensors = (
134
+ _get_ref_change_linear_weights_to_woqtensors (Int8WeightOnlyQuantizedLinearWeight )
135
+ )
136
+ _ref_change_linear_weights_to_int4_woqtensors = (
137
+ _get_ref_change_linear_weights_to_woqtensors (Int4WeightOnlyQuantizedLinearWeight )
138
+ )
107
139
108
140
109
141
torch ._dynamo .config .cache_size_limit = 50000
110
142
143
+
111
144
@torch .no_grad
112
145
def _bench_quantized_tensor_subclass_perf (api , ref_api , M , N , K , kwargs = None ):
113
146
if kwargs is None :
114
147
kwargs = {}
115
148
116
- m = ToyLinearModel (M , N , K , has_bias = True , dtype = torch .bfloat16 , device = "cuda" ).eval ()
149
+ m = ToyLinearModel (
150
+ M , N , K , has_bias = True , dtype = torch .bfloat16 , device = "cuda"
151
+ ).eval ()
117
152
m_bf16 = copy .deepcopy (m )
118
153
m_ref = copy .deepcopy (m )
119
154
example_inputs = m .example_inputs ()
@@ -130,26 +165,30 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
130
165
131
166
# perf comparison
132
167
from torchao .utils import benchmark_model
168
+
133
169
# warmup
134
170
WARMUP = 20
135
171
RUNS = 100
136
172
137
173
torch ._dynamo .reset ()
138
- m_ref = torch .compile (m_ref , mode = ' max-autotune' , fullgraph = True )
174
+ m_ref = torch .compile (m_ref , mode = " max-autotune" , fullgraph = True )
139
175
benchmark_model (m_ref , WARMUP , example_inputs )
140
176
ref_elapsed_time = benchmark_model (m_ref , RUNS , example_inputs )
141
177
142
178
torch ._dynamo .reset ()
143
- m = torch .compile (m , mode = ' max-autotune' , fullgraph = True )
179
+ m = torch .compile (m , mode = " max-autotune" , fullgraph = True )
144
180
benchmark_model (m , WARMUP , example_inputs )
145
181
elapsed_time = benchmark_model (m , RUNS , example_inputs )
146
182
147
183
torch ._dynamo .reset ()
148
- m_bf16 = torch .compile (m_bf16 , mode = ' max-autotune' , fullgraph = True )
184
+ m_bf16 = torch .compile (m_bf16 , mode = " max-autotune" , fullgraph = True )
149
185
benchmark_model (m_bf16 , WARMUP , example_inputs )
150
186
bf16_elapsed_time = benchmark_model (m_bf16 , RUNS , example_inputs )
151
187
152
- print (f"{ (M , N , K )} : elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } " )
188
+ print (
189
+ f"{ (M , N , K )} : elapsed time: { elapsed_time } , ref elapsed time: { ref_elapsed_time } , bf16 elapsed time: { bf16_elapsed_time } "
190
+ )
191
+
153
192
154
193
if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch .cuda .is_available ():
155
194
all_shapes = [
@@ -158,16 +197,25 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
158
197
159
198
print ("_int8da_int8w_api" )
160
199
from torchao .quantization .quant_api import change_linear_weights_to_int8_dqtensors
200
+
161
201
for M , N , K in all_shapes :
162
- _bench_quantized_tensor_subclass_perf (_int8da_int8w_api , _ref_change_linear_weights_to_int8_dqtensors , M , N , K )
202
+ _bench_quantized_tensor_subclass_perf (
203
+ _int8da_int8w_api , _ref_change_linear_weights_to_int8_dqtensors , M , N , K
204
+ )
163
205
164
206
print ("_int8wo_api" )
165
207
from torchao .quantization .quant_api import change_linear_weights_to_int8_woqtensors
208
+
166
209
for M , N , K in all_shapes :
167
- _bench_quantized_tensor_subclass_perf (_int8wo_api , _ref_change_linear_weights_to_int8_woqtensors , M , N , K )
210
+ _bench_quantized_tensor_subclass_perf (
211
+ _int8wo_api , _ref_change_linear_weights_to_int8_woqtensors , M , N , K
212
+ )
168
213
169
214
print ("_int4wo_api" )
170
215
kwargs = {"groupsize" : 32 }
171
216
from torchao .quantization .quant_api import change_linear_weights_to_int4_woqtensors
217
+
172
218
for M , N , K in all_shapes :
173
- _bench_quantized_tensor_subclass_perf (_int4wo_api , _ref_change_linear_weights_to_int4_woqtensors , M , N , K , kwargs )
219
+ _bench_quantized_tensor_subclass_perf (
220
+ _int4wo_api , _ref_change_linear_weights_to_int4_woqtensors , M , N , K , kwargs
221
+ )
0 commit comments