From a4f9ba2e623e81c20e782685212437173590cb4d Mon Sep 17 00:00:00 2001 From: aladerran Date: Tue, 17 Jun 2025 21:38:50 +0800 Subject: [PATCH 1/4] Use torch.compile to speed up GPTQ algo Signed-off-by: aladerran --- .../modifiers/quantization/gptq/gptq_quantize.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index 4392ed8cf..f9ef41696 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -3,6 +3,7 @@ from typing import Dict, Optional, Tuple, Union import torch +import torch._dynamo.config import transformers from compressed_tensors.quantization import ( ActivationOrdering, @@ -16,6 +17,8 @@ from llmcompressor.observers.base import Observer from llmcompressor.pytorch.utils.helpers import tensor_sparsity +torch._dynamo.config.capture_scalar_outputs = True + GPTQ_PRECISION = torch.float32 __all__ = ["make_empty_hessian", "accumulate_hessian", "quantize_weight"] @@ -68,6 +71,7 @@ def accumulate_hessian( return H, num_samples +@torch.compile def quantize_weight( module: torch.nn.Module, quant_args: QuantizationArgs, From 8fb026ee2c75ee5697ac637ce8c82ba03607bc2c Mon Sep 17 00:00:00 2001 From: aladerran Date: Sun, 22 Jun 2025 09:45:30 +0800 Subject: [PATCH 2/4] Upload torch.compiled GPTQ as an opt version Signed-off-by: aladerran --- .../quantization/gptq/gptq_quantize.py | 253 +++++++++++++++++- 1 file changed, 252 insertions(+), 1 deletion(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index f9ef41696..3414e9dd7 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -4,6 +4,7 @@ import torch import torch._dynamo.config +import torch._inductor.config import transformers from compressed_tensors.quantization import ( ActivationOrdering, @@ -18,6 +19,8 @@ from llmcompressor.pytorch.utils.helpers import tensor_sparsity torch._dynamo.config.capture_scalar_outputs = True +torch._inductor.config.triton.tile_reductions = True +torch.set_float32_matmul_precision("high") GPTQ_PRECISION = torch.float32 @@ -71,7 +74,6 @@ def accumulate_hessian( return H, num_samples -@torch.compile def quantize_weight( module: torch.nn.Module, quant_args: QuantizationArgs, @@ -283,6 +285,255 @@ def quantize_weight( ) +@torch.compile(dynamic=True) +def _quantize_core( + W: torch.Tensor, + Hinv: torch.Tensor, + scale_map: torch.Tensor, + zero_map: torch.Tensor, + W_nz_mask: Optional[torch.Tensor], + blocksize: int, + quant_min: int, + quant_max: int, + sym: bool, + num_rows: int, + num_columns: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype) + + for i1 in range(0, num_columns, blocksize): + i2 = min(i1 + blocksize, num_columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone().contiguous() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2].contiguous() + + for i in range(count): + col_idx = i1 + i + w = W1[:, i] + d = Hinv1[i, i] + + s = scale_map[:, col_idx] + z = zero_map[:, col_idx] + + if sym: + z = torch.zeros_like(z) + + scaled = w / s + if not sym: + scaled -= z + q = torch.clamp(torch.round(scaled), quant_min, quant_max) + dq = q * s + if not sym: + dq += z * s + + # propagate column error + Q1[:, i] = dq + losses1[:, i] = (w - dq) ** 2 / d**2 + + err1 = (w - dq) / d + Err1[:, i] = err1 + + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if W_nz_mask is not None: + mask_slice = W_nz_mask[:, i1 + i : i2] + W1[:, i:] -= w1_err * mask_slice + else: + W1[:, i:] -= w1_err + + # propagate block error + W[:, i1:i2] = Q1 + losses += torch.sum(losses1.contiguous(), dim=1) / 2 + + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if W_nz_mask is not None: + mask_slice = W_nz_mask[:, i2:] + W[:, i2:] -= w_err * mask_slice + else: + W[:, i2:] -= w_err + + return W, losses + + +def quantize_weight_optimized( + module: torch.nn.Module, + quant_args: QuantizationArgs, + hessians_dict: Dict[torch.nn.Module, torch.Tensor], + blocksize: int = 128, + percdamp: float = 0.01, +) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: + """ + Quantize a module weight according to the GPTQ algorithm + + This version is faster than the original one with torch.compile support + + :param module: module with weight being quantized + :param quant_args: quantization arguments used to find quantization parameters + :param hessian_dict: dictionary containing preaccumulated hessian for quantization + :param blocksize: chunk size of quantization updates + :param percdamp: dampening factor on hessian diagonal + :return: loss, quantized_weight, scale, zero_point, g_idx + """ + strategy = quant_args.strategy + actorder = quant_args.actorder + n_bits = quant_args.num_bits + sym = quant_args.symmetric + final_shape = module.weight.shape + final_dtype = module.weight.dtype + W = module.weight.clone() + H = hessians_dict[module] # unfortunately python does not have a `move` keyword + del hessians_dict[module] # so we have to delete the original reference manually + + # create observer for calculating quantization parameters + observer = Observer.load_from_registry( + quant_args.observer, + quantization_args=quant_args, + averaging_constant=1.0, # ignore moving average + ) + + # standardize shape and dtype + if isinstance(module, torch.nn.Conv2d): + W = W.flatten(1) + elif isinstance(module, transformers.Conv1D): + W.transpose_(0, 1) + W = W.to(dtype=GPTQ_PRECISION) + num_rows = W.shape[0] + num_columns = W.shape[1] + + if strategy == QuantizationStrategy.GROUP: + # mapping from column index to group index + g_idx = ( + torch.arange(num_columns, device=W.device, dtype=torch.int) + // quant_args.group_size + ) + + if actorder == ActivationOrdering.GROUP: + # permute by activation order first, then update groups + W, H, perm = _apply_activation_ordering(W, H) + scale, zero_point = observer(W, g_idx=None) + + # use identity g_idx (invert permutation later) + + elif actorder == ActivationOrdering.WEIGHT: + # update groups first, then permute by activation order + scale, zero_point = observer(W, g_idx=None) + W, H, perm = _apply_activation_ordering(W, H) + + # permute g_idx to maintain identity mapping after unpermutation + g_idx = g_idx[perm] + + else: + scale, zero_point = observer(W, g_idx=None) + else: + scale, zero_point = observer(W, g_idx=None) + + scale = scale.to(W.device) + zero_point = zero_point.to(W.device) + + # sparsity mask + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device, dtype=W.dtype))).float() + if preserve_zeros + else None + ) + + # mask dead hessian values + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + # compute inverse hessian in place to save memory + try: + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(H.shape[0], device=H.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + except torch._C._LinAlgError: + logger.warning( + "Failed to invert hessian due to numerical instability. Consider " + "increasing GPTQModifier.dampening_frac, increasing the number " + "of calibration samples, or shuffling the calibration dataset. " + "Falling back to round-to-nearest for this module." + ) + Hinv = H = torch.eye(num_columns, dtype=H.dtype, device=H.device) + + # See section 3.4 of https://arxiv.org/abs/2203.07259 + # quantize column + if strategy == QuantizationStrategy.TENSOR: + scale_map = scale.expand(num_rows, num_columns) + zero_map = zero_point.expand(num_rows, num_columns) + elif strategy == QuantizationStrategy.CHANNEL: + scale_map = scale.expand(-1, num_columns) + zero_map = zero_point.expand(-1, num_columns) + elif strategy == QuantizationStrategy.GROUP: + # get the group index for the current column + scale_map = scale[:, g_idx] + zero_map = zero_point[:, g_idx] + else: + raise ValueError(f"Quantization strategy is not supported for GPTQ: {strategy}") + + if sym: + quant_min = -(2 ** (n_bits - 1)) + quant_max = 2 ** (n_bits - 1) - 1 + else: + quant_min = 0 + quant_max = 2**n_bits - 1 + + W, losses = _quantize_core( + W=W, + Hinv=Hinv, + scale_map=scale_map, + zero_map=zero_map, + W_nz_mask=W_nz_mask, + blocksize=blocksize, + quant_min=quant_min, + quant_max=quant_max, + sym=sym, + num_rows=num_rows, + num_columns=num_columns, + ) + + has_gidx = False + if strategy == QuantizationStrategy.GROUP: + if actorder == ActivationOrdering.WEIGHT: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + + elif actorder == ActivationOrdering.GROUP: + # restore original permutation + invperm = torch.argsort(perm) + W = W[:, invperm] + g_idx = g_idx[invperm] + + # only save g_idx if mapping is not identity + has_gidx = True + + if not has_gidx: + g_idx = None + + if isinstance(module, transformers.Conv1D): + W.transpose_(0, 1) + W = W.reshape(final_shape).to(final_dtype) + + loss = torch.sum(losses).item() + return ( + loss, + W, + scale.to(dtype=final_dtype), + zero_point.to(dtype=quant_args.pytorch_dtype()), + g_idx, + ) + + def _apply_activation_ordering( W: torch.Tensor, H: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: From 200ceac6a6c6c63a80e3ee3a910109f38ceb26a7 Mon Sep 17 00:00:00 2001 From: aladerran Date: Wed, 2 Jul 2025 22:56:08 +0800 Subject: [PATCH 3/4] Fix long compilation issue Signed-off-by: aladerran --- .../quantization/gptq/gptq_quantize.py | 105 +++++++++++------- 1 file changed, 63 insertions(+), 42 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index 3414e9dd7..a2f353a7f 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -286,6 +286,56 @@ def quantize_weight( @torch.compile(dynamic=True) +def _process_block( + W1: torch.Tensor, + Hinv1: torch.Tensor, + scale_slice: torch.Tensor, + zero_slice: torch.Tensor, + mask_slice: Optional[torch.Tensor], + quant_min: int, + quant_max: int, + sym: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + count = W1.shape[1] + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + losses1 = torch.zeros_like(W1) + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + s = scale_slice[:, i] + z = zero_slice[:, i] + + if sym: + z = torch.zeros_like(z) + + scaled = w / s + if not sym: + scaled -= z + q = torch.clamp(torch.round(scaled), quant_min, quant_max) + dq = q * s + if not sym: + dq += z * s + + err1 = (w - dq) / d + loss_col = (w - dq) ** 2 / d**2 + + Q1[:, i] = dq + Err1[:, i] = err1 + losses1[:, i] = loss_col + + w1_err = err1.unsqueeze(1) @ Hinv1[i, i:].unsqueeze(0) + if mask_slice is not None: + mask_block = mask_slice[:, i:] + W1[:, i:] -= w1_err * mask_block + else: + W1[:, i:] -= w1_err + + return Q1, Err1, losses1 + + def _quantize_core( W: torch.Tensor, Hinv: torch.Tensor, @@ -303,55 +353,26 @@ def _quantize_core( for i1 in range(0, num_columns, blocksize): i2 = min(i1 + blocksize, num_columns) - count = i2 - i1 - W1 = W[:, i1:i2].clone().contiguous() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - losses1 = torch.zeros_like(W1) + W1 = W[:, i1:i2].clone() Hinv1 = Hinv[i1:i2, i1:i2].contiguous() + scale_slice = scale_map[:, i1:i2] + zero_slice = zero_map[:, i1:i2] + mask_slice = None + if W_nz_mask is not None: + mask_slice = W_nz_mask[:, i1:i2] - for i in range(count): - col_idx = i1 + i - w = W1[:, i] - d = Hinv1[i, i] - - s = scale_map[:, col_idx] - z = zero_map[:, col_idx] - - if sym: - z = torch.zeros_like(z) - - scaled = w / s - if not sym: - scaled -= z - q = torch.clamp(torch.round(scaled), quant_min, quant_max) - dq = q * s - if not sym: - dq += z * s - - # propagate column error - Q1[:, i] = dq - losses1[:, i] = (w - dq) ** 2 / d**2 - - err1 = (w - dq) / d - Err1[:, i] = err1 - - w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - if W_nz_mask is not None: - mask_slice = W_nz_mask[:, i1 + i : i2] - W1[:, i:] -= w1_err * mask_slice - else: - W1[:, i:] -= w1_err + Q1, Err1, losses1 = _process_block( + W1, Hinv1, scale_slice, zero_slice, mask_slice, quant_min, quant_max, sym + ) - # propagate block error W[:, i1:i2] = Q1 - losses += torch.sum(losses1.contiguous(), dim=1) / 2 + losses += losses1.sum(dim=1) / 2 - w_err = Err1.matmul(Hinv[i1:i2, i2:]) + w_err = Err1 @ Hinv[i1:i2, i2:] if W_nz_mask is not None: - mask_slice = W_nz_mask[:, i2:] - W[:, i2:] -= w_err * mask_slice + mask_rest = W_nz_mask[:, i2:] + W[:, i2:] -= w_err * mask_rest else: W[:, i2:] -= w_err From 63bf122c1ba0458024ce18c783265a6699064c46 Mon Sep 17 00:00:00 2001 From: aladerran Date: Mon, 21 Jul 2025 23:27:07 +0800 Subject: [PATCH 4/4] Revise & add comments Signed-off-by: aladerran --- .../modifiers/quantization/gptq/gptq_quantize.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py index a2f353a7f..91534d645 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py +++ b/src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py @@ -18,10 +18,6 @@ from llmcompressor.observers.base import Observer from llmcompressor.pytorch.utils.helpers import tensor_sparsity -torch._dynamo.config.capture_scalar_outputs = True -torch._inductor.config.triton.tile_reductions = True -torch.set_float32_matmul_precision("high") - GPTQ_PRECISION = torch.float32 __all__ = ["make_empty_hessian", "accumulate_hessian", "quantize_weight"] @@ -74,7 +70,7 @@ def accumulate_hessian( return H, num_samples -def quantize_weight( +def quantize_weight_original( module: torch.nn.Module, quant_args: QuantizationArgs, hessians_dict: Dict[torch.nn.Module, torch.Tensor], @@ -296,6 +292,7 @@ def _process_block( quant_max: int, sym: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Process a single block of weight columns using with torch.compile support.""" count = W1.shape[1] Q1 = torch.zeros_like(W1) Err1 = torch.zeros_like(W1) @@ -349,11 +346,13 @@ def _quantize_core( num_rows: int, num_columns: int, ) -> Tuple[torch.Tensor, torch.Tensor]: + """Core GPTQ quantization loop processing weights in blocks.""" losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype) for i1 in range(0, num_columns, blocksize): i2 = min(i1 + blocksize, num_columns) + # Extract current block and corresponding Hessian/quantization params W1 = W[:, i1:i2].clone() Hinv1 = Hinv[i1:i2, i1:i2].contiguous() scale_slice = scale_map[:, i1:i2] @@ -362,6 +361,7 @@ def _quantize_core( if W_nz_mask is not None: mask_slice = W_nz_mask[:, i1:i2] + # Quantize the current block Q1, Err1, losses1 = _process_block( W1, Hinv1, scale_slice, zero_slice, mask_slice, quant_min, quant_max, sym ) @@ -369,6 +369,7 @@ def _quantize_core( W[:, i1:i2] = Q1 losses += losses1.sum(dim=1) / 2 + # Propagate block error to remaining unprocessed columns w_err = Err1 @ Hinv[i1:i2, i2:] if W_nz_mask is not None: mask_rest = W_nz_mask[:, i2:] @@ -379,7 +380,7 @@ def _quantize_core( return W, losses -def quantize_weight_optimized( +def quantize_weight( module: torch.nn.Module, quant_args: QuantizationArgs, hessians_dict: Dict[torch.nn.Module, torch.Tensor],