Skip to content

Commit 8fb026e

Browse files
committed
Upload torch.compiled GPTQ as an opt version
Signed-off-by: aladerran <aladerran@gmail.com>
1 parent a4f9ba2 commit 8fb026e

File tree

1 file changed

+252
-1
lines changed

1 file changed

+252
-1
lines changed

src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py

Lines changed: 252 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
import torch._dynamo.config
7+
import torch._inductor.config
78
import transformers
89
from compressed_tensors.quantization import (
910
ActivationOrdering,
@@ -18,6 +19,8 @@
1819
from llmcompressor.pytorch.utils.helpers import tensor_sparsity
1920

2021
torch._dynamo.config.capture_scalar_outputs = True
22+
torch._inductor.config.triton.tile_reductions = True
23+
torch.set_float32_matmul_precision("high")
2124

2225
GPTQ_PRECISION = torch.float32
2326

@@ -71,7 +74,6 @@ def accumulate_hessian(
7174
return H, num_samples
7275

7376

74-
@torch.compile
7577
def quantize_weight(
7678
module: torch.nn.Module,
7779
quant_args: QuantizationArgs,
@@ -283,6 +285,255 @@ def quantize_weight(
283285
)
284286

285287

288+
@torch.compile(dynamic=True)
289+
def _quantize_core(
290+
W: torch.Tensor,
291+
Hinv: torch.Tensor,
292+
scale_map: torch.Tensor,
293+
zero_map: torch.Tensor,
294+
W_nz_mask: Optional[torch.Tensor],
295+
blocksize: int,
296+
quant_min: int,
297+
quant_max: int,
298+
sym: bool,
299+
num_rows: int,
300+
num_columns: int,
301+
) -> Tuple[torch.Tensor, torch.Tensor]:
302+
losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype)
303+
304+
for i1 in range(0, num_columns, blocksize):
305+
i2 = min(i1 + blocksize, num_columns)
306+
count = i2 - i1
307+
308+
W1 = W[:, i1:i2].clone().contiguous()
309+
Q1 = torch.zeros_like(W1)
310+
Err1 = torch.zeros_like(W1)
311+
losses1 = torch.zeros_like(W1)
312+
Hinv1 = Hinv[i1:i2, i1:i2].contiguous()
313+
314+
for i in range(count):
315+
col_idx = i1 + i
316+
w = W1[:, i]
317+
d = Hinv1[i, i]
318+
319+
s = scale_map[:, col_idx]
320+
z = zero_map[:, col_idx]
321+
322+
if sym:
323+
z = torch.zeros_like(z)
324+
325+
scaled = w / s
326+
if not sym:
327+
scaled -= z
328+
q = torch.clamp(torch.round(scaled), quant_min, quant_max)
329+
dq = q * s
330+
if not sym:
331+
dq += z * s
332+
333+
# propagate column error
334+
Q1[:, i] = dq
335+
losses1[:, i] = (w - dq) ** 2 / d**2
336+
337+
err1 = (w - dq) / d
338+
Err1[:, i] = err1
339+
340+
w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
341+
if W_nz_mask is not None:
342+
mask_slice = W_nz_mask[:, i1 + i : i2]
343+
W1[:, i:] -= w1_err * mask_slice
344+
else:
345+
W1[:, i:] -= w1_err
346+
347+
# propagate block error
348+
W[:, i1:i2] = Q1
349+
losses += torch.sum(losses1.contiguous(), dim=1) / 2
350+
351+
w_err = Err1.matmul(Hinv[i1:i2, i2:])
352+
if W_nz_mask is not None:
353+
mask_slice = W_nz_mask[:, i2:]
354+
W[:, i2:] -= w_err * mask_slice
355+
else:
356+
W[:, i2:] -= w_err
357+
358+
return W, losses
359+
360+
361+
def quantize_weight_optimized(
362+
module: torch.nn.Module,
363+
quant_args: QuantizationArgs,
364+
hessians_dict: Dict[torch.nn.Module, torch.Tensor],
365+
blocksize: int = 128,
366+
percdamp: float = 0.01,
367+
) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]:
368+
"""
369+
Quantize a module weight according to the GPTQ algorithm
370+
371+
This version is faster than the original one with torch.compile support
372+
373+
:param module: module with weight being quantized
374+
:param quant_args: quantization arguments used to find quantization parameters
375+
:param hessian_dict: dictionary containing preaccumulated hessian for quantization
376+
:param blocksize: chunk size of quantization updates
377+
:param percdamp: dampening factor on hessian diagonal
378+
:return: loss, quantized_weight, scale, zero_point, g_idx
379+
"""
380+
strategy = quant_args.strategy
381+
actorder = quant_args.actorder
382+
n_bits = quant_args.num_bits
383+
sym = quant_args.symmetric
384+
final_shape = module.weight.shape
385+
final_dtype = module.weight.dtype
386+
W = module.weight.clone()
387+
H = hessians_dict[module] # unfortunately python does not have a `move` keyword
388+
del hessians_dict[module] # so we have to delete the original reference manually
389+
390+
# create observer for calculating quantization parameters
391+
observer = Observer.load_from_registry(
392+
quant_args.observer,
393+
quantization_args=quant_args,
394+
averaging_constant=1.0, # ignore moving average
395+
)
396+
397+
# standardize shape and dtype
398+
if isinstance(module, torch.nn.Conv2d):
399+
W = W.flatten(1)
400+
elif isinstance(module, transformers.Conv1D):
401+
W.transpose_(0, 1)
402+
W = W.to(dtype=GPTQ_PRECISION)
403+
num_rows = W.shape[0]
404+
num_columns = W.shape[1]
405+
406+
if strategy == QuantizationStrategy.GROUP:
407+
# mapping from column index to group index
408+
g_idx = (
409+
torch.arange(num_columns, device=W.device, dtype=torch.int)
410+
// quant_args.group_size
411+
)
412+
413+
if actorder == ActivationOrdering.GROUP:
414+
# permute by activation order first, then update groups
415+
W, H, perm = _apply_activation_ordering(W, H)
416+
scale, zero_point = observer(W, g_idx=None)
417+
418+
# use identity g_idx (invert permutation later)
419+
420+
elif actorder == ActivationOrdering.WEIGHT:
421+
# update groups first, then permute by activation order
422+
scale, zero_point = observer(W, g_idx=None)
423+
W, H, perm = _apply_activation_ordering(W, H)
424+
425+
# permute g_idx to maintain identity mapping after unpermutation
426+
g_idx = g_idx[perm]
427+
428+
else:
429+
scale, zero_point = observer(W, g_idx=None)
430+
else:
431+
scale, zero_point = observer(W, g_idx=None)
432+
433+
scale = scale.to(W.device)
434+
zero_point = zero_point.to(W.device)
435+
436+
# sparsity mask
437+
sparsity = tensor_sparsity(W)
438+
preserve_zeros = sparsity >= SPARSITY_THRESHOLD
439+
W_nz_mask = (
440+
(~torch.isclose(W, torch.zeros(1, device=W.device, dtype=W.dtype))).float()
441+
if preserve_zeros
442+
else None
443+
)
444+
445+
# mask dead hessian values
446+
dead = torch.diag(H) == 0
447+
H[dead, dead] = 1
448+
W[:, dead] = 0
449+
450+
# compute inverse hessian in place to save memory
451+
try:
452+
damp = percdamp * torch.mean(torch.diag(H))
453+
diag = torch.arange(H.shape[0], device=H.device)
454+
H[diag, diag] += damp
455+
H = torch.linalg.cholesky(H)
456+
H = torch.cholesky_inverse(H)
457+
H = torch.linalg.cholesky(H, upper=True)
458+
Hinv = H
459+
except torch._C._LinAlgError:
460+
logger.warning(
461+
"Failed to invert hessian due to numerical instability. Consider "
462+
"increasing GPTQModifier.dampening_frac, increasing the number "
463+
"of calibration samples, or shuffling the calibration dataset. "
464+
"Falling back to round-to-nearest for this module."
465+
)
466+
Hinv = H = torch.eye(num_columns, dtype=H.dtype, device=H.device)
467+
468+
# See section 3.4 of https://arxiv.org/abs/2203.07259
469+
# quantize column
470+
if strategy == QuantizationStrategy.TENSOR:
471+
scale_map = scale.expand(num_rows, num_columns)
472+
zero_map = zero_point.expand(num_rows, num_columns)
473+
elif strategy == QuantizationStrategy.CHANNEL:
474+
scale_map = scale.expand(-1, num_columns)
475+
zero_map = zero_point.expand(-1, num_columns)
476+
elif strategy == QuantizationStrategy.GROUP:
477+
# get the group index for the current column
478+
scale_map = scale[:, g_idx]
479+
zero_map = zero_point[:, g_idx]
480+
else:
481+
raise ValueError(f"Quantization strategy is not supported for GPTQ: {strategy}")
482+
483+
if sym:
484+
quant_min = -(2 ** (n_bits - 1))
485+
quant_max = 2 ** (n_bits - 1) - 1
486+
else:
487+
quant_min = 0
488+
quant_max = 2**n_bits - 1
489+
490+
W, losses = _quantize_core(
491+
W=W,
492+
Hinv=Hinv,
493+
scale_map=scale_map,
494+
zero_map=zero_map,
495+
W_nz_mask=W_nz_mask,
496+
blocksize=blocksize,
497+
quant_min=quant_min,
498+
quant_max=quant_max,
499+
sym=sym,
500+
num_rows=num_rows,
501+
num_columns=num_columns,
502+
)
503+
504+
has_gidx = False
505+
if strategy == QuantizationStrategy.GROUP:
506+
if actorder == ActivationOrdering.WEIGHT:
507+
# restore original permutation
508+
invperm = torch.argsort(perm)
509+
W = W[:, invperm]
510+
511+
elif actorder == ActivationOrdering.GROUP:
512+
# restore original permutation
513+
invperm = torch.argsort(perm)
514+
W = W[:, invperm]
515+
g_idx = g_idx[invperm]
516+
517+
# only save g_idx if mapping is not identity
518+
has_gidx = True
519+
520+
if not has_gidx:
521+
g_idx = None
522+
523+
if isinstance(module, transformers.Conv1D):
524+
W.transpose_(0, 1)
525+
W = W.reshape(final_shape).to(final_dtype)
526+
527+
loss = torch.sum(losses).item()
528+
return (
529+
loss,
530+
W,
531+
scale.to(dtype=final_dtype),
532+
zero_point.to(dtype=quant_args.pytorch_dtype()),
533+
g_idx,
534+
)
535+
536+
286537
def _apply_activation_ordering(
287538
W: torch.Tensor, H: torch.Tensor
288539
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

0 commit comments

Comments
 (0)