Skip to content

Commit 2c688cb

Browse files
committed
Revise & add comments
1 parent 6a5e420 commit 2c688cb

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
from llmcompressor.observers.base import Observer
1919
from llmcompressor.pytorch.utils.helpers import tensor_sparsity
2020

21-
torch._dynamo.config.capture_scalar_outputs = True
22-
torch._inductor.config.triton.tile_reductions = True
23-
torch.set_float32_matmul_precision("high")
24-
2521
GPTQ_PRECISION = torch.float32
2622

2723
__all__ = ["make_empty_hessian", "accumulate_hessian", "quantize_weight"]
@@ -74,7 +70,7 @@ def accumulate_hessian(
7470
return H, num_samples
7571

7672

77-
def quantize_weight(
73+
def quantize_weight_original(
7874
module: torch.nn.Module,
7975
quant_args: QuantizationArgs,
8076
hessians_dict: Dict[torch.nn.Module, torch.Tensor],
@@ -296,6 +292,7 @@ def _process_block(
296292
quant_max: int,
297293
sym: bool,
298294
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
295+
"""Process a single block of weight columns using with torch.compile support."""
299296
count = W1.shape[1]
300297
Q1 = torch.zeros_like(W1)
301298
Err1 = torch.zeros_like(W1)
@@ -349,11 +346,13 @@ def _quantize_core(
349346
num_rows: int,
350347
num_columns: int,
351348
) -> Tuple[torch.Tensor, torch.Tensor]:
349+
"""Core GPTQ quantization loop processing weights in blocks."""
352350
losses = torch.zeros(num_rows, device=W.device, dtype=W.dtype)
353351

354352
for i1 in range(0, num_columns, blocksize):
355353
i2 = min(i1 + blocksize, num_columns)
356354

355+
# Extract current block and corresponding Hessian/quantization params
357356
W1 = W[:, i1:i2].clone()
358357
Hinv1 = Hinv[i1:i2, i1:i2].contiguous()
359358
scale_slice = scale_map[:, i1:i2]
@@ -362,13 +361,15 @@ def _quantize_core(
362361
if W_nz_mask is not None:
363362
mask_slice = W_nz_mask[:, i1:i2]
364363

364+
# Quantize the current block
365365
Q1, Err1, losses1 = _process_block(
366366
W1, Hinv1, scale_slice, zero_slice, mask_slice, quant_min, quant_max, sym
367367
)
368368

369369
W[:, i1:i2] = Q1
370370
losses += losses1.sum(dim=1) / 2
371371

372+
# Propagate block error to remaining unprocessed columns
372373
w_err = Err1 @ Hinv[i1:i2, i2:]
373374
if W_nz_mask is not None:
374375
mask_rest = W_nz_mask[:, i2:]
@@ -379,7 +380,7 @@ def _quantize_core(
379380
return W, losses
380381

381382

382-
def quantize_weight_optimized(
383+
def quantize_weight(
383384
module: torch.nn.Module,
384385
quant_args: QuantizationArgs,
385386
hessians_dict: Dict[torch.nn.Module, torch.Tensor],

0 commit comments

Comments
 (0)