Skip to content

Use torch.compile to speed up GPTQ algo #1561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 278 additions & 1 deletion src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Dict, Optional, Tuple, Union

import torch
import torch._dynamo.config
import torch._inductor.config
import transformers
from compressed_tensors.quantization import (
ActivationOrdering,
Expand Down Expand Up @@ -68,7 +70,7 @@ def accumulate_hessian(
return H, num_samples


def quantize_weight(
def quantize_weight_original(
module: torch.nn.Module,
quant_args: QuantizationArgs,
hessians_dict: Dict[torch.nn.Module, torch.Tensor],
Expand Down Expand Up @@ -279,6 +281,281 @@ def quantize_weight(
)


@torch.compile(dynamic=True)
def _process_block(
W1: torch.Tensor,
Hinv1: torch.Tensor,
scale_slice: torch.Tensor,
zero_slice: torch.Tensor,
mask_slice: Optional[torch.Tensor],
quant_min: int,
quant_max: int,
sym: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Process a single block of weight columns using with torch.compile support."""
count = W1.shape[1]
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
losses1 = torch.zeros_like(W1)

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

s = scale_slice[:, i]
z = zero_slice[:, i]

if sym:
z = torch.zeros_like(z)

scaled = w / s
if not sym:
scaled -= z
q = torch.clamp(torch.round(scaled), quant_min, quant_max)
dq = q * s
if not sym:
dq += z * s

err1 = (w - dq) / d
loss_col = (w - dq) ** 2 / d**2

Q1[:, i] = dq
Err1[:, i] = err1
losses1[:, i] = loss_col

w1_err = err1.unsqueeze(1) @ Hinv1[i, i:].unsqueeze(0)
if mask_slice is not None:
mask_block = mask_slice[:, i:]
W1[:, i:] -= w1_err * mask_block
else:
W1[:, i:] -= w1_err

return Q1, Err1, losses1


def _quantize_core(
W: torch.Tensor,
Hinv: torch.Tensor,
scale_map: torch.Tensor,
zero_map: torch.Tensor,
W_nz_mask: Optional[torch.Tensor],
blocksize: int,
quant_min: int,
quant_max: int,
sym: bool,
num_rows: int,
num_columns: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Core GPTQ quantization loop processing weights in blocks."""
losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype)

for i1 in range(0, num_columns, blocksize):
i2 = min(i1 + blocksize, num_columns)

# Extract current block and corresponding Hessian/quantization params
W1 = W[:, i1:i2].clone()
Hinv1 = Hinv[i1:i2, i1:i2].contiguous()
scale_slice = scale_map[:, i1:i2]
zero_slice = zero_map[:, i1:i2]
mask_slice = None
if W_nz_mask is not None:
mask_slice = W_nz_mask[:, i1:i2]

# Quantize the current block
Q1, Err1, losses1 = _process_block(
W1, Hinv1, scale_slice, zero_slice, mask_slice, quant_min, quant_max, sym
)

W[:, i1:i2] = Q1
losses += losses1.sum(dim=1) / 2

# Propagate block error to remaining unprocessed columns
w_err = Err1 @ Hinv[i1:i2, i2:]
if W_nz_mask is not None:
mask_rest = W_nz_mask[:, i2:]
W[:, i2:] -= w_err * mask_rest
else:
W[:, i2:] -= w_err

return W, losses


def quantize_weight(
module: torch.nn.Module,
quant_args: QuantizationArgs,
hessians_dict: Dict[torch.nn.Module, torch.Tensor],
blocksize: int = 128,
percdamp: float = 0.01,
) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]:
"""
Quantize a module weight according to the GPTQ algorithm

This version is faster than the original one with torch.compile support

:param module: module with weight being quantized
:param quant_args: quantization arguments used to find quantization parameters
:param hessian_dict: dictionary containing preaccumulated hessian for quantization
:param blocksize: chunk size of quantization updates
:param percdamp: dampening factor on hessian diagonal
:return: loss, quantized_weight, scale, zero_point, g_idx
"""
strategy = quant_args.strategy
actorder = quant_args.actorder
n_bits = quant_args.num_bits
sym = quant_args.symmetric
final_shape = module.weight.shape
final_dtype = module.weight.dtype
W = module.weight.clone()
H = hessians_dict[module] # unfortunately python does not have a `move` keyword
del hessians_dict[module] # so we have to delete the original reference manually

# create observer for calculating quantization parameters
observer = Observer.load_from_registry(
quant_args.observer,
quantization_args=quant_args,
averaging_constant=1.0, # ignore moving average
)

# standardize shape and dtype
if isinstance(module, torch.nn.Conv2d):
W = W.flatten(1)
elif isinstance(module, transformers.Conv1D):
W.transpose_(0, 1)
W = W.to(dtype=GPTQ_PRECISION)
num_rows = W.shape[0]
num_columns = W.shape[1]

if strategy == QuantizationStrategy.GROUP:
# mapping from column index to group index
g_idx = (
torch.arange(num_columns, device=W.device, dtype=torch.int)
// quant_args.group_size
)

if actorder == ActivationOrdering.GROUP:
# permute by activation order first, then update groups
W, H, perm = _apply_activation_ordering(W, H)
scale, zero_point = observer(W, g_idx=None)

# use identity g_idx (invert permutation later)

elif actorder == ActivationOrdering.WEIGHT:
# update groups first, then permute by activation order
scale, zero_point = observer(W, g_idx=None)
W, H, perm = _apply_activation_ordering(W, H)

# permute g_idx to maintain identity mapping after unpermutation
g_idx = g_idx[perm]

else:
scale, zero_point = observer(W, g_idx=None)
else:
scale, zero_point = observer(W, g_idx=None)

scale = scale.to(W.device)
zero_point = zero_point.to(W.device)

# sparsity mask
sparsity = tensor_sparsity(W)
preserve_zeros = sparsity >= SPARSITY_THRESHOLD
W_nz_mask = (
(~torch.isclose(W, torch.zeros(1, device=W.device, dtype=W.dtype))).float()
if preserve_zeros
else None
)

# mask dead hessian values
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

# compute inverse hessian in place to save memory
try:
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(H.shape[0], device=H.device)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
except torch._C._LinAlgError:
logger.warning(
"Failed to invert hessian due to numerical instability. Consider "
"increasing GPTQModifier.dampening_frac, increasing the number "
"of calibration samples, or shuffling the calibration dataset. "
"Falling back to round-to-nearest for this module."
)
Hinv = H = torch.eye(num_columns, dtype=H.dtype, device=H.device)

# See section 3.4 of https://arxiv.org/abs/2203.07259
# quantize column
if strategy == QuantizationStrategy.TENSOR:
scale_map = scale.expand(num_rows, num_columns)
zero_map = zero_point.expand(num_rows, num_columns)
elif strategy == QuantizationStrategy.CHANNEL:
scale_map = scale.expand(-1, num_columns)
zero_map = zero_point.expand(-1, num_columns)
elif strategy == QuantizationStrategy.GROUP:
# get the group index for the current column
scale_map = scale[:, g_idx]
zero_map = zero_point[:, g_idx]
else:
raise ValueError(f"Quantization strategy is not supported for GPTQ: {strategy}")

if sym:
quant_min = -(2 ** (n_bits - 1))
quant_max = 2 ** (n_bits - 1) - 1
else:
quant_min = 0
quant_max = 2**n_bits - 1

W, losses = _quantize_core(
W=W,
Hinv=Hinv,
scale_map=scale_map,
zero_map=zero_map,
W_nz_mask=W_nz_mask,
blocksize=blocksize,
quant_min=quant_min,
quant_max=quant_max,
sym=sym,
num_rows=num_rows,
num_columns=num_columns,
)

has_gidx = False
if strategy == QuantizationStrategy.GROUP:
if actorder == ActivationOrdering.WEIGHT:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]

elif actorder == ActivationOrdering.GROUP:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]
g_idx = g_idx[invperm]

# only save g_idx if mapping is not identity
has_gidx = True

if not has_gidx:
g_idx = None

if isinstance(module, transformers.Conv1D):
W.transpose_(0, 1)
W = W.reshape(final_shape).to(final_dtype)

loss = torch.sum(losses).item()
return (
loss,
W,
scale.to(dtype=final_dtype),
zero_point.to(dtype=quant_args.pytorch_dtype()),
g_idx,
)


def _apply_activation_ordering(
W: torch.Tensor, H: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down