@@ -286,6 +286,56 @@ def quantize_weight(
286
286
287
287
288
288
@torch .compile (dynamic = True )
289
+ def _process_block (
290
+ W1 : torch .Tensor ,
291
+ Hinv1 : torch .Tensor ,
292
+ scale_slice : torch .Tensor ,
293
+ zero_slice : torch .Tensor ,
294
+ mask_slice : Optional [torch .Tensor ],
295
+ quant_min : int ,
296
+ quant_max : int ,
297
+ sym : bool ,
298
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
299
+ count = W1 .shape [1 ]
300
+ Q1 = torch .zeros_like (W1 )
301
+ Err1 = torch .zeros_like (W1 )
302
+ losses1 = torch .zeros_like (W1 )
303
+
304
+ for i in range (count ):
305
+ w = W1 [:, i ]
306
+ d = Hinv1 [i , i ]
307
+
308
+ s = scale_slice [:, i ]
309
+ z = zero_slice [:, i ]
310
+
311
+ if sym :
312
+ z = torch .zeros_like (z )
313
+
314
+ scaled = w / s
315
+ if not sym :
316
+ scaled -= z
317
+ q = torch .clamp (torch .round (scaled ), quant_min , quant_max )
318
+ dq = q * s
319
+ if not sym :
320
+ dq += z * s
321
+
322
+ err1 = (w - dq ) / d
323
+ loss_col = (w - dq ) ** 2 / d ** 2
324
+
325
+ Q1 [:, i ] = dq
326
+ Err1 [:, i ] = err1
327
+ losses1 [:, i ] = loss_col
328
+
329
+ w1_err = err1 .unsqueeze (1 ) @ Hinv1 [i , i :].unsqueeze (0 )
330
+ if mask_slice is not None :
331
+ mask_block = mask_slice [:, i :]
332
+ W1 [:, i :] -= w1_err * mask_block
333
+ else :
334
+ W1 [:, i :] -= w1_err
335
+
336
+ return Q1 , Err1 , losses1
337
+
338
+
289
339
def _quantize_core (
290
340
W : torch .Tensor ,
291
341
Hinv : torch .Tensor ,
@@ -303,55 +353,26 @@ def _quantize_core(
303
353
304
354
for i1 in range (0 , num_columns , blocksize ):
305
355
i2 = min (i1 + blocksize , num_columns )
306
- count = i2 - i1
307
356
308
- W1 = W [:, i1 :i2 ].clone ().contiguous ()
309
- Q1 = torch .zeros_like (W1 )
310
- Err1 = torch .zeros_like (W1 )
311
- losses1 = torch .zeros_like (W1 )
357
+ W1 = W [:, i1 :i2 ].clone ()
312
358
Hinv1 = Hinv [i1 :i2 , i1 :i2 ].contiguous ()
359
+ scale_slice = scale_map [:, i1 :i2 ]
360
+ zero_slice = zero_map [:, i1 :i2 ]
361
+ mask_slice = None
362
+ if W_nz_mask is not None :
363
+ mask_slice = W_nz_mask [:, i1 :i2 ]
313
364
314
- for i in range (count ):
315
- col_idx = i1 + i
316
- w = W1 [:, i ]
317
- d = Hinv1 [i , i ]
318
-
319
- s = scale_map [:, col_idx ]
320
- z = zero_map [:, col_idx ]
321
-
322
- if sym :
323
- z = torch .zeros_like (z )
324
-
325
- scaled = w / s
326
- if not sym :
327
- scaled -= z
328
- q = torch .clamp (torch .round (scaled ), quant_min , quant_max )
329
- dq = q * s
330
- if not sym :
331
- dq += z * s
332
-
333
- # propagate column error
334
- Q1 [:, i ] = dq
335
- losses1 [:, i ] = (w - dq ) ** 2 / d ** 2
336
-
337
- err1 = (w - dq ) / d
338
- Err1 [:, i ] = err1
339
-
340
- w1_err = err1 .unsqueeze (1 ).matmul (Hinv1 [i , i :].unsqueeze (0 ))
341
- if W_nz_mask is not None :
342
- mask_slice = W_nz_mask [:, i1 + i : i2 ]
343
- W1 [:, i :] -= w1_err * mask_slice
344
- else :
345
- W1 [:, i :] -= w1_err
365
+ Q1 , Err1 , losses1 = _process_block (
366
+ W1 , Hinv1 , scale_slice , zero_slice , mask_slice , quant_min , quant_max , sym
367
+ )
346
368
347
- # propagate block error
348
369
W [:, i1 :i2 ] = Q1
349
- losses += torch .sum (losses1 . contiguous (), dim = 1 ) / 2
370
+ losses += losses1 .sum (dim = 1 ) / 2
350
371
351
- w_err = Err1 . matmul ( Hinv [i1 :i2 , i2 :])
372
+ w_err = Err1 @ Hinv [i1 :i2 , i2 :]
352
373
if W_nz_mask is not None :
353
- mask_slice = W_nz_mask [:, i2 :]
354
- W [:, i2 :] -= w_err * mask_slice
374
+ mask_rest = W_nz_mask [:, i2 :]
375
+ W [:, i2 :] -= w_err * mask_rest
355
376
else :
356
377
W [:, i2 :] -= w_err
357
378
0 commit comments