-
Notifications
You must be signed in to change notification settings - Fork 182
Use torch.compile to speed up GPTQ algo #1561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aladerran
wants to merge
6
commits into
vllm-project:main
Choose a base branch
from
aladerran:gptq_tc
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+278
−1
Open
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
a4f9ba2
Use torch.compile to speed up GPTQ algo
aladerran 8fb026e
Upload torch.compiled GPTQ as an opt version
aladerran f4a3419
Merge branch 'main' into gptq_tc
kylesayrs 200ceac
Fix long compilation issue
aladerran 6a5e420
Merge branch 'main' into gptq_tc
dsikka 63bf122
Revise & add comments
aladerran File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
from typing import Dict, Optional, Tuple, Union | ||
|
||
import torch | ||
import torch._dynamo.config | ||
import torch._inductor.config | ||
import transformers | ||
from compressed_tensors.quantization import ( | ||
ActivationOrdering, | ||
|
@@ -16,6 +18,10 @@ | |
from llmcompressor.observers.base import Observer | ||
from llmcompressor.pytorch.utils.helpers import tensor_sparsity | ||
|
||
torch._dynamo.config.capture_scalar_outputs = True | ||
torch._inductor.config.triton.tile_reductions = True | ||
torch.set_float32_matmul_precision("high") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on this -- you don't want to set these globally. |
||
|
||
GPTQ_PRECISION = torch.float32 | ||
|
||
__all__ = ["make_empty_hessian", "accumulate_hessian", "quantize_weight"] | ||
|
@@ -279,6 +285,276 @@ def quantize_weight( | |
) | ||
|
||
|
||
@torch.compile(dynamic=True) | ||
def _process_block( | ||
W1: torch.Tensor, | ||
Hinv1: torch.Tensor, | ||
scale_slice: torch.Tensor, | ||
zero_slice: torch.Tensor, | ||
mask_slice: Optional[torch.Tensor], | ||
quant_min: int, | ||
quant_max: int, | ||
sym: bool, | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
count = W1.shape[1] | ||
Q1 = torch.zeros_like(W1) | ||
Err1 = torch.zeros_like(W1) | ||
losses1 = torch.zeros_like(W1) | ||
|
||
for i in range(count): | ||
w = W1[:, i] | ||
d = Hinv1[i, i] | ||
|
||
s = scale_slice[:, i] | ||
z = zero_slice[:, i] | ||
|
||
if sym: | ||
z = torch.zeros_like(z) | ||
|
||
scaled = w / s | ||
if not sym: | ||
scaled -= z | ||
q = torch.clamp(torch.round(scaled), quant_min, quant_max) | ||
dq = q * s | ||
if not sym: | ||
dq += z * s | ||
|
||
err1 = (w - dq) / d | ||
loss_col = (w - dq) ** 2 / d**2 | ||
|
||
Q1[:, i] = dq | ||
Err1[:, i] = err1 | ||
losses1[:, i] = loss_col | ||
|
||
w1_err = err1.unsqueeze(1) @ Hinv1[i, i:].unsqueeze(0) | ||
if mask_slice is not None: | ||
mask_block = mask_slice[:, i:] | ||
W1[:, i:] -= w1_err * mask_block | ||
else: | ||
W1[:, i:] -= w1_err | ||
|
||
return Q1, Err1, losses1 | ||
|
||
|
||
def _quantize_core( | ||
W: torch.Tensor, | ||
Hinv: torch.Tensor, | ||
scale_map: torch.Tensor, | ||
zero_map: torch.Tensor, | ||
W_nz_mask: Optional[torch.Tensor], | ||
blocksize: int, | ||
quant_min: int, | ||
quant_max: int, | ||
sym: bool, | ||
num_rows: int, | ||
num_columns: int, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype) | ||
|
||
for i1 in range(0, num_columns, blocksize): | ||
i2 = min(i1 + blocksize, num_columns) | ||
|
||
W1 = W[:, i1:i2].clone() | ||
Hinv1 = Hinv[i1:i2, i1:i2].contiguous() | ||
scale_slice = scale_map[:, i1:i2] | ||
zero_slice = zero_map[:, i1:i2] | ||
mask_slice = None | ||
if W_nz_mask is not None: | ||
mask_slice = W_nz_mask[:, i1:i2] | ||
|
||
Q1, Err1, losses1 = _process_block( | ||
W1, Hinv1, scale_slice, zero_slice, mask_slice, quant_min, quant_max, sym | ||
) | ||
|
||
W[:, i1:i2] = Q1 | ||
losses += losses1.sum(dim=1) / 2 | ||
|
||
w_err = Err1 @ Hinv[i1:i2, i2:] | ||
if W_nz_mask is not None: | ||
mask_rest = W_nz_mask[:, i2:] | ||
W[:, i2:] -= w_err * mask_rest | ||
else: | ||
W[:, i2:] -= w_err | ||
|
||
return W, losses | ||
|
||
|
||
def quantize_weight_optimized( | ||
module: torch.nn.Module, | ||
quant_args: QuantizationArgs, | ||
hessians_dict: Dict[torch.nn.Module, torch.Tensor], | ||
blocksize: int = 128, | ||
percdamp: float = 0.01, | ||
) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: | ||
""" | ||
Quantize a module weight according to the GPTQ algorithm | ||
This version is faster than the original one with torch.compile support | ||
:param module: module with weight being quantized | ||
:param quant_args: quantization arguments used to find quantization parameters | ||
:param hessian_dict: dictionary containing preaccumulated hessian for quantization | ||
:param blocksize: chunk size of quantization updates | ||
:param percdamp: dampening factor on hessian diagonal | ||
:return: loss, quantized_weight, scale, zero_point, g_idx | ||
""" | ||
strategy = quant_args.strategy | ||
actorder = quant_args.actorder | ||
n_bits = quant_args.num_bits | ||
sym = quant_args.symmetric | ||
final_shape = module.weight.shape | ||
final_dtype = module.weight.dtype | ||
W = module.weight.clone() | ||
H = hessians_dict[module] # unfortunately python does not have a `move` keyword | ||
del hessians_dict[module] # so we have to delete the original reference manually | ||
|
||
# create observer for calculating quantization parameters | ||
observer = Observer.load_from_registry( | ||
quant_args.observer, | ||
quantization_args=quant_args, | ||
averaging_constant=1.0, # ignore moving average | ||
) | ||
|
||
# standardize shape and dtype | ||
if isinstance(module, torch.nn.Conv2d): | ||
W = W.flatten(1) | ||
elif isinstance(module, transformers.Conv1D): | ||
W.transpose_(0, 1) | ||
W = W.to(dtype=GPTQ_PRECISION) | ||
num_rows = W.shape[0] | ||
num_columns = W.shape[1] | ||
|
||
if strategy == QuantizationStrategy.GROUP: | ||
# mapping from column index to group index | ||
g_idx = ( | ||
torch.arange(num_columns, device=W.device, dtype=torch.int) | ||
// quant_args.group_size | ||
) | ||
|
||
if actorder == ActivationOrdering.GROUP: | ||
# permute by activation order first, then update groups | ||
W, H, perm = _apply_activation_ordering(W, H) | ||
scale, zero_point = observer(W, g_idx=None) | ||
|
||
# use identity g_idx (invert permutation later) | ||
|
||
elif actorder == ActivationOrdering.WEIGHT: | ||
# update groups first, then permute by activation order | ||
scale, zero_point = observer(W, g_idx=None) | ||
W, H, perm = _apply_activation_ordering(W, H) | ||
|
||
# permute g_idx to maintain identity mapping after unpermutation | ||
g_idx = g_idx[perm] | ||
|
||
else: | ||
scale, zero_point = observer(W, g_idx=None) | ||
else: | ||
scale, zero_point = observer(W, g_idx=None) | ||
|
||
scale = scale.to(W.device) | ||
zero_point = zero_point.to(W.device) | ||
|
||
# sparsity mask | ||
sparsity = tensor_sparsity(W) | ||
preserve_zeros = sparsity >= SPARSITY_THRESHOLD | ||
W_nz_mask = ( | ||
(~torch.isclose(W, torch.zeros(1, device=W.device, dtype=W.dtype))).float() | ||
if preserve_zeros | ||
else None | ||
) | ||
|
||
# mask dead hessian values | ||
dead = torch.diag(H) == 0 | ||
H[dead, dead] = 1 | ||
W[:, dead] = 0 | ||
|
||
# compute inverse hessian in place to save memory | ||
try: | ||
damp = percdamp * torch.mean(torch.diag(H)) | ||
diag = torch.arange(H.shape[0], device=H.device) | ||
H[diag, diag] += damp | ||
H = torch.linalg.cholesky(H) | ||
H = torch.cholesky_inverse(H) | ||
H = torch.linalg.cholesky(H, upper=True) | ||
Hinv = H | ||
except torch._C._LinAlgError: | ||
logger.warning( | ||
"Failed to invert hessian due to numerical instability. Consider " | ||
"increasing GPTQModifier.dampening_frac, increasing the number " | ||
"of calibration samples, or shuffling the calibration dataset. " | ||
"Falling back to round-to-nearest for this module." | ||
) | ||
Hinv = H = torch.eye(num_columns, dtype=H.dtype, device=H.device) | ||
|
||
# See section 3.4 of https://arxiv.org/abs/2203.07259 | ||
# quantize column | ||
if strategy == QuantizationStrategy.TENSOR: | ||
scale_map = scale.expand(num_rows, num_columns) | ||
zero_map = zero_point.expand(num_rows, num_columns) | ||
elif strategy == QuantizationStrategy.CHANNEL: | ||
scale_map = scale.expand(-1, num_columns) | ||
zero_map = zero_point.expand(-1, num_columns) | ||
elif strategy == QuantizationStrategy.GROUP: | ||
# get the group index for the current column | ||
scale_map = scale[:, g_idx] | ||
zero_map = zero_point[:, g_idx] | ||
else: | ||
raise ValueError(f"Quantization strategy is not supported for GPTQ: {strategy}") | ||
|
||
if sym: | ||
quant_min = -(2 ** (n_bits - 1)) | ||
quant_max = 2 ** (n_bits - 1) - 1 | ||
else: | ||
quant_min = 0 | ||
quant_max = 2**n_bits - 1 | ||
|
||
W, losses = _quantize_core( | ||
W=W, | ||
Hinv=Hinv, | ||
scale_map=scale_map, | ||
zero_map=zero_map, | ||
W_nz_mask=W_nz_mask, | ||
blocksize=blocksize, | ||
quant_min=quant_min, | ||
quant_max=quant_max, | ||
sym=sym, | ||
num_rows=num_rows, | ||
num_columns=num_columns, | ||
) | ||
|
||
has_gidx = False | ||
if strategy == QuantizationStrategy.GROUP: | ||
if actorder == ActivationOrdering.WEIGHT: | ||
# restore original permutation | ||
invperm = torch.argsort(perm) | ||
W = W[:, invperm] | ||
|
||
elif actorder == ActivationOrdering.GROUP: | ||
# restore original permutation | ||
invperm = torch.argsort(perm) | ||
W = W[:, invperm] | ||
g_idx = g_idx[invperm] | ||
|
||
# only save g_idx if mapping is not identity | ||
has_gidx = True | ||
|
||
if not has_gidx: | ||
g_idx = None | ||
|
||
if isinstance(module, transformers.Conv1D): | ||
W.transpose_(0, 1) | ||
W = W.reshape(final_shape).to(final_dtype) | ||
|
||
loss = torch.sum(losses).item() | ||
return ( | ||
loss, | ||
W, | ||
scale.to(dtype=final_dtype), | ||
zero_point.to(dtype=quant_args.pytorch_dtype()), | ||
g_idx, | ||
) | ||
|
||
|
||
def _apply_activation_ordering( | ||
W: torch.Tensor, H: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting
torch._dynamo.config.capture_scalar_outputs = True
at the module level applies this configuration globally to any code that imports this module. While this might be necessary fortorch.compile
to function correctly with thequantize_weight
function, it's a broad setting that could potentially affect other parts of the codebase in unexpected ways. Consider adding a brief comment explaining why this setting is needed specifically for this module/function and acknowledging its global scope.