|
| 1 | +import torch |
| 2 | +import os |
| 3 | +from subprocess import check_output |
| 4 | +from .subclass import ( # noqa |
| 5 | + Int8DynamicallyQuantizedLinearWeight, |
| 6 | + Int8WeightOnlyQuantizedLinearWeight, |
| 7 | + QuantizedLinearWeightBase, |
| 8 | +) |
| 9 | +from torch.utils._python_dispatch import return_and_correct_aliasing |
| 10 | +from .quant_primitives import ( |
| 11 | + quantize_activation_per_token_absmax, |
| 12 | + safe_int_mm, |
| 13 | +) |
| 14 | +import torch.nn.functional as F |
| 15 | +from torch._inductor.utils import do_bench |
| 16 | +aten = torch.ops.aten |
| 17 | + |
| 18 | +AUTOQUANT_CACHE = {} |
| 19 | + |
| 20 | +def check_cache(cls, shapes_and_dtype): |
| 21 | + return AUTOQUANT_CACHE.get((cls,)+shapes_and_dtype, None) |
| 22 | + |
| 23 | +def update_cache(cls, shapes_and_dtype, res): |
| 24 | + AUTOQUANT_CACHE[(cls,)+shapes_and_dtype] = res |
| 25 | + |
| 26 | +class AutoQuantizableLinearWeight(torch.Tensor): |
| 27 | + """ |
| 28 | + when run, finds best type of quantization for this tensor and swaps itself with that |
| 29 | + """ |
| 30 | + @staticmethod |
| 31 | + def __new__(cls, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): |
| 32 | + kwargs["device"] = weight.device |
| 33 | + kwargs["layout"] = ( |
| 34 | + kwargs.get("layout") if kwargs.get("layout", False) else weight.layout |
| 35 | + ) |
| 36 | + kwargs["dtype"] = ( |
| 37 | + kwargs.get("dtype") if kwargs.get("dtype", False) else weight.dtype |
| 38 | + ) |
| 39 | + kwargs["requires_grad"] = False |
| 40 | + shape = kwargs.pop("shape", weight.shape) |
| 41 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 42 | + |
| 43 | + def __init__(self, weight, qtensor_class_list, *args, mode=["relu", None], **kwargs): |
| 44 | + self.weight = weight |
| 45 | + self.qtensor_class_list = qtensor_class_list |
| 46 | + self.logged_data = {} |
| 47 | + self.mode = mode |
| 48 | + |
| 49 | + def __repr__(self): |
| 50 | + return ( |
| 51 | + f"{self.__class__.__name__}(data={self.weight}, shape={self.shape}, " |
| 52 | + f"device={self.device}, dtype={self.dtype}, qtensor_class_list={self.qtensor_class_list})" |
| 53 | + ) |
| 54 | + |
| 55 | + @staticmethod |
| 56 | + def log_shape(act_mat, w_autoquant, bias): |
| 57 | + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) |
| 58 | + logged_dtype = act_mat.dtype |
| 59 | + logged_shapes = (act_mat.shape, w_autoquant.shape, None if bias is None else bias.shape,) |
| 60 | + shapes_and_dtype = logged_shapes + (logged_dtype,) |
| 61 | + w_autoquant.logged_data[shapes_and_dtype] = 1 + w_autoquant.logged_data.get(shapes_and_dtype, 0) |
| 62 | + for q_cls in w_autoquant.qtensor_class_list: |
| 63 | + if check_cache(q_cls, shapes_and_dtype) is None: |
| 64 | + update_cache(q_cls, shapes_and_dtype, None) |
| 65 | + |
| 66 | + def tune_autoquant(self, q_cls, shapes_and_dtype, best_time): |
| 67 | + act_shape, w_shape, bias_shape, act_dtype = shapes_and_dtype |
| 68 | + if check_cache(q_cls, shapes_and_dtype) is None: |
| 69 | + with torch.no_grad(): |
| 70 | + act_mat = torch.randn(act_shape, dtype=act_dtype, device=self.device) |
| 71 | + bias = None if bias_shape is None else torch.randn(bias_shape, dtype=act_dtype, device=self.device) |
| 72 | + res = q_cls._autoquant_test(act_mat, self.weight, bias, best_time, self.mode) |
| 73 | + update_cache(q_cls, shapes_and_dtype, res) |
| 74 | + |
| 75 | + def to_quantized(self, error_on_unseen, **kwargs): |
| 76 | + if error_on_unseen and self.logged_data == {}: |
| 77 | + raise RuntimeError("must run module normally to get shape, dtype info for autoquant") |
| 78 | + elif (self.logged_data == {}) and not error_on_unseen: |
| 79 | + # default back to non-quantized weight if not seen |
| 80 | + self = AQFloatLinearWeight.from_float(self.weight) |
| 81 | + return self |
| 82 | + best_time = torch.inf |
| 83 | + best_cls = None |
| 84 | + do_print=False |
| 85 | + # check each class |
| 86 | + for q_cls in self.qtensor_class_list: |
| 87 | + # for each logged shape+dtype, benchmark |
| 88 | + cls_res=0 |
| 89 | + for shapes_and_dtype, times_seen in self.logged_data.items(): |
| 90 | + if check_cache(q_cls, shapes_and_dtype) is None: |
| 91 | + do_print=True |
| 92 | + self.tune_autoquant(q_cls, shapes_and_dtype, best_time) |
| 93 | + torch._dynamo.reset() |
| 94 | + cls_res += check_cache(q_cls, shapes_and_dtype) * times_seen |
| 95 | + if best_time >= cls_res: |
| 96 | + best_time = cls_res |
| 97 | + best_cls = q_cls |
| 98 | + # only print if this is the first time seeing some cls+shape combo, |
| 99 | + # otherwise we will print the same thing for every layer. |
| 100 | + if do_print: |
| 101 | + print(f"for {self.logged_data}, best_cls={best_cls}") |
| 102 | + # TODO handle random cls args/kwargs? or should they be curried? |
| 103 | + self = best_cls.from_float(self.weight) |
| 104 | + return self |
| 105 | + |
| 106 | + def _apply_fn_to_data(self, fn): |
| 107 | + return self.__class__( |
| 108 | + fn(self.weight), self.qtensor_class_list, dtype=self.dtype, mode=self.mode |
| 109 | + ) |
| 110 | + |
| 111 | + def __tensor_flatten__(self): |
| 112 | + return ["weight"], [self.qtensor_class_list, self.mode, self.dtype, self.shape] |
| 113 | + |
| 114 | + @classmethod |
| 115 | + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): |
| 116 | + weight = tensor_data_dict["weight"] |
| 117 | + qtensor_class_list, mode, dtype, shape = tensor_attributes[0] |
| 118 | + return cls(weight, qtensor_class_list, mode, shape=shape if outer_size is None else outer_size, dtype=dtype, strides=outer_stride) |
| 119 | + |
| 120 | + @classmethod |
| 121 | + def from_float(cls, weight, qtensor_class_list, **kwargs): |
| 122 | + return cls(weight, qtensor_class_list, **kwargs) |
| 123 | + |
| 124 | + @classmethod |
| 125 | + def __torch_function__(cls, func, types, args=(), kwargs=None): |
| 126 | + kwargs = {} if kwargs is None else kwargs |
| 127 | + |
| 128 | + if func is torch.nn.functional.linear: |
| 129 | + mat1, w_autoquant, bias = ( |
| 130 | + args[0], |
| 131 | + args[1], |
| 132 | + args[2] if len(args)>2 else None |
| 133 | + ) |
| 134 | + cls.log_shape(mat1, w_autoquant, bias) |
| 135 | + return func(mat1, w_autoquant.weight, bias) |
| 136 | + try: |
| 137 | + with torch._C.DisableTorchFunctionSubclass(): |
| 138 | + return func(*args, **kwargs) |
| 139 | + except: |
| 140 | + print(f"ERR: subclass doesn't implement {func}") |
| 141 | + |
| 142 | + @classmethod |
| 143 | + def __torch_dispatch__(cls, func, types, args, kwargs): |
| 144 | + if func is aten.detach.default: |
| 145 | + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) |
| 146 | + |
| 147 | +def do_autoquant_bench(op, *args, **kwargs): |
| 148 | + rep = kwargs.pop("rep", 100) |
| 149 | + warmup = kwargs.pop("warmup", 25) |
| 150 | + with torch.no_grad(): |
| 151 | + torch.cuda.synchronize() |
| 152 | + stream = torch.cuda.Stream() |
| 153 | + stream.wait_stream(torch.cuda.current_stream()) |
| 154 | + with torch.cuda.stream(stream): |
| 155 | + op(*args) |
| 156 | + stream.synchronize() |
| 157 | + torch.cuda.current_stream().wait_stream(stream) |
| 158 | + torch.cuda.synchronize() |
| 159 | + |
| 160 | + graph = torch.cuda.CUDAGraph() |
| 161 | + with torch.cuda.graph(graph, stream=stream): |
| 162 | + op(*args) |
| 163 | + res = do_bench(lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median") |
| 164 | + return res |
| 165 | + |
| 166 | +def _is_interpolate_mode(mode): |
| 167 | + if isinstance(mode, list) and mode[0]=="interpolate" and len(mode)==2 and isinstance(mode[1], float): |
| 168 | + return True |
| 169 | + return False |
| 170 | + |
| 171 | +class AQMixin(): |
| 172 | + """ |
| 173 | + Mixin to turn normal quantized subclasses into autoquantizable ones |
| 174 | + """ |
| 175 | + @classmethod |
| 176 | + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): |
| 177 | + w_qtensor = cls.from_float(weight) |
| 178 | + if _is_interpolate_mode(mode): |
| 179 | + q_c_op = torch.compile(cls._quantized_op, mode="max-autotune-no-cudagraphs") |
| 180 | + else: |
| 181 | + func = lambda a,b,c: F.relu(cls._quantized_op(F.relu(a), b, c)) |
| 182 | + q_c_op = torch.compile(func, mode="max-autotune-no-cudagraphs") |
| 183 | + res = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias) |
| 184 | + if res < best_time*1.1: |
| 185 | + res2 = do_autoquant_bench(q_c_op, act_mat, w_qtensor, bias, warmup=25, rep=900) |
| 186 | + res=(res2*.9+res*.1) |
| 187 | + print(f"time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") |
| 188 | + return res |
| 189 | + |
| 190 | +class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, Int8DynamicallyQuantizedLinearWeight): |
| 191 | + """ |
| 192 | + AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight |
| 193 | + """ |
| 194 | + @classmethod |
| 195 | + def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): |
| 196 | + if not _is_interpolate_mode(mode): |
| 197 | + return super()._autoquant_test(act_mat, weight, bias, best_time, mode) |
| 198 | + |
| 199 | + # SAM best is between .8 to 1, SDXL also performs best in this range |
| 200 | + INTERPOLATION_CONSTANT = mode[1] |
| 201 | + w_qtensor = cls.from_float(weight) |
| 202 | + x_vals_int8, x_scales = quantize_activation_per_token_absmax( |
| 203 | + act_mat.reshape(-1, act_mat.shape[-1]) |
| 204 | + ) |
| 205 | + quantized_matmul = ( |
| 206 | + lambda x_vals_int8, x_scales, w_vals_int8: |
| 207 | + safe_int_mm(x_vals_int8, w_vals_int8) * x_scales |
| 208 | + ) |
| 209 | + q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs") |
| 210 | + with torch.no_grad(): |
| 211 | + res_matmul = do_autoquant_bench(q_c_matmul, x_vals_int8, x_scales, w_qtensor.int_data) |
| 212 | + print(f"time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") |
| 213 | + |
| 214 | + # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op |
| 215 | + if res_matmul>=best_time: |
| 216 | + return res_matmul |
| 217 | + |
| 218 | + # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT |
| 219 | + to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) |
| 220 | + res = super()._autoquant_test(act_mat, weight, bias, to_beat) |
| 221 | + max_int_const_win = (best_time-res_matmul)/(res-res_matmul) |
| 222 | + res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul |
| 223 | + print(f"time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_int_const_win:0.2f}") |
| 224 | + return res_f |
| 225 | + |
| 226 | +class AQWeightOnlyQuantizedLinearWeight(Int8WeightOnlyQuantizedLinearWeight, AQMixin): |
| 227 | + """ |
| 228 | + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight |
| 229 | + """ |
| 230 | + |
| 231 | +class AQWeightOnlyQuantizedLinearWeight2(Int8WeightOnlyQuantizedLinearWeight, AQMixin): |
| 232 | + """ |
| 233 | + AutoQuantizable version of Int8WeightOnlyQuantizedLinearWeight that |
| 234 | + uses a different kernel |
| 235 | + """ |
| 236 | + @staticmethod |
| 237 | + def _quantized_op(act_mat, w_qtensor, bias): |
| 238 | + orig_dtype = act_mat.dtype |
| 239 | + orig_shape = act_mat.shape |
| 240 | + act_mat = act_mat.reshape(-1, act_mat.shape[-1], 1) |
| 241 | + y = (act_mat*w_qtensor.int_data.unsqueeze(0)).sum(dim=-2) |
| 242 | + y = y.reshape(*orig_shape[:-1], y.shape[-1]) * w_qtensor.q_scales |
| 243 | + if bias is not None: |
| 244 | + y += bias |
| 245 | + return y.to(orig_dtype) |
| 246 | + |
| 247 | + @classmethod |
| 248 | + def _autoquant_test(cls, act_mat, *args): |
| 249 | + # if act_mat has batchsize>2 don't use this kernel |
| 250 | + if act_mat.reshape(-1, act_mat.shape[-1]).shape[0]>32: |
| 251 | + return torch.inf |
| 252 | + return super()._autoquant_test(act_mat, *args) |
| 253 | + |
| 254 | +class AQWeightOnlyQuantizedLinearWeight3(Int8WeightOnlyQuantizedLinearWeight, AQMixin): |
| 255 | + def _quantized_op(act_mat, w_qtensor, bias): |
| 256 | + orig_shape = act_mat.shape |
| 257 | + y = torch.mm(act_mat.reshape(-1, orig_shape[-1]), w_qtensor.int_data*w_qtensor.q_scales) |
| 258 | + y=y.reshape(*orig_shape[:-1], y.shape[-1]) |
| 259 | + if bias is not None: |
| 260 | + y += bias |
| 261 | + return y |
| 262 | + |
| 263 | +class AQFloatLinearWeight(torch.Tensor, AQMixin): |
| 264 | + """ |
| 265 | + A class to be used in concert with AutoQuantizableLinearWeight to provide a |
| 266 | + default/non-quantized option. Only implements the bare minimum needed to work with the |
| 267 | + AutoQuantizableLinearWeight class using the same interfaces that would normally be |
| 268 | + used by QTensor subclasses but for a default linear op instead. |
| 269 | + """ |
| 270 | + def __init__(self): |
| 271 | + super().__init__() |
| 272 | + |
| 273 | + @staticmethod |
| 274 | + def _quantized_op(act_mat, w_qtensor, bias): |
| 275 | + return torch.nn.functional.linear(act_mat, w_qtensor, bias) |
| 276 | + |
| 277 | + @classmethod |
| 278 | + def from_float(cls, weight): |
| 279 | + return weight |
| 280 | + |
| 281 | +DEFAULT_CLASS_LIST = [ |
| 282 | + AQFloatLinearWeight, |
| 283 | + AQInt8DynamicallyQuantizedLinearWeight, |
| 284 | + AQWeightOnlyQuantizedLinearWeight, |
| 285 | + AQWeightOnlyQuantizedLinearWeight2, |
| 286 | + # AQWeightOnlyQuantizedLinearWeight3, |
| 287 | + # 3rd version gets picked in situations where it is slower for the interpolation mode |
| 288 | +] |
0 commit comments