diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index 4392ed8cf..91534d645 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -3,6 +3,8 @@ from typing import Dict, Optional, Tuple, Union import torch +import torch._dynamo.config +import torch._inductor.config import transformers from compressed_tensors.quantization import ( ActivationOrdering, @@ -68,7 +70,7 @@ def accumulate_hessian( return H, num_samples -def quantize_weight( +def quantize_weight_original( module: torch.nn.Module, quant_args: QuantizationArgs, hessians_dict: Dict[torch.nn.Module, torch.Tensor], @@ -279,6 +281,281 @@ def quantize_weight( ) +@torch.compile(dynamic=True) +def _process_block( + W1: torch.Tensor, + Hinv1: torch.Tensor, + scale_slice: torch.Tensor, + zero_slice: torch.Tensor, + mask_slice: Optional[torch.Tensor], + quant_min: int, + quant_max: int, + sym: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Process a single block of weight columns using with torch.compile support.""" + count = W1.shape[1] + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + s = scale_slice[:, i] + z = zero_slice[:, i] + + if sym: + z = torch.zeros_like(z) + + scaled = w / s + if not sym: + scaled -= z + q = torch.clamp(torch.round(scaled), quant_min, quant_max) + dq = q * s + if not sym: + dq += z * s + + err1 = (w - dq) / d + loss_col = (w - dq) ** 2 / d**2 + + Q1[:, i] = dq + Err1[:, i] = err1 + losses1[:, i] = loss_col + + w1_err = err1.unsqueeze(1) @ Hinv1[i, i:].unsqueeze(0) + if mask_slice is not None: + mask_block = mask_slice[:, i:] + W1[:, i:] -= w1_err * mask_block + else: + W1[:, i:] -= w1_err + + return Q1, Err1, losses1 + + +def _quantize_core( + W: torch.Tensor, + Hinv: torch.Tensor, + scale_map: torch.Tensor, + zero_map: torch.Tensor, + W_nz_mask: Optional[torch.Tensor], + blocksize: int, + quant_min: int, + quant_max: int, + sym: bool, + num_rows: int, + num_columns: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Core GPTQ quantization loop processing weights in blocks.""" + losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype) + + for i1 in range(0, num_columns, blocksize): + i2 = min(i1 + blocksize, num_columns) + + # Extract current block and corresponding Hessian/quantization params + W1 = W[:, i1:i2].clone() + Hinv1 = Hinv[i1:i2, i1:i2].contiguous() + scale_slice = scale_map[:, i1:i2] + zero_slice = zero_map[:, i1:i2] + mask_slice = None + if W_nz_mask is not None: + mask_slice = W_nz_mask[:, i1:i2] + + # Quantize the current block + Q1, Err1, losses1 = _process_block( + W1, Hinv1, scale_slice, zero_slice, mask_slice, quant_min, quant_max, sym + ) + + W[:, i1:i2] = Q1 + losses += losses1.sum(dim=1) / 2 + + # Propagate block error to remaining unprocessed columns + w_err = Err1 @ Hinv[i1:i2, i2:] + if W_nz_mask is not None: + mask_rest = W_nz_mask[:, i2:] + W[:, i2:] -= w_err * mask_rest + else: + W[:, i2:] -= w_err + + return W, losses + + +def quantize_weight( + module: torch.nn.Module, + quant_args: QuantizationArgs, + hessians_dict: Dict[torch.nn.Module, torch.Tensor], + blocksize: int = 128, + percdamp: float = 0.01, +) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: + """ + Quantize a module weight according to the GPTQ algorithm + + This version is faster than the original one with torch.compile support + + :param module: module with weight being quantized + :param quant_args: quantization arguments used to find quantization parameters + :param hessian_dict: dictionary containing preaccumulated hessian for quantization + :param blocksize: chunk size of quantization updates + :param percdamp: dampening factor on hessian diagonal + :return: loss, quantized_weight, scale, zero_point, g_idx + """ + strategy = quant_args.strategy + actorder = quant_args.actorder + n_bits = quant_args.num_bits + sym = quant_args.symmetric + final_shape = module.weight.shape + final_dtype = module.weight.dtype + W = module.weight.clone() + H = hessians_dict[module] # unfortunately python does not have a `move` keyword + del hessians_dict[module] # so we have to delete the original reference manually + + # create observer for calculating quantization parameters + observer = Observer.load_from_registry( + quant_args.observer, + quantization_args=quant_args, + averaging_constant=1.0, # ignore moving average + ) + + # standardize shape and dtype + if isinstance(module, torch.nn.Conv2d): + W = W.flatten(1) + elif isinstance(module, transformers.Conv1D): + W.transpose_(0, 1) + W = W.to(dtype=GPTQ_PRECISION) + num_rows = W.shape[0] + num_columns = W.shape[1] + + if strategy == QuantizationStrategy.GROUP: + # mapping from column index to group index + g_idx = ( + torch.arange(num_columns, device=W.device, dtype=torch.int) + // quant_args.group_size + ) + + if actorder == ActivationOrdering.GROUP: + # permute by activation order first, then update groups + W, H, perm = _apply_activation_ordering(W, H) + scale, zero_point = observer(W, g_idx=None) + + # use identity g_idx (invert permutation later) + + elif actorder == ActivationOrdering.WEIGHT: + # update groups first, then permute by activation order + scale, zero_point = observer(W, g_idx=None) + W, H, perm = _apply_activation_ordering(W, H) + + # permute g_idx to maintain identity mapping after unpermutation + g_idx = g_idx[perm] + + else: + scale, zero_point = observer(W, g_idx=None) + else: + scale, zero_point = observer(W, g_idx=None) + + scale = scale.to(W.device) + zero_point = zero_point.to(W.device) + + # sparsity mask + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device, dtype=W.dtype))).float() + if preserve_zeros + else None + ) + + # mask dead hessian values + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # compute inverse hessian in place to save memory + try: + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + except torch._C._LinAlgError: + logger.warning( + "Failed to invert hessian due to numerical instability. Consider " + "increasing GPTQModifier.dampening_frac, increasing the number " + "of calibration samples, or shuffling the calibration dataset. " + "Falling back to round-to-nearest for this module." + ) + Hinv = H = torch.eye(num_columns, dtype=H.dtype, device=H.device) + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + # quantize column + if strategy == QuantizationStrategy.TENSOR: + scale_map = scale.expand(num_rows, num_columns) + zero_map = zero_point.expand(num_rows, num_columns) + elif strategy == QuantizationStrategy.CHANNEL: + scale_map = scale.expand(-1, num_columns) + zero_map = zero_point.expand(-1, num_columns) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + scale_map = scale[:, g_idx] + zero_map = zero_point[:, g_idx] + else: + raise ValueError(f"Quantization strategy is not supported for GPTQ: {strategy}") + + if sym: + quant_min = -(2 ** (n_bits - 1)) + quant_max = 2 ** (n_bits - 1) - 1 + else: + quant_min = 0 + quant_max = 2**n_bits - 1 + + W, losses = _quantize_core( + W=W, + Hinv=Hinv, + scale_map=scale_map, + zero_map=zero_map, + W_nz_mask=W_nz_mask, + blocksize=blocksize, + quant_min=quant_min, + quant_max=quant_max, + sym=sym, + num_rows=num_rows, + num_columns=num_columns, + ) + + has_gidx = False + if strategy == QuantizationStrategy.GROUP: + if actorder == ActivationOrdering.WEIGHT: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + + elif actorder == ActivationOrdering.GROUP: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + g_idx = g_idx[invperm] + + # only save g_idx if mapping is not identity + has_gidx = True + + if not has_gidx: + g_idx = None + + if isinstance(module, transformers.Conv1D): + W.transpose_(0, 1) + W = W.reshape(final_shape).to(final_dtype) + + loss = torch.sum(losses).item() + return ( + loss, + W, + scale.to(dtype=final_dtype), + zero_point.to(dtype=quant_args.pytorch_dtype()), + g_idx, + ) + + def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: