@@ -69,13 +69,13 @@ def log_shape(act_mat, w_autoquant, bias):
69
69
y += bias
70
70
return y
71
71
72
- def tune_autoquant (self , q_cls ):
72
+ def tune_autoquant (self , q_cls , best_time ):
73
73
act_shape , w_shape , bias_shape = self .logged_shape
74
74
if check_cache (q_cls , self .logged_shape , self .logged_dtype ) is None :
75
75
with torch .no_grad ():
76
76
act_mat = torch .randn (act_shape , dtype = self .logged_dtype , device = self .device )
77
77
bias = None if bias_shape is None else torch .randn (bias_shape , dtype = self .logged_dtype , device = self .device )
78
- res = q_cls ._autoquant_test (act_mat , self .weight , bias )
78
+ res = q_cls ._autoquant_test (act_mat , self .weight , bias , best_time )
79
79
update_cache (q_cls , self .logged_shape , self .logged_dtype , res )
80
80
81
81
def to_quantized (self , error_on_unseen , ** kwargs ):
@@ -91,7 +91,7 @@ def to_quantized(self, error_on_unseen, **kwargs):
91
91
for q_cls in self .qtensor_class_list :
92
92
if check_cache (q_cls , self .logged_shape , self .logged_dtype ) is None :
93
93
do_print = True
94
- self .tune_autoquant (q_cls )
94
+ self .tune_autoquant (q_cls , best_time )
95
95
torch ._dynamo .reset ()
96
96
cls_res = AUTOQUANT_CACHE .get ((q_cls , self .logged_shape , self .logged_dtype ), torch .inf )
97
97
if best_time >= cls_res :
@@ -149,14 +149,12 @@ class AQMixin():
149
149
Mixin to turn normal quantized subclasses into autoquantizable ones
150
150
"""
151
151
@classmethod
152
- def _autoquant_test (cls , act_mat , weight , bias ):
152
+ def _autoquant_test (cls , act_mat , weight , bias , best_time , * args , ** kwargs ):
153
153
w_qtensor = cls .from_float (weight )
154
- func = lambda a , b , c : F .relu (cls ._quantized_op (F .relu (a ), b , c ))
155
- q_c_op = torch .compile (func , mode = "max-autotune" )
156
- # q_c_op = torch.compile(cls._quantized_op, mode="max-autotune")
154
+ q_c_op = torch .compile (cls ._quantized_op , mode = "max-autotune" )
157
155
with torch .no_grad ():
158
156
torch .cuda .synchronize ()
159
- res = benchmark (q_c_op , act_mat , w_qtensor , bias )
157
+ res = benchmark (q_c_op , act_mat , w_qtensor , bias , best_time = best_time )
160
158
print (cls , res )
161
159
return res
162
160
@@ -165,8 +163,9 @@ class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLi
165
163
AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight
166
164
"""
167
165
@classmethod
168
- def _autoquant_test (cls , act_mat , weight , bias ):
169
- res = super ()._autoquant_test (act_mat , weight , bias )
166
+ def _autoquant_test (cls , act_mat , weight , bias , best_time ):
167
+ # SAM best is between .51 to .60, SDXL also performs best in this range
168
+ INTERPOLATION_CONSTANT = .55
170
169
w_qtensor = cls .from_float (weight )
171
170
x_vals_int8 , x_scales = quantize_activation_per_token_absmax (
172
171
act_mat .reshape (- 1 , act_mat .shape [- 1 ])
@@ -177,10 +176,18 @@ def _autoquant_test(cls, act_mat, weight, bias):
177
176
)
178
177
q_c_matmul = torch .compile (quantized_matmul , mode = "max-autotune" )
179
178
with torch .no_grad ():
180
- res2 = benchmark (q_c_matmul , x_vals_int8 , x_scales , w_qtensor .int_data )
181
- print (cls , "matmul" , res2 )
182
- # for SAM best is between .458-.499, SDXL .45=3.094 .47=2.880 .48=3.036 .5=2.930
183
- return res
179
+ res_matmul = benchmark (q_c_matmul , x_vals_int8 , x_scales , w_qtensor .int_data , best_time = best_time )
180
+ print (cls , "matmul" , res_matmul )
181
+
182
+ # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
183
+ if res_matmul >= best_time :
184
+ return res_matmul
185
+
186
+ # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
187
+ to_beat = best_time + INTERPOLATION_CONSTANT / (1 - INTERPOLATION_CONSTANT )* (best_time - res_matmul )
188
+ res = super ()._autoquant_test (act_mat , weight , bias , to_beat )
189
+ print (cls , "full" , INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT )* res_matmul )
190
+ return INTERPOLATION_CONSTANT * res + (1 - INTERPOLATION_CONSTANT )* res_matmul
184
191
185
192
186
193
class AQWeightOnlyQuantizedLinearWeight (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
@@ -205,11 +212,11 @@ def _quantized_op(act_mat, w_qtensor, bias):
205
212
return y .to (orig_dtype )
206
213
207
214
@classmethod
208
- def _autoquant_test (cls , act_mat , weight , bias ):
215
+ def _autoquant_test (cls , act_mat , weight , bias , best_time ):
209
216
# if act_mat has batchsize>2 don't use this kernel
210
217
if act_mat .reshape (- 1 , act_mat .shape [- 1 ]).shape [0 ]> 2 :
211
218
return torch .inf
212
- return super ()._autoquant_test (act_mat , weight , bias )
219
+ return super ()._autoquant_test (act_mat , weight , bias , best_time )
213
220
214
221
class AQWeightOnlyQuantizedLinearWeight3 (Int8WeightOnlyQuantizedLinearWeight , AQMixin ):
215
222
def _quantized_op (act_mat , w_qtensor , bias ):
@@ -246,42 +253,3 @@ def from_float(cls, weight):
246
253
AQWeightOnlyQuantizedLinearWeight2 ,
247
254
AQWeightOnlyQuantizedLinearWeight3 ,
248
255
]
249
-
250
- if False :
251
- # def _get_to_kwargs(self, *args, **kwargs):
252
- # device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
253
- # device = self.device if device is None else device
254
- # dtype = self.dtype if dtype is None else dtype
255
- # memory_format = (
256
- # memory_format if memory_format is not None else torch.preserve_format
257
- # )
258
- # kwargs = {
259
- # "device": device,
260
- # "dtype": dtype,
261
- # "memory_format": memory_format,
262
- # }
263
- # return kwargs
264
-
265
- # def to(self, *args, **kwargs):
266
- # kwargs = self._get_to_kwargs(*args, **kwargs)
267
- # return self.__class__(
268
- # self.int_data.to(kwargs["device"]),
269
- # self.q_scales.to(kwargs["device"]),
270
- # self.transposed,
271
- # self.shape,
272
- # **kwargs,
273
- # )
274
-
275
- # def _apply_fn_to_data(self, fn):
276
- # return self.__class__(
277
- # fn(self.int_data), fn(self.q_scales), self.transposed, self.shape, dtype=self.dtype
278
- # )
279
-
280
- # def _change_shape(self, shape):
281
- # return self.__class__(
282
- # self.int_data, self.q_scales, self.transposed, shape, dtype=self.dtype
283
- # )
284
-
285
- # def half(self):
286
- # return self.to(torch.float16)
287
- pass
0 commit comments