|
| 1 | +""" |
| 2 | +Prototype of QAT with exact (instead of emulated) forward pass using |
| 3 | +integer matrix multiply. |
| 4 | +
|
| 5 | +Quant spec: |
| 6 | +* int4 symmetric weights w/ group size 32 or 256, |
| 7 | +* int8 asymmetric per-token dynamic activations |
| 8 | +
|
| 9 | +""" |
| 10 | + |
| 11 | +import copy |
| 12 | + |
| 13 | +import fire |
| 14 | +import torch |
| 15 | +import torch.nn as nn |
| 16 | +from torch._higher_order_ops.out_dtype import out_dtype |
| 17 | + |
| 18 | +from torchao.float8.float8_utils import compute_error |
| 19 | +from torchao.prototype.qat_exact.triton_gemm import int8_matmul_triton |
| 20 | +from torchao.quantization import quantize_ |
| 21 | +from torchao.quantization.qat import ( |
| 22 | + FakeQuantizeConfig, |
| 23 | + IntXQuantizationAwareTrainingConfig, |
| 24 | +) |
| 25 | +from torchao.quantization.qat.fake_quantizer import FakeQuantizer |
| 26 | +from torchao.quantization.quant_primitives import ( |
| 27 | + _DTYPE_TO_QVALUE_BOUNDS, |
| 28 | + MappingType, |
| 29 | +) |
| 30 | +from torchao.quantization.utils import ( |
| 31 | + _get_per_token_block_size, |
| 32 | +) |
| 33 | + |
| 34 | +torch.manual_seed(0) |
| 35 | + |
| 36 | + |
| 37 | +def quantize_x(x_fp32): |
| 38 | + # Dynamic quantization of activation |
| 39 | + x_mapping_type = MappingType.ASYMMETRIC |
| 40 | + per_token_block_size = _get_per_token_block_size(x_fp32) |
| 41 | + x_quant_min, x_quant_max = _DTYPE_TO_QVALUE_BOUNDS[torch.int8] |
| 42 | + x_eps = torch.finfo(torch.float32).eps |
| 43 | + x_scales_type = torch.float32 |
| 44 | + x_zero_points_type = torch.int32 |
| 45 | + x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine( |
| 46 | + x_fp32, |
| 47 | + x_mapping_type.name, |
| 48 | + per_token_block_size, |
| 49 | + torch.int8, |
| 50 | + x_quant_min, |
| 51 | + x_quant_max, |
| 52 | + x_eps, |
| 53 | + x_scales_type, |
| 54 | + x_zero_points_type, |
| 55 | + ) |
| 56 | + x_i8 = torch.ops.torchao.quantize_affine( |
| 57 | + x_fp32, |
| 58 | + per_token_block_size, |
| 59 | + x_scale, |
| 60 | + x_zero_point, |
| 61 | + torch.int8, |
| 62 | + x_quant_min, |
| 63 | + x_quant_max, |
| 64 | + ) |
| 65 | + return x_i8, x_scale, x_zero_point |
| 66 | + |
| 67 | + |
| 68 | +def _xnnpack_integer_arithmetic_dqlinear_4bit_groupwise( |
| 69 | + x_i8, |
| 70 | + x_scale, |
| 71 | + x_zero_point, |
| 72 | + weight_int4, |
| 73 | + weight_scale, |
| 74 | + weight_zero_point, |
| 75 | + bias_fp32, |
| 76 | + group_size, |
| 77 | +): |
| 78 | + # For groupwise quantization, we need to handle the computation differently |
| 79 | + # weight_i4 shape: [out_features, in_features] |
| 80 | + # weight_scale shape: [out_features, in_features // group_size] |
| 81 | + # weight_zero_point shape: [out_features, in_features // group_size] |
| 82 | + out_features, in_features = weight_int4.shape |
| 83 | + num_groups = in_features // group_size |
| 84 | + |
| 85 | + # scales in xnnpack are stored as bf16 and converted to fp32 for computation |
| 86 | + weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32) |
| 87 | + |
| 88 | + assert x_i8.dim() == 2, "x_i8 must be 2D tensor" |
| 89 | + # Reshape for group-wise processing |
| 90 | + # x: [batch_size, in_features] -> [batch_size, num_groups, group_size] |
| 91 | + batch_size = x_i8.shape[0] |
| 92 | + x_i8_grouped = x_i8.view(batch_size, num_groups, group_size) |
| 93 | + |
| 94 | + # weight: [out_features, in_features] -> [out_features, num_groups, group_size] |
| 95 | + weight_i4_grouped = weight_int4.view(out_features, num_groups, group_size) |
| 96 | + |
| 97 | + # Convert to int16 for computation |
| 98 | + x_i32_grouped = x_i8_grouped.to(torch.int32) |
| 99 | + weight_i32_grouped = weight_i4_grouped.to(torch.int32) |
| 100 | + |
| 101 | + # Perform groupwise integer linear operation |
| 102 | + acc_fp32 = torch.zeros( |
| 103 | + batch_size, out_features, dtype=torch.float32, device=x_i8.device |
| 104 | + ) |
| 105 | + |
| 106 | + for group_idx in range(num_groups): |
| 107 | + # Extract current group |
| 108 | + x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size] |
| 109 | + weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size] |
| 110 | + weight_group_col_sum = weight_group.sum(dim=-1) # [out_features] |
| 111 | + |
| 112 | + # Get scale for this group |
| 113 | + weight_scale_group = weight_scale[:, group_idx] # [out_features] |
| 114 | + |
| 115 | + # Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features] |
| 116 | + group_acc = out_dtype( |
| 117 | + torch.ops.aten.linear.default, |
| 118 | + torch.int32, |
| 119 | + x_group, |
| 120 | + weight_group, |
| 121 | + None, |
| 122 | + ) |
| 123 | + |
| 124 | + # Output has to be scaled by x_scale * weight_scale_group |
| 125 | + # However we will first scale by weight_scale_group, that is accounting |
| 126 | + # only for scale of weight, and then scale by x_scale at the end because |
| 127 | + # x_scale applies to all groups |
| 128 | + acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view( |
| 129 | + 1, -1 |
| 130 | + ) |
| 131 | + |
| 132 | + # we must also subtract x_zero_point * weight_group_sum |
| 133 | + # since (X - x_zero_point) * W = X * W - x_zero_point * W |
| 134 | + weights_col_sum_adjusted = ( |
| 135 | + weight_group_col_sum.to(torch.float32).view(1, -1) |
| 136 | + * x_zero_point.view(-1, 1) |
| 137 | + * weight_scale_group.view(1, -1) |
| 138 | + ) |
| 139 | + acc_fp32 = acc_fp32 - weights_col_sum_adjusted |
| 140 | + x_scale_multiplier = x_scale.view(-1, 1) |
| 141 | + out_fp32 = acc_fp32 * x_scale_multiplier |
| 142 | + if bias_fp32 is not None: |
| 143 | + out_fp32 = out_fp32 + bias_fp32 |
| 144 | + |
| 145 | + return out_fp32 |
| 146 | + |
| 147 | + |
| 148 | +def vasiliy_reference_i8i8f32_gemm( |
| 149 | + act_q, |
| 150 | + act_scale, |
| 151 | + act_zp, |
| 152 | + w_q, |
| 153 | + w_scale, |
| 154 | + w_group_size, |
| 155 | +) -> torch.Tensor: |
| 156 | + # |
| 157 | + # now we have the scales/zero_points/quant values for both gemm operands |
| 158 | + # below is a manual slow gemm with integer operands and float rescaling, |
| 159 | + # implemented using eager PyTorch ops. This should be slow but closely |
| 160 | + # (but not exactly) matching a real int8,int8->int32 gemm with |
| 161 | + # rescaling, with the only difference being that the sum inside of the |
| 162 | + # dot product is done in float32 right now. |
| 163 | + # |
| 164 | + q_output = torch.zeros( |
| 165 | + act_q.shape[0], |
| 166 | + w_q.shape[0], |
| 167 | + dtype=torch.float32, |
| 168 | + device=act_q.device, |
| 169 | + ) |
| 170 | + for m_idx in range(act_q.shape[0]): |
| 171 | + for n_idx in range(w_q.shape[0]): |
| 172 | + for g_idx in range(w_q.shape[1] // w_group_size): |
| 173 | + k_start = g_idx * w_group_size |
| 174 | + k_end = k_start + w_group_size |
| 175 | + act_chunk = act_q[m_idx][k_start:k_end] |
| 176 | + w_chunk = w_q[n_idx][k_start:k_end] |
| 177 | + |
| 178 | + # (act_q - act_zp) * w_q |
| 179 | + # result still in int32 |
| 180 | + elem_int32 = (act_chunk - act_zp[m_idx]) * w_chunk |
| 181 | + |
| 182 | + # sum((act_q - act_zp) * w_q) |
| 183 | + # this is in float32, so likely a small deviation from the real |
| 184 | + # kernel, where the entire dot product would be in int32 |
| 185 | + sum_float32 = torch.sum(elem_int32) |
| 186 | + |
| 187 | + # scale |
| 188 | + act_scale_tmp = act_scale[m_idx].squeeze(-1) |
| 189 | + w_scale_tmp = w_scale[n_idx][g_idx].squeeze(-1).bfloat16().float() |
| 190 | + sum_scaled = sum_float32 * act_scale_tmp * w_scale_tmp |
| 191 | + |
| 192 | + # accumulate |
| 193 | + q_output[m_idx][n_idx] += sum_scaled |
| 194 | + |
| 195 | + return q_output |
| 196 | + |
| 197 | + |
| 198 | +class Int8PerTokenActivationInt4PerGroupWeightLinear(torch.nn.Linear): |
| 199 | + def __init__(self, *args, **kwargs): |
| 200 | + super().__init__(*args, **kwargs) |
| 201 | + # manually create fake quantizer configs |
| 202 | + activation_config = FakeQuantizeConfig( |
| 203 | + torch.int8, "per_token", is_symmetric=False |
| 204 | + ) |
| 205 | + weight_config = FakeQuantizeConfig(torch.int4, group_size=32) |
| 206 | + |
| 207 | + # manually create fake quantizers |
| 208 | + # reference: `FakeQuantizedLinear` (https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/qat/linear.py) |
| 209 | + self.activation_fq = FakeQuantizer(activation_config) |
| 210 | + self.weight_fq = FakeQuantizer(weight_config) |
| 211 | + |
| 212 | + def forward(self, input): |
| 213 | + # quantize x |
| 214 | + input_i8, input_scale, input_zp = quantize_x(input) |
| 215 | + |
| 216 | + # quantize w |
| 217 | + _ = self.weight_fq(self.weight) |
| 218 | + w_qmin, w_qmax = _DTYPE_TO_QVALUE_BOUNDS[torch.int4] |
| 219 | + w_granularity = self.weight_fq.config.granularity |
| 220 | + w_group_size = w_granularity.group_size |
| 221 | + w_block_size = (1, w_group_size) |
| 222 | + weight_int4 = torch.ops.torchao.quantize_affine( |
| 223 | + self.weight, |
| 224 | + w_block_size, |
| 225 | + self.weight_fq.scale, |
| 226 | + self.weight_fq.zero_point, |
| 227 | + torch.int8, |
| 228 | + w_qmin, |
| 229 | + w_qmax, |
| 230 | + ) |
| 231 | + |
| 232 | + # original reference |
| 233 | + q_output = vasiliy_reference_i8i8f32_gemm( |
| 234 | + input_i8.to(torch.int32), |
| 235 | + input_scale, |
| 236 | + input_zp, |
| 237 | + weight_int4.to(torch.int32), |
| 238 | + self.weight_fq.scale, |
| 239 | + w_group_size, |
| 240 | + ) |
| 241 | + |
| 242 | + # now also check Kimish's implementation |
| 243 | + q_output2 = _xnnpack_integer_arithmetic_dqlinear_4bit_groupwise( |
| 244 | + input_i8.cpu(), |
| 245 | + input_scale.cpu(), |
| 246 | + input_zp.cpu(), |
| 247 | + weight_int4.cpu(), |
| 248 | + self.weight_fq.scale.cpu(), |
| 249 | + self.weight_fq.zero_point.cpu(), |
| 250 | + self.bias, |
| 251 | + self.weight_fq.config.granularity.group_size, |
| 252 | + ).cuda() |
| 253 | + sqnr = compute_error(q_output, q_output2) |
| 254 | + print("sqnr vasiliy_reference vs kimish_reference", sqnr) |
| 255 | + |
| 256 | + # finally, check vs triton gemm |
| 257 | + q_output3 = int8_matmul_triton( |
| 258 | + input_i8, |
| 259 | + weight_int4.t(), |
| 260 | + input_scale, |
| 261 | + input_zp, |
| 262 | + self.weight_fq.scale.t(), |
| 263 | + w_group_size, |
| 264 | + ) |
| 265 | + |
| 266 | + sqnr3 = compute_error(q_output, q_output3) |
| 267 | + print("sqnr vasiliy_reference vs triton", sqnr3) |
| 268 | + |
| 269 | + sqnr4 = compute_error(q_output2, q_output3) |
| 270 | + print("sqnr kimish_reference vs triton", sqnr4) |
| 271 | + |
| 272 | + return q_output, q_output2 |
| 273 | + |
| 274 | + @classmethod |
| 275 | + def from_float(cls, mod: torch.nn.Linear): |
| 276 | + new_mod = cls(mod.in_features, mod.out_features) |
| 277 | + new_mod.weight = mod.weight |
| 278 | + new_mod.bias = mod.bias |
| 279 | + return new_mod |
| 280 | + |
| 281 | + |
| 282 | +def run(): |
| 283 | + M, K, N = 32, 64, 128 |
| 284 | + |
| 285 | + # TODO(before land): also implement bias=True |
| 286 | + m_hp = nn.Sequential(nn.Linear(K, N, bias=False)).cuda() |
| 287 | + mq_ref = copy.deepcopy(m_hp) |
| 288 | + mq = copy.deepcopy(m_hp) |
| 289 | + |
| 290 | + # create a baseline: QAT with fake quants. Our exact QAT's output should |
| 291 | + # be close to this |
| 292 | + activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) |
| 293 | + weight_config = FakeQuantizeConfig(torch.int4, group_size=32) |
| 294 | + quantize_( |
| 295 | + mq_ref, |
| 296 | + IntXQuantizationAwareTrainingConfig(activation_config, weight_config), |
| 297 | + ) |
| 298 | + |
| 299 | + # create the experiment: forward pass with an integer gemm |
| 300 | + mq[0] = Int8PerTokenActivationInt4PerGroupWeightLinear.from_float(mq[0]) |
| 301 | + |
| 302 | + x_hp = torch.randn(M, K, device="cuda") |
| 303 | + xq_ref = copy.deepcopy(x_hp) |
| 304 | + xq = copy.deepcopy(x_hp) |
| 305 | + |
| 306 | + with torch.no_grad(): |
| 307 | + y_hp = m_hp(x_hp) |
| 308 | + yq_ref = mq_ref(xq_ref) |
| 309 | + yq, yq2 = mq(xq) |
| 310 | + |
| 311 | + sqnr_hp_qref = compute_error(y_hp, yq_ref) |
| 312 | + sqnr_hp_q = compute_error(y_hp, yq) |
| 313 | + sqnr_qref_q = compute_error(yq_ref, yq) |
| 314 | + print("sqnr_hp_qref", sqnr_hp_qref) |
| 315 | + print("sqnr_hp_q", sqnr_hp_q) |
| 316 | + print("sqnr_qref_q", sqnr_qref_q) |
| 317 | + |
| 318 | + |
| 319 | +if __name__ == "__main__": |
| 320 | + fire.Fire(run) |
0 commit comments