18
18
from llmcompressor .observers .base import Observer
19
19
from llmcompressor .pytorch .utils .helpers import tensor_sparsity
20
20
21
- torch ._dynamo .config .capture_scalar_outputs = True
22
- torch ._inductor .config .triton .tile_reductions = True
23
- torch .set_float32_matmul_precision ("high" )
24
-
25
21
GPTQ_PRECISION = torch .float32
26
22
27
23
__all__ = ["make_empty_hessian" , "accumulate_hessian" , "quantize_weight" ]
@@ -74,7 +70,7 @@ def accumulate_hessian(
74
70
return H , num_samples
75
71
76
72
77
- def quantize_weight (
73
+ def quantize_weight_original (
78
74
module : torch .nn .Module ,
79
75
quant_args : QuantizationArgs ,
80
76
hessians_dict : Dict [torch .nn .Module , torch .Tensor ],
@@ -296,6 +292,7 @@ def _process_block(
296
292
quant_max : int ,
297
293
sym : bool ,
298
294
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
295
+ """Process a single block of weight columns using with torch.compile support."""
299
296
count = W1 .shape [1 ]
300
297
Q1 = torch .zeros_like (W1 )
301
298
Err1 = torch .zeros_like (W1 )
@@ -349,11 +346,13 @@ def _quantize_core(
349
346
num_rows : int ,
350
347
num_columns : int ,
351
348
) -> Tuple [torch .Tensor , torch .Tensor ]:
349
+ """Core GPTQ quantization loop processing weights in blocks."""
352
350
losses = torch .zeros (num_rows , device = W .device , dtype = W .dtype )
353
351
354
352
for i1 in range (0 , num_columns , blocksize ):
355
353
i2 = min (i1 + blocksize , num_columns )
356
354
355
+ # Extract current block and corresponding Hessian/quantization params
357
356
W1 = W [:, i1 :i2 ].clone ()
358
357
Hinv1 = Hinv [i1 :i2 , i1 :i2 ].contiguous ()
359
358
scale_slice = scale_map [:, i1 :i2 ]
@@ -362,13 +361,15 @@ def _quantize_core(
362
361
if W_nz_mask is not None :
363
362
mask_slice = W_nz_mask [:, i1 :i2 ]
364
363
364
+ # Quantize the current block
365
365
Q1 , Err1 , losses1 = _process_block (
366
366
W1 , Hinv1 , scale_slice , zero_slice , mask_slice , quant_min , quant_max , sym
367
367
)
368
368
369
369
W [:, i1 :i2 ] = Q1
370
370
losses += losses1 .sum (dim = 1 ) / 2
371
371
372
+ # Propagate block error to remaining unprocessed columns
372
373
w_err = Err1 @ Hinv [i1 :i2 , i2 :]
373
374
if W_nz_mask is not None :
374
375
mask_rest = W_nz_mask [:, i2 :]
@@ -379,7 +380,7 @@ def _quantize_core(
379
380
return W , losses
380
381
381
382
382
- def quantize_weight_optimized (
383
+ def quantize_weight (
383
384
module : torch .nn .Module ,
384
385
quant_args : QuantizationArgs ,
385
386
hessians_dict : Dict [torch .nn .Module , torch .Tensor ],
0 commit comments