|
8 | 8 | import torch
|
9 | 9 | import torch.nn.functional as F
|
10 | 10 |
|
| 11 | +from torchao.core.config import AOBaseConfig |
11 | 12 | from torchao.dtypes import to_affine_quantized_intx
|
12 | 13 | from torchao.dtypes.uintx.uintx_layout import UintxLayout
|
| 14 | +from torchao.quantization import Int8DynamicActivationIntxWeightConfig |
13 | 15 | from torchao.quantization.granularity import Granularity
|
14 | 16 | from torchao.quantization.observer import (
|
15 | 17 | AffineQuantizedObserverBase,
|
|
18 | 20 | MappingType,
|
19 | 21 | ZeroPointDomain,
|
20 | 22 | )
|
| 23 | +from torchao.quantization.transform_module import ( |
| 24 | + _QUANTIZE_CONFIG_HANDLER, |
| 25 | +) |
| 26 | +from torchao.utils import DummyModule |
21 | 27 |
|
22 | 28 |
|
23 | 29 | class AWQObserver(AffineQuantizedObserverBase):
|
@@ -145,6 +151,134 @@ def calculate_qparams(self):
|
145 | 151 | return best_scales.detach()
|
146 | 152 |
|
147 | 153 |
|
| 154 | +class AWQObserver2(AffineQuantizedObserverBase): |
| 155 | + def __init__( |
| 156 | + self, |
| 157 | + weight: torch.Tensor, |
| 158 | + bias: torch.Tensor, |
| 159 | + config: AOBaseConfig, |
| 160 | + n_validation_examples: int, |
| 161 | + validation_sequence_len: int, |
| 162 | + scale_search_space_size: int = 20, |
| 163 | + base_config: Optional[AOBaseConfig] = None, |
| 164 | + ): |
| 165 | + """ |
| 166 | + A custom observer for Activation aware Weight Quantization (AWQ) |
| 167 | +
|
| 168 | + Args: |
| 169 | + weight: The weight tensor to be observed. |
| 170 | + bias: The bias tensor to be observed. |
| 171 | + quantization_granularity: Granularity which specifies how many weights share the same scale/zero point |
| 172 | + input_dtype: The data type of the input tensor. |
| 173 | + mapping_type: Always set to asymmetric |
| 174 | + target_dtype: The target data type of the quantized tensor |
| 175 | + n_validation_examples: Number of examples used to calibrate observer |
| 176 | + validation_sequence_len: Number of tokens in each example |
| 177 | + scale_search_space_size: The number of scales to search for. |
| 178 | + quant_min: The minimum quantized value |
| 179 | + quant_max: The maximum quantized value |
| 180 | + eps: The minimum scale. |
| 181 | + scale_dtype: The data type of the scale tensor. |
| 182 | + zero_point_dtype: The data type of the zero point tensor. |
| 183 | + preserve_zero: A flag to indicate whether we need zero to be exactly |
| 184 | + representable or not. |
| 185 | + zero_point_domain: The domain of the zero point. |
| 186 | + """ |
| 187 | + self.base_config = base_config |
| 188 | + quant_min = getattr(config, "quant_min", None) |
| 189 | + quant_max = getattr(config, "quant_max", None) |
| 190 | + |
| 191 | + assert isinstance(base_config, Int8DynamicActivationIntxWeightConfig) |
| 192 | + # TODO: |
| 193 | + quantization_granularity = base_config.weight_granularity |
| 194 | + target_dtype = base_config.weight_dtype |
| 195 | + mapping_type = base_config.weight_mapping_type |
| 196 | + |
| 197 | + # TODO: |
| 198 | + super().__init__( |
| 199 | + mapping_type, |
| 200 | + target_dtype, |
| 201 | + quantization_granularity, |
| 202 | + quant_min=quant_min, |
| 203 | + quant_max=quant_max, |
| 204 | + ) |
| 205 | + self.quantization_granularity = quantization_granularity |
| 206 | + self.weight = weight |
| 207 | + self.bias = bias |
| 208 | + self.n_validation_examples = n_validation_examples |
| 209 | + self.validation_sequence_len = validation_sequence_len |
| 210 | + self.calibration_token_count = 0 |
| 211 | + self.inputs = [] |
| 212 | + self.outputs = [] |
| 213 | + self.scale_options = scale_search_space_size |
| 214 | + self.device = self.weight.device |
| 215 | + self.average = torch.zeros((1, weight.shape[1]), device=self.device) |
| 216 | + if self.bias is not None: |
| 217 | + self.bias.to(self.device) |
| 218 | + |
| 219 | + @torch.no_grad() |
| 220 | + def forward(self, input: torch.Tensor, output: torch.Tensor): |
| 221 | + # import pdb |
| 222 | + # pdb.set_trace() |
| 223 | + # print(input.shape, input.abs().sum(1).shape, self.average.shape) |
| 224 | + if len(self.inputs) < self.n_validation_examples: |
| 225 | + self.inputs.append(input.to("cpu")) |
| 226 | + self.outputs.append(output.to("cpu")) |
| 227 | + self.calibration_token_count += input.shape[-2] |
| 228 | + self.average += input.abs().sum(-2) |
| 229 | + |
| 230 | + def calculate_qparams(self): |
| 231 | + # import pdb |
| 232 | + # pdb.set_trace() |
| 233 | + assert self.outputs != None, ( |
| 234 | + "calibrate observer first by running model on exemplar data" |
| 235 | + ) |
| 236 | + self.average /= self.calibration_token_count |
| 237 | + for i in range(self.n_validation_examples): |
| 238 | + self.inputs[i] = self.inputs[i].to(self.device) |
| 239 | + self.outputs[i] = self.outputs[i].to(self.device) |
| 240 | + |
| 241 | + best_loss = float("inf") |
| 242 | + best_scales = None |
| 243 | + for i in range(self.scale_options): |
| 244 | + ratio = i * 1 / self.scale_options |
| 245 | + scales = self.average.pow(ratio).to(self.weight.dtype) |
| 246 | + scales = scales / (scales.max() * scales.min()).sqrt() |
| 247 | + # layout = UintxLayout(self.target_dtype) |
| 248 | + # # regardless of weight dtype, we have to store as packed uint8 tensors |
| 249 | + # tensor_dtype = torch.uint8 |
| 250 | + # w = to_affine_quantized_intx( |
| 251 | + # self.weight * scales, |
| 252 | + # self.mapping_type, |
| 253 | + # (1, self.quantization_granularity.group_size), |
| 254 | + # tensor_dtype, |
| 255 | + # quant_min=self.quant_min, |
| 256 | + # quant_max=self.quant_max, |
| 257 | + # eps=self.eps, |
| 258 | + # scale_dtype=self.scale_dtype, |
| 259 | + # zero_point_dtype=self.zero_point_dtype, |
| 260 | + # preserve_zero=self.preserve_zero, |
| 261 | + # zero_point_domain=self.zero_point_domain, |
| 262 | + # _layout=layout, |
| 263 | + # ) |
| 264 | + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(self.base_config)] |
| 265 | + dummy_mod = DummyModule(self.weight * scales) |
| 266 | + quant_mod = base_config_handler(dummy_mod, self.base_config) |
| 267 | + w = quant_mod.weight |
| 268 | + |
| 269 | + loss = 0 |
| 270 | + for i in range(self.n_validation_examples): |
| 271 | + q_out = F.linear(self.inputs[i] / scales, w, self.bias) |
| 272 | + loss += (self.outputs[i] - q_out).pow(2).mean().item() |
| 273 | + if loss < best_loss: |
| 274 | + best_scales = scales |
| 275 | + best_loss = loss |
| 276 | + for i in range(self.n_validation_examples): |
| 277 | + self.inputs[i].to("cpu") |
| 278 | + self.outputs[i].to("cpu") |
| 279 | + return best_scales.detach() |
| 280 | + |
| 281 | + |
148 | 282 | class AWQObservedLinear(torch.nn.Linear):
|
149 | 283 | def __init__(
|
150 | 284 | self,
|
|
0 commit comments