|
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 | import torch._dynamo.config
|
| 7 | +import torch._inductor.config |
7 | 8 | import transformers
|
8 | 9 | from compressed_tensors.quantization import (
|
9 | 10 | ActivationOrdering,
|
|
18 | 19 | from llmcompressor.pytorch.utils.helpers import tensor_sparsity
|
19 | 20 |
|
20 | 21 | torch._dynamo.config.capture_scalar_outputs = True
|
| 22 | +torch._inductor.config.triton.tile_reductions = True |
| 23 | +torch.set_float32_matmul_precision("high") |
21 | 24 |
|
22 | 25 | GPTQ_PRECISION = torch.float32
|
23 | 26 |
|
@@ -71,7 +74,6 @@ def accumulate_hessian(
|
71 | 74 | return H, num_samples
|
72 | 75 |
|
73 | 76 |
|
74 |
| -@torch.compile |
75 | 77 | def quantize_weight(
|
76 | 78 | module: torch.nn.Module,
|
77 | 79 | quant_args: QuantizationArgs,
|
@@ -283,6 +285,255 @@ def quantize_weight(
|
283 | 285 | )
|
284 | 286 |
|
285 | 287 |
|
| 288 | +@torch.compile(dynamic=True) |
| 289 | +def _quantize_core( |
| 290 | + W: torch.Tensor, |
| 291 | + Hinv: torch.Tensor, |
| 292 | + scale_map: torch.Tensor, |
| 293 | + zero_map: torch.Tensor, |
| 294 | + W_nz_mask: Optional[torch.Tensor], |
| 295 | + blocksize: int, |
| 296 | + quant_min: int, |
| 297 | + quant_max: int, |
| 298 | + sym: bool, |
| 299 | + num_rows: int, |
| 300 | + num_columns: int, |
| 301 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 302 | + losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype) |
| 303 | + |
| 304 | + for i1 in range(0, num_columns, blocksize): |
| 305 | + i2 = min(i1 + blocksize, num_columns) |
| 306 | + count = i2 - i1 |
| 307 | + |
| 308 | + W1 = W[:, i1:i2].clone().contiguous() |
| 309 | + Q1 = torch.zeros_like(W1) |
| 310 | + Err1 = torch.zeros_like(W1) |
| 311 | + losses1 = torch.zeros_like(W1) |
| 312 | + Hinv1 = Hinv[i1:i2, i1:i2].contiguous() |
| 313 | + |
| 314 | + for i in range(count): |
| 315 | + col_idx = i1 + i |
| 316 | + w = W1[:, i] |
| 317 | + d = Hinv1[i, i] |
| 318 | + |
| 319 | + s = scale_map[:, col_idx] |
| 320 | + z = zero_map[:, col_idx] |
| 321 | + |
| 322 | + if sym: |
| 323 | + z = torch.zeros_like(z) |
| 324 | + |
| 325 | + scaled = w / s |
| 326 | + if not sym: |
| 327 | + scaled -= z |
| 328 | + q = torch.clamp(torch.round(scaled), quant_min, quant_max) |
| 329 | + dq = q * s |
| 330 | + if not sym: |
| 331 | + dq += z * s |
| 332 | + |
| 333 | + # propagate column error |
| 334 | + Q1[:, i] = dq |
| 335 | + losses1[:, i] = (w - dq) ** 2 / d**2 |
| 336 | + |
| 337 | + err1 = (w - dq) / d |
| 338 | + Err1[:, i] = err1 |
| 339 | + |
| 340 | + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) |
| 341 | + if W_nz_mask is not None: |
| 342 | + mask_slice = W_nz_mask[:, i1 + i : i2] |
| 343 | + W1[:, i:] -= w1_err * mask_slice |
| 344 | + else: |
| 345 | + W1[:, i:] -= w1_err |
| 346 | + |
| 347 | + # propagate block error |
| 348 | + W[:, i1:i2] = Q1 |
| 349 | + losses += torch.sum(losses1.contiguous(), dim=1) / 2 |
| 350 | + |
| 351 | + w_err = Err1.matmul(Hinv[i1:i2, i2:]) |
| 352 | + if W_nz_mask is not None: |
| 353 | + mask_slice = W_nz_mask[:, i2:] |
| 354 | + W[:, i2:] -= w_err * mask_slice |
| 355 | + else: |
| 356 | + W[:, i2:] -= w_err |
| 357 | + |
| 358 | + return W, losses |
| 359 | + |
| 360 | + |
| 361 | +def quantize_weight_optimized( |
| 362 | + module: torch.nn.Module, |
| 363 | + quant_args: QuantizationArgs, |
| 364 | + hessians_dict: Dict[torch.nn.Module, torch.Tensor], |
| 365 | + blocksize: int = 128, |
| 366 | + percdamp: float = 0.01, |
| 367 | +) -> Tuple[float, torch.Tensor, torch.Tensor, Union[torch.Tensor, None], torch.Tensor]: |
| 368 | + """ |
| 369 | + Quantize a module weight according to the GPTQ algorithm |
| 370 | +
|
| 371 | + This version is faster than the original one with torch.compile support |
| 372 | +
|
| 373 | + :param module: module with weight being quantized |
| 374 | + :param quant_args: quantization arguments used to find quantization parameters |
| 375 | + :param hessian_dict: dictionary containing preaccumulated hessian for quantization |
| 376 | + :param blocksize: chunk size of quantization updates |
| 377 | + :param percdamp: dampening factor on hessian diagonal |
| 378 | + :return: loss, quantized_weight, scale, zero_point, g_idx |
| 379 | + """ |
| 380 | + strategy = quant_args.strategy |
| 381 | + actorder = quant_args.actorder |
| 382 | + n_bits = quant_args.num_bits |
| 383 | + sym = quant_args.symmetric |
| 384 | + final_shape = module.weight.shape |
| 385 | + final_dtype = module.weight.dtype |
| 386 | + W = module.weight.clone() |
| 387 | + H = hessians_dict[module] # unfortunately python does not have a `move` keyword |
| 388 | + del hessians_dict[module] # so we have to delete the original reference manually |
| 389 | + |
| 390 | + # create observer for calculating quantization parameters |
| 391 | + observer = Observer.load_from_registry( |
| 392 | + quant_args.observer, |
| 393 | + quantization_args=quant_args, |
| 394 | + averaging_constant=1.0, # ignore moving average |
| 395 | + ) |
| 396 | + |
| 397 | + # standardize shape and dtype |
| 398 | + if isinstance(module, torch.nn.Conv2d): |
| 399 | + W = W.flatten(1) |
| 400 | + elif isinstance(module, transformers.Conv1D): |
| 401 | + W.transpose_(0, 1) |
| 402 | + W = W.to(dtype=GPTQ_PRECISION) |
| 403 | + num_rows = W.shape[0] |
| 404 | + num_columns = W.shape[1] |
| 405 | + |
| 406 | + if strategy == QuantizationStrategy.GROUP: |
| 407 | + # mapping from column index to group index |
| 408 | + g_idx = ( |
| 409 | + torch.arange(num_columns, device=W.device, dtype=torch.int) |
| 410 | + // quant_args.group_size |
| 411 | + ) |
| 412 | + |
| 413 | + if actorder == ActivationOrdering.GROUP: |
| 414 | + # permute by activation order first, then update groups |
| 415 | + W, H, perm = _apply_activation_ordering(W, H) |
| 416 | + scale, zero_point = observer(W, g_idx=None) |
| 417 | + |
| 418 | + # use identity g_idx (invert permutation later) |
| 419 | + |
| 420 | + elif actorder == ActivationOrdering.WEIGHT: |
| 421 | + # update groups first, then permute by activation order |
| 422 | + scale, zero_point = observer(W, g_idx=None) |
| 423 | + W, H, perm = _apply_activation_ordering(W, H) |
| 424 | + |
| 425 | + # permute g_idx to maintain identity mapping after unpermutation |
| 426 | + g_idx = g_idx[perm] |
| 427 | + |
| 428 | + else: |
| 429 | + scale, zero_point = observer(W, g_idx=None) |
| 430 | + else: |
| 431 | + scale, zero_point = observer(W, g_idx=None) |
| 432 | + |
| 433 | + scale = scale.to(W.device) |
| 434 | + zero_point = zero_point.to(W.device) |
| 435 | + |
| 436 | + # sparsity mask |
| 437 | + sparsity = tensor_sparsity(W) |
| 438 | + preserve_zeros = sparsity >= SPARSITY_THRESHOLD |
| 439 | + W_nz_mask = ( |
| 440 | + (~torch.isclose(W, torch.zeros(1, device=W.device, dtype=W.dtype))).float() |
| 441 | + if preserve_zeros |
| 442 | + else None |
| 443 | + ) |
| 444 | + |
| 445 | + # mask dead hessian values |
| 446 | + dead = torch.diag(H) == 0 |
| 447 | + H[dead, dead] = 1 |
| 448 | + W[:, dead] = 0 |
| 449 | + |
| 450 | + # compute inverse hessian in place to save memory |
| 451 | + try: |
| 452 | + damp = percdamp * torch.mean(torch.diag(H)) |
| 453 | + diag = torch.arange(H.shape[0], device=H.device) |
| 454 | + H[diag, diag] += damp |
| 455 | + H = torch.linalg.cholesky(H) |
| 456 | + H = torch.cholesky_inverse(H) |
| 457 | + H = torch.linalg.cholesky(H, upper=True) |
| 458 | + Hinv = H |
| 459 | + except torch._C._LinAlgError: |
| 460 | + logger.warning( |
| 461 | + "Failed to invert hessian due to numerical instability. Consider " |
| 462 | + "increasing GPTQModifier.dampening_frac, increasing the number " |
| 463 | + "of calibration samples, or shuffling the calibration dataset. " |
| 464 | + "Falling back to round-to-nearest for this module." |
| 465 | + ) |
| 466 | + Hinv = H = torch.eye(num_columns, dtype=H.dtype, device=H.device) |
| 467 | + |
| 468 | + # See section 3.4 of https://arxiv.org/abs/2203.07259 |
| 469 | + # quantize column |
| 470 | + if strategy == QuantizationStrategy.TENSOR: |
| 471 | + scale_map = scale.expand(num_rows, num_columns) |
| 472 | + zero_map = zero_point.expand(num_rows, num_columns) |
| 473 | + elif strategy == QuantizationStrategy.CHANNEL: |
| 474 | + scale_map = scale.expand(-1, num_columns) |
| 475 | + zero_map = zero_point.expand(-1, num_columns) |
| 476 | + elif strategy == QuantizationStrategy.GROUP: |
| 477 | + # get the group index for the current column |
| 478 | + scale_map = scale[:, g_idx] |
| 479 | + zero_map = zero_point[:, g_idx] |
| 480 | + else: |
| 481 | + raise ValueError(f"Quantization strategy is not supported for GPTQ: {strategy}") |
| 482 | + |
| 483 | + if sym: |
| 484 | + quant_min = -(2 ** (n_bits - 1)) |
| 485 | + quant_max = 2 ** (n_bits - 1) - 1 |
| 486 | + else: |
| 487 | + quant_min = 0 |
| 488 | + quant_max = 2**n_bits - 1 |
| 489 | + |
| 490 | + W, losses = _quantize_core( |
| 491 | + W=W, |
| 492 | + Hinv=Hinv, |
| 493 | + scale_map=scale_map, |
| 494 | + zero_map=zero_map, |
| 495 | + W_nz_mask=W_nz_mask, |
| 496 | + blocksize=blocksize, |
| 497 | + quant_min=quant_min, |
| 498 | + quant_max=quant_max, |
| 499 | + sym=sym, |
| 500 | + num_rows=num_rows, |
| 501 | + num_columns=num_columns, |
| 502 | + ) |
| 503 | + |
| 504 | + has_gidx = False |
| 505 | + if strategy == QuantizationStrategy.GROUP: |
| 506 | + if actorder == ActivationOrdering.WEIGHT: |
| 507 | + # restore original permutation |
| 508 | + invperm = torch.argsort(perm) |
| 509 | + W = W[:, invperm] |
| 510 | + |
| 511 | + elif actorder == ActivationOrdering.GROUP: |
| 512 | + # restore original permutation |
| 513 | + invperm = torch.argsort(perm) |
| 514 | + W = W[:, invperm] |
| 515 | + g_idx = g_idx[invperm] |
| 516 | + |
| 517 | + # only save g_idx if mapping is not identity |
| 518 | + has_gidx = True |
| 519 | + |
| 520 | + if not has_gidx: |
| 521 | + g_idx = None |
| 522 | + |
| 523 | + if isinstance(module, transformers.Conv1D): |
| 524 | + W.transpose_(0, 1) |
| 525 | + W = W.reshape(final_shape).to(final_dtype) |
| 526 | + |
| 527 | + loss = torch.sum(losses).item() |
| 528 | + return ( |
| 529 | + loss, |
| 530 | + W, |
| 531 | + scale.to(dtype=final_dtype), |
| 532 | + zero_point.to(dtype=quant_args.pytorch_dtype()), |
| 533 | + g_idx, |
| 534 | + ) |
| 535 | + |
| 536 | + |
286 | 537 | def _apply_activation_ordering(
|
287 | 538 | W: torch.Tensor, H: torch.Tensor
|
288 | 539 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
0 commit comments