7
7
)
8
8
from torch .utils ._python_dispatch import return_and_correct_aliasing
9
9
from .utils import benchmark
10
+ from .quant_primitives import (
11
+ quantize_activation_per_token_absmax ,
12
+ safe_int_mm ,
13
+ )
14
+ import torch .nn .functional as F
10
15
11
16
aten = torch .ops .aten
12
17
@@ -70,23 +75,30 @@ def tune_autoquant(self, q_cls):
70
75
with torch .no_grad ():
71
76
act_mat = torch .randn (act_shape , dtype = self .logged_dtype , device = self .device )
72
77
bias = None if bias_shape is None else torch .randn (bias_shape , dtype = self .logged_dtype , device = self .device )
73
- print (q_cls , self .logged_shape , self .logged_dtype )
74
- print ("mem" , torch .cuda .max_memory_allocated ()/ 1e6 , torch .cuda .memory_usage ())
75
78
res = q_cls ._autoquant_test (act_mat , self .weight , bias )
76
79
update_cache (q_cls , self .logged_shape , self .logged_dtype , res )
77
80
78
- def to_quantized (self ):
79
- if self .logged_shape is None or self .logged_dtype is None :
81
+ def to_quantized (self , error_on_unseen , ** kwargs ):
82
+ if error_on_unseen and ( self .logged_shape is None or self .logged_dtype is None ) :
80
83
raise RuntimeError ("must run module normally to get shape, dtype info for autoquant" )
84
+ elif (self .logged_shape is None or self .logged_dtype is None ) and not error_on_unseen :
85
+ # default back to non-quantized weight if not seen
86
+ self = AQFloatLinearWeight .from_float (self .weight )
87
+ return self
81
88
best_time = torch .inf
82
89
best_cls = None
90
+ do_print = False
83
91
for q_cls in self .qtensor_class_list :
84
92
if check_cache (q_cls , self .logged_shape , self .logged_dtype ) is None :
93
+ do_print = True
85
94
self .tune_autoquant (q_cls )
95
+ torch ._dynamo .reset ()
86
96
cls_res = AUTOQUANT_CACHE .get ((q_cls , self .logged_shape , self .logged_dtype ), torch .inf )
87
97
if best_time >= cls_res :
88
98
best_time = cls_res
89
99
best_cls = q_cls
100
+ if do_print :
101
+ print (f"shape={ self .logged_shape } , dtype={ self .logged_dtype } , best_cls={ best_cls } " )
90
102
# TODO handle random cls args/kwargs? or should they be curried
91
103
self = best_cls .from_float (self .weight )
92
104
return self
@@ -132,26 +144,93 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
132
144
if func is aten .detach .default :
133
145
return return_and_correct_aliasing (func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach ))
134
146
135
-
136
- class DefaultLinear (torch .Tensor ):
147
+ class AQMixin ():
137
148
"""
138
- An class to be used in concert with AutoQuantizableLinearWeight to provide a
139
- default/non-quantized option. Only implements the bare minimum needed to work with the
140
- AutoQuantizableLinearWeight class using the same interfaces that would normally be
141
- used by QTensor subclasses but for a default linear op instead.
149
+ Mixin to turn normal quantized subclasses into autoquantizable ones
142
150
"""
143
- def __init__ (self ):
144
- super ().__init__ ()
145
-
146
151
@classmethod
147
152
def _autoquant_test (cls , act_mat , weight , bias ):
148
153
w_qtensor = cls .from_float (weight )
149
- q_c_op = torch .compile (cls ._quantized_op , mode = "max-autotune" )
154
+ func = lambda act_mat , w_qtensor , bias : F .relu (cls ._quantized_op (F .relu (act_mat ), w_qtensor , bias ))
155
+ q_c_op = torch .compile (func , mode = "max-autotune" )
156
+ # q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
150
157
with torch .no_grad ():
151
- res = benchmark (q_c_op , act_mat , w_qtensor , bias )
158
+ torch .cuda .synchronize ()
159
+ res = benchmark (q_c_op , act_mat , w_qtensor , bias )
152
160
print (cls , res )
153
161
return res
154
162
163
+ class AQInt8DynamicallyQuantizedLinearWeight (AQMixin , Int8DynamicallyQuantizedLinearWeight ):
164
+ """
165
+ AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
166
+ """
167
+ @classmethod
168
+ def _autoquant_test (cls , act_mat , weight , bias ):
169
+ res = super ()._autoquant_test (act_mat , weight , bias )
170
+ w_qtensor = cls .from_float (weight )
171
+ x_vals_int8 , x_scales = quantize_activation_per_token_absmax (
172
+ act_mat .reshape (- 1 , act_mat .shape [- 1 ])
173
+ )
174
+ quantized_matmul = (
175
+ lambda x_vals_int8 , x_scales , w_vals_int8 :
176
+ safe_int_mm (x_vals_int8 , w_vals_int8 ) * x_scales
177
+ )
178
+ q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune" )
179
+ with torch .no_grad ():
180
+ res2 = benchmark (q_c_matmul , x_vals_int8 , x_scales , w_qtensor .int_data )
181
+ print (cls , "matmul" , res2 )
182
+ # for SAM best is between .458-.499, SDXL .45=3.094 .47=2.880 .48=3.036 .5=2.930
183
+ return res
184
+
185
+
186
+ class AQWeightOnlyQuantizedLinearWeight (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
187
+ """
188
+ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight
189
+ """
190
+
191
+ class AQWeightOnlyQuantizedLinearWeight2 (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
192
+ """
193
+ AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that
194
+ uses a different kernel
195
+ """
196
+ @staticmethod
197
+ def _quantized_op (act_mat , w_qtensor , bias ):
198
+ orig_dtype = act_mat .dtype
199
+ orig_shape = act_mat .shape
200
+ act_mat = act_mat .reshape (- 1 , act_mat .shape [- 1 ], 1 )
201
+ y = (act_mat * w_qtensor .int_data .unsqueeze (0 )).sum (dim = - 2 )
202
+ y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ])
203
+ if bias is not None :
204
+ y += bias
205
+ return y .to (orig_dtype )
206
+
207
+ @classmethod
208
+ def _autoquant_test (cls , act_mat , weight , bias ):
209
+ # if act_mat has batchsize>2 don't use this kernel
210
+ if act_mat .reshape (- 1 , act_mat .shape [- 1 ]).shape [0 ]> 2 :
211
+ return torch .inf
212
+ return super ()._autoquant_test (act_mat , weight , bias )
213
+
214
+ class AQWeightOnlyQuantizedLinearWeight3 (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
215
+ def _quantized_op (act_mat , w_qtensor , bias ):
216
+ orig_shape = act_mat .shape
217
+ y = torch .mm (act_mat .reshape (- 1 , orig_shape [- 1 ]), w_qtensor .int_data * w_qtensor .q_scales )
218
+ y = y .reshape (* orig_shape [:- 1 ], y .shape [- 1 ])
219
+ if bias is not None :
220
+ y += bias
221
+ return y
222
+
223
+
224
+ class AQFloatLinearWeight (torch .Tensor , AQMixin ):
225
+ """
226
+ A class to be used in concert with AutoQuantizableLinearWeight to provide a
227
+ default/non-quantized option. Only implements the bare minimum needed to work with the
228
+ AutoQuantizableLinearWeight class using the same interfaces that would normally be
229
+ used by QTensor subclasses but for a default linear op instead.
230
+ """
231
+ def __init__ (self ):
232
+ super ().__init__ ()
233
+
155
234
@staticmethod
156
235
def _quantized_op (act_mat , w_qtensor , bias ):
157
236
return torch .nn .functional .linear (act_mat , w_qtensor , bias )
@@ -161,10 +240,11 @@ def from_float(cls, weight):
161
240
return weight
162
241
163
242
DEFAULT_CLASS_LIST = [
164
- Int8DynamicallyQuantizedLinearWeight ,
165
- DefaultLinear ,
166
- Int8WeightOnlyQuantizedLinearWeight ,
167
-
243
+ AQFloatLinearWeight ,
244
+ AQInt8DynamicallyQuantizedLinearWeight ,
245
+ AQWeightOnlyQuantizedLinearWeight ,
246
+ AQWeightOnlyQuantizedLinearWeight2 ,
247
+ AQWeightOnlyQuantizedLinearWeight3 ,
168
248
]
169
249
170
250
if False :
0 commit comments