1
1
import torch
2
-
2
+ import os
3
+ from subprocess import check_output
3
4
from .subclass import ( # noqa
4
5
Int8DynamicallyQuantizedLinearWeight ,
5
6
Int8WeightOnlyQuantizedLinearWeight ,
6
7
QuantizedLinearWeightBase ,
7
8
)
8
9
from torch .utils ._python_dispatch import return_and_correct_aliasing
9
- from .utils import benchmark
10
10
from .quant_primitives import (
11
11
quantize_activation_per_token_absmax ,
12
12
safe_int_mm ,
13
13
)
14
14
import torch .nn .functional as F
15
-
15
+ from torch . _inductor . utils import do_bench
16
16
aten = torch .ops .aten
17
17
18
18
AUTOQUANT_CACHE = {}
19
19
20
- def check_cache (cls , shape , dtype ):
21
- return AUTOQUANT_CACHE .get ((cls , shape , dtype ) , None )
20
+ def check_cache (cls , shapes_and_dtype ):
21
+ return AUTOQUANT_CACHE .get ((cls ,) + shapes_and_dtype , None )
22
22
23
- def update_cache (cls , shape , dtype , res ):
24
- AUTOQUANT_CACHE [(cls , shape , dtype ) ] = res
23
+ def update_cache (cls , shapes_and_dtype , res ):
24
+ AUTOQUANT_CACHE [(cls ,) + shapes_and_dtype ] = res
25
25
26
26
class AutoQuantizableLinearWeight (torch .Tensor ):
27
27
"""
28
28
when run, finds best type of quantization for this tensor and swaps itself with that
29
29
"""
30
30
@staticmethod
31
- def __new__ (cls , weight , qtensor_class_list , * args , ** kwargs ):
31
+ def __new__ (cls , weight , qtensor_class_list , * args , mode = [ "relu" , None ], ** kwargs ):
32
32
kwargs ["device" ] = weight .device
33
33
kwargs ["layout" ] = (
34
34
kwargs .get ("layout" ) if kwargs .get ("layout" , False ) else weight .layout
@@ -40,11 +40,11 @@ def __new__(cls, weight, qtensor_class_list, *args, **kwargs):
40
40
shape = kwargs .pop ("shape" , weight .shape )
41
41
return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
42
42
43
- def __init__ (self , weight , qtensor_class_list , * args , ** kwargs ):
43
+ def __init__ (self , weight , qtensor_class_list , * args , mode = [ "relu" , None ], ** kwargs ):
44
44
self .weight = weight
45
45
self .qtensor_class_list = qtensor_class_list
46
- self .logged_shape = None
47
- self .logged_dtype = None
46
+ self .logged_data = {}
47
+ self .mode = mode
48
48
49
49
def __repr__ (self ):
50
50
return (
@@ -54,72 +54,72 @@ def __repr__(self):
54
54
55
55
@staticmethod
56
56
def log_shape (act_mat , w_autoquant , bias ):
57
- orig_shape = act_mat .shape
58
57
act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ])
59
- logged_shape = (act_mat .shape , w_autoquant .shape , None if bias is None else bias .shape )
60
58
logged_dtype = act_mat .dtype
61
- w_autoquant .logged_shape = logged_shape
62
- w_autoquant .logged_dtype = logged_dtype
59
+ logged_shapes = (act_mat .shape , w_autoquant .shape , None if bias is None else bias .shape ,)
60
+ shapes_and_dtype = logged_shapes + (logged_dtype ,)
61
+ w_autoquant .logged_data [shapes_and_dtype ] = 1 + w_autoquant .logged_data .get (shapes_and_dtype , 0 )
63
62
for q_cls in w_autoquant .qtensor_class_list :
64
- if check_cache (q_cls , logged_shape , logged_dtype ) is None :
65
- update_cache (q_cls , logged_shape , logged_dtype , None )
66
- y = torch .mm (act_mat , w_autoquant .weight .t ())
67
- y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ])
68
- if bias is not None :
69
- y += bias
70
- return y
63
+ if check_cache (q_cls , shapes_and_dtype ) is None :
64
+ update_cache (q_cls , shapes_and_dtype , None )
71
65
72
- def tune_autoquant (self , q_cls , best_time ):
73
- act_shape , w_shape , bias_shape = self . logged_shape
74
- if check_cache (q_cls , self . logged_shape , self . logged_dtype ) is None :
66
+ def tune_autoquant (self , q_cls , shapes_and_dtype , best_time ):
67
+ act_shape , w_shape , bias_shape , act_dtype = shapes_and_dtype
68
+ if check_cache (q_cls , shapes_and_dtype ) is None :
75
69
with torch .no_grad ():
76
- act_mat = torch .randn (act_shape , dtype = self . logged_dtype , device = self .device )
77
- bias = None if bias_shape is None else torch .randn (bias_shape , dtype = self . logged_dtype , device = self .device )
78
- res = q_cls ._autoquant_test (act_mat , self .weight , bias , best_time )
79
- update_cache (q_cls , self . logged_shape , self . logged_dtype , res )
70
+ act_mat = torch .randn (act_shape , dtype = act_dtype , device = self .device )
71
+ bias = None if bias_shape is None else torch .randn (bias_shape , dtype = act_dtype , device = self .device )
72
+ res = q_cls ._autoquant_test (act_mat , self .weight , bias , best_time , self . mode )
73
+ update_cache (q_cls , shapes_and_dtype , res )
80
74
81
75
def to_quantized (self , error_on_unseen , ** kwargs ):
82
- if error_on_unseen and ( self .logged_shape is None or self . logged_dtype is None ) :
76
+ if error_on_unseen and self .logged_data == {} :
83
77
raise RuntimeError ("must run module normally to get shape, dtype info for autoquant" )
84
- elif (self .logged_shape is None or self . logged_dtype is None ) and not error_on_unseen :
78
+ elif (self .logged_data == {} ) and not error_on_unseen :
85
79
# default back to non-quantized weight if not seen
86
80
self = AQFloatLinearWeight .from_float (self .weight )
87
- return self
81
+ return self
88
82
best_time = torch .inf
89
83
best_cls = None
90
84
do_print = False
85
+ # check each class
91
86
for q_cls in self .qtensor_class_list :
92
- if check_cache (q_cls , self .logged_shape , self .logged_dtype ) is None :
93
- do_print = True
94
- self .tune_autoquant (q_cls , best_time )
95
- torch ._dynamo .reset ()
96
- cls_res = AUTOQUANT_CACHE .get ((q_cls , self .logged_shape , self .logged_dtype ), torch .inf )
87
+ # for each logged shape+dtype, benchmark
88
+ cls_res = 0
89
+ for shapes_and_dtype , times_seen in self .logged_data .items ():
90
+ if check_cache (q_cls , shapes_and_dtype ) is None :
91
+ do_print = True
92
+ self .tune_autoquant (q_cls , shapes_and_dtype , best_time )
93
+ torch ._dynamo .reset ()
94
+ cls_res += check_cache (q_cls , shapes_and_dtype ) * times_seen
97
95
if best_time >= cls_res :
98
96
best_time = cls_res
99
97
best_cls = q_cls
98
+ # only print if this is the first time seeing some cls+shape combo,
99
+ # otherwise we will print the same thing for every layer.
100
100
if do_print :
101
- print (f"shape= { self . logged_shape } , dtype= { self .logged_dtype } , best_cls={ best_cls } " )
102
- # TODO handle random cls args/kwargs? or should they be curried
101
+ print (f"for { self .logged_data } , best_cls={ best_cls } " )
102
+ # TODO handle random cls args/kwargs? or should they be curried?
103
103
self = best_cls .from_float (self .weight )
104
104
return self
105
105
106
106
def _apply_fn_to_data (self , fn ):
107
107
return self .__class__ (
108
- fn (self .weight ), self .qtensor_class_list , dtype = self .dtype
108
+ fn (self .weight ), self .qtensor_class_list , dtype = self .dtype , mode = self . mode
109
109
)
110
110
111
111
def __tensor_flatten__ (self ):
112
- return ["weight" ], [self .qtensor_class_list , self .dtype , self .shape ]
112
+ return ["weight" ], [self .qtensor_class_list , self .mode , self . dtype , self .shape ]
113
113
114
114
@classmethod
115
115
def __tensor_unflatten__ (cls , tensor_data_dict , tensor_attributes , outer_size = None , outer_stride = None ):
116
116
weight = tensor_data_dict ["weight" ]
117
- qtensor_class_list , dtype , shape = tensor_attributes [0 ]
118
- return cls (weight , qtensor_class_list , shape = shape if outer_size is None else outer_size , dtype = dtype , strides = outer_stride )
117
+ qtensor_class_list , mode , dtype , shape = tensor_attributes [0 ]
118
+ return cls (weight , qtensor_class_list , mode , shape = shape if outer_size is None else outer_size , dtype = dtype , strides = outer_stride )
119
119
120
120
@classmethod
121
- def from_float (cls , weight , qtensor_class_list ):
122
- return cls (weight , qtensor_class_list )
121
+ def from_float (cls , weight , qtensor_class_list , ** kwargs ):
122
+ return cls (weight , qtensor_class_list , ** kwargs )
123
123
124
124
@classmethod
125
125
def __torch_function__ (cls , func , types , args = (), kwargs = None ):
@@ -131,8 +131,8 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
131
131
args [1 ],
132
132
args [2 ] if len (args )> 2 else None
133
133
)
134
- return cls .log_shape (mat1 , w_autoquant , bias )
135
-
134
+ cls .log_shape (mat1 , w_autoquant , bias )
135
+ return func ( mat1 , w_autoquant . weight , bias )
136
136
try :
137
137
with torch ._C .DisableTorchFunctionSubclass ():
138
138
return func (* args , ** kwargs )
@@ -144,28 +144,60 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
144
144
if func is aten .detach .default :
145
145
return return_and_correct_aliasing (func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach ))
146
146
147
+ def do_autoquant_bench (op , * args , ** kwargs ):
148
+ rep = kwargs .pop ("rep" , 100 )
149
+ warmup = kwargs .pop ("warmup" , 25 )
150
+ with torch .no_grad ():
151
+ torch .cuda .synchronize ()
152
+ stream = torch .cuda .Stream ()
153
+ stream .wait_stream (torch .cuda .current_stream ())
154
+ with torch .cuda .stream (stream ):
155
+ op (* args )
156
+ stream .synchronize ()
157
+ torch .cuda .current_stream ().wait_stream (stream )
158
+ torch .cuda .synchronize ()
159
+
160
+ graph = torch .cuda .CUDAGraph ()
161
+ with torch .cuda .graph (graph , stream = stream ):
162
+ op (* args )
163
+ res = do_bench (lambda : graph .replay (), warmup = warmup , rep = rep , return_mode = "median" )
164
+ return res
165
+
166
+ def _is_interpolate_mode (mode ):
167
+ if isinstance (mode , list ) and mode [0 ]== "interpolate" and len (mode )== 2 and isinstance (mode [1 ], float ):
168
+ return True
169
+ return False
170
+
147
171
class AQMixin ():
148
172
"""
149
173
Mixin to turn normal quantized subclasses into autoquantizable ones
150
174
"""
151
175
@classmethod
152
- def _autoquant_test (cls , act_mat , weight , bias , best_time , * args , ** kwargs ):
176
+ def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = [ "relu" , None ] ):
153
177
w_qtensor = cls .from_float (weight )
154
- q_c_op = torch .compile (cls ._quantized_op , mode = "max-autotune" )
155
- with torch .no_grad ():
156
- torch .cuda .synchronize ()
157
- res = benchmark (q_c_op , act_mat , w_qtensor , bias , best_time = best_time )
158
- print (cls , res )
178
+ if _is_interpolate_mode (mode ):
179
+ q_c_op = torch .compile (cls ._quantized_op , mode = "max-autotune-no-cudagraphs" )
180
+ else :
181
+ func = lambda a ,b ,c : F .relu (cls ._quantized_op (F .relu (a ), b , c ))
182
+ q_c_op = torch .compile (func , mode = "max-autotune-no-cudagraphs" )
183
+ res = do_autoquant_bench (q_c_op , act_mat , w_qtensor , bias )
184
+ if res < best_time * 1.1 :
185
+ res2 = do_autoquant_bench (q_c_op , act_mat , w_qtensor , bias , warmup = 25 , rep = 900 )
186
+ res = (res2 * .9 + res * .1 )
187
+ print (f"time: { res :0.3f} ms for { cls } , to_beat: { best_time :0.3f} ms " )
159
188
return res
160
189
161
190
class AQInt8DynamicallyQuantizedLinearWeight (AQMixin , Int8DynamicallyQuantizedLinearWeight ):
162
191
"""
163
192
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
164
193
"""
165
194
@classmethod
166
- def _autoquant_test (cls , act_mat , weight , bias , best_time ):
167
- # SAM best is between .51 to .60, SDXL also performs best in this range
168
- INTERPOLATION_CONSTANT = .55
195
+ def _autoquant_test (cls , act_mat , weight , bias , best_time , mode = ["relu" , None ]):
196
+ if not _is_interpolate_mode (mode ):
197
+ return super ()._autoquant_test (act_mat , weight , bias , best_time , mode )
198
+
199
+ # SAM best is between .8 to 1, SDXL also performs best in this range
200
+ INTERPOLATION_CONSTANT = mode [1 ]
169
201
w_qtensor = cls .from_float (weight )
170
202
x_vals_int8 , x_scales = quantize_activation_per_token_absmax (
171
203
act_mat .reshape (- 1 , act_mat .shape [- 1 ])
@@ -174,10 +206,10 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time):
174
206
lambda x_vals_int8 , x_scales , w_vals_int8 :
175
207
safe_int_mm (x_vals_int8 , w_vals_int8 ) * x_scales
176
208
)
177
- q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune" )
209
+ q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune-no-cudagraphs " )
178
210
with torch .no_grad ():
179
- res_matmul = benchmark (q_c_matmul , x_vals_int8 , x_scales , w_qtensor .int_data , best_time = best_time )
180
- print (cls , " matmul" , res_matmul )
211
+ res_matmul = do_autoquant_bench (q_c_matmul , x_vals_int8 , x_scales , w_qtensor .int_data )
212
+ print (f"time: { res_matmul :0.3f } ms for { cls } matmul, to_beat: { best_time :0.3f } ms" )
181
213
182
214
# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
183
215
if res_matmul >= best_time :
@@ -186,9 +218,10 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time):
186
218
# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
187
219
to_beat = best_time + INTERPOLATION_CONSTANT / (1 - INTERPOLATION_CONSTANT )* (best_time - res_matmul )
188
220
res = super ()._autoquant_test (act_mat , weight , bias , to_beat )
189
- print (cls , "full" , INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT )* res_matmul )
190
- return INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT )* res_matmul
191
-
221
+ max_int_const_win = (best_time - res_matmul )/ (res - res_matmul )
222
+ res_f = INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT )* res_matmul
223
+ print (f"time: { res_f :0.3f} ms for { cls } interpolated, breakeven constant: { max_int_const_win :0.2f} " )
224
+ return res_f
192
225
193
226
class AQWeightOnlyQuantizedLinearWeight (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
194
227
"""
@@ -206,17 +239,17 @@ def _quantized_op(act_mat, w_qtensor, bias):
206
239
orig_shape = act_mat .shape
207
240
act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ], 1 )
208
241
y = (act_mat * w_qtensor .int_data .unsqueeze (0 )).sum (dim = - 2 )
209
- y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ])
242
+ y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ]) * w_qtensor . q_scales
210
243
if bias is not None :
211
244
y += bias
212
245
return y .to (orig_dtype )
213
246
214
247
@classmethod
215
- def _autoquant_test (cls , act_mat , weight , bias , best_time ):
248
+ def _autoquant_test (cls , act_mat , * args ):
216
249
# if act_mat has batchsize>2 don't use this kernel
217
- if act_mat .reshape (- 1 , act_mat .shape [- 1 ]).shape [0 ]> 2 :
250
+ if act_mat .reshape (- 1 , act_mat .shape [- 1 ]).shape [0 ]> 32 :
218
251
return torch .inf
219
- return super ()._autoquant_test (act_mat , weight , bias , best_time )
252
+ return super ()._autoquant_test (act_mat , * args )
220
253
221
254
class AQWeightOnlyQuantizedLinearWeight3 (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
222
255
def _quantized_op (act_mat , w_qtensor , bias ):
@@ -227,7 +260,6 @@ def _quantized_op(act_mat, w_qtensor, bias):
227
260
y += bias
228
261
return y
229
262
230
-
231
263
class AQFloatLinearWeight (torch .Tensor , AQMixin ):
232
264
"""
233
265
A class to be used in concert with AutoQuantizableLinearWeight to provide a
@@ -251,5 +283,6 @@ def from_float(cls, weight):
251
283
AQInt8DynamicallyQuantizedLinearWeight ,
252
284
AQWeightOnlyQuantizedLinearWeight ,
253
285
AQWeightOnlyQuantizedLinearWeight2 ,
254
- AQWeightOnlyQuantizedLinearWeight3 ,
286
+ # AQWeightOnlyQuantizedLinearWeight3,
287
+ # 3rd version gets picked in situations where it is slower for the interpolation mode
255
288
]
0 commit comments