Skip to content

Commit 200ceac

Browse files
committed
Fix long compilation issue
Signed-off-by: aladerran <aladerran@gmail.com>
1 parent f4a3419 commit 200ceac

File tree

1 file changed

+63
-42
lines changed

1 file changed

+63
-42
lines changed

src/llmcompressor/modifiers/quantization/gptq/gptq_quantize.py

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,56 @@ def quantize_weight(
286286

287287

288288
@torch.compile(dynamic=True)
289+
def _process_block(
290+
W1: torch.Tensor,
291+
Hinv1: torch.Tensor,
292+
scale_slice: torch.Tensor,
293+
zero_slice: torch.Tensor,
294+
mask_slice: Optional[torch.Tensor],
295+
quant_min: int,
296+
quant_max: int,
297+
sym: bool,
298+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
299+
count = W1.shape[1]
300+
Q1 = torch.zeros_like(W1)
301+
Err1 = torch.zeros_like(W1)
302+
losses1 = torch.zeros_like(W1)
303+
304+
for i in range(count):
305+
w = W1[:, i]
306+
d = Hinv1[i, i]
307+
308+
s = scale_slice[:, i]
309+
z = zero_slice[:, i]
310+
311+
if sym:
312+
z = torch.zeros_like(z)
313+
314+
scaled = w / s
315+
if not sym:
316+
scaled -= z
317+
q = torch.clamp(torch.round(scaled), quant_min, quant_max)
318+
dq = q * s
319+
if not sym:
320+
dq += z * s
321+
322+
err1 = (w - dq) / d
323+
loss_col = (w - dq) ** 2 / d**2
324+
325+
Q1[:, i] = dq
326+
Err1[:, i] = err1
327+
losses1[:, i] = loss_col
328+
329+
w1_err = err1.unsqueeze(1) @ Hinv1[i, i:].unsqueeze(0)
330+
if mask_slice is not None:
331+
mask_block = mask_slice[:, i:]
332+
W1[:, i:] -= w1_err * mask_block
333+
else:
334+
W1[:, i:] -= w1_err
335+
336+
return Q1, Err1, losses1
337+
338+
289339
def _quantize_core(
290340
W: torch.Tensor,
291341
Hinv: torch.Tensor,
@@ -303,55 +353,26 @@ def _quantize_core(
303353

304354
for i1 in range(0, num_columns, blocksize):
305355
i2 = min(i1 + blocksize, num_columns)
306-
count = i2 - i1
307356

308-
W1 = W[:, i1:i2].clone().contiguous()
309-
Q1 = torch.zeros_like(W1)
310-
Err1 = torch.zeros_like(W1)
311-
losses1 = torch.zeros_like(W1)
357+
W1 = W[:, i1:i2].clone()
312358
Hinv1 = Hinv[i1:i2, i1:i2].contiguous()
359+
scale_slice = scale_map[:, i1:i2]
360+
zero_slice = zero_map[:, i1:i2]
361+
mask_slice = None
362+
if W_nz_mask is not None:
363+
mask_slice = W_nz_mask[:, i1:i2]
313364

314-
for i in range(count):
315-
col_idx = i1 + i
316-
w = W1[:, i]
317-
d = Hinv1[i, i]
318-
319-
s = scale_map[:, col_idx]
320-
z = zero_map[:, col_idx]
321-
322-
if sym:
323-
z = torch.zeros_like(z)
324-
325-
scaled = w / s
326-
if not sym:
327-
scaled -= z
328-
q = torch.clamp(torch.round(scaled), quant_min, quant_max)
329-
dq = q * s
330-
if not sym:
331-
dq += z * s
332-
333-
# propagate column error
334-
Q1[:, i] = dq
335-
losses1[:, i] = (w - dq) ** 2 / d**2
336-
337-
err1 = (w - dq) / d
338-
Err1[:, i] = err1
339-
340-
w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
341-
if W_nz_mask is not None:
342-
mask_slice = W_nz_mask[:, i1 + i : i2]
343-
W1[:, i:] -= w1_err * mask_slice
344-
else:
345-
W1[:, i:] -= w1_err
365+
Q1, Err1, losses1 = _process_block(
366+
W1, Hinv1, scale_slice, zero_slice, mask_slice, quant_min, quant_max, sym
367+
)
346368

347-
# propagate block error
348369
W[:, i1:i2] = Q1
349-
losses += torch.sum(losses1.contiguous(), dim=1) / 2
370+
losses += losses1.sum(dim=1) / 2
350371

351-
w_err = Err1.matmul(Hinv[i1:i2, i2:])
372+
w_err = Err1 @ Hinv[i1:i2, i2:]
352373
if W_nz_mask is not None:
353-
mask_slice = W_nz_mask[:, i2:]
354-
W[:, i2:] -= w_err * mask_slice
374+
mask_rest = W_nz_mask[:, i2:]
375+
W[:, i2:] -= w_err * mask_rest
355376
else:
356377
W[:, i2:] -= w_err
357378

0 commit comments

Comments
 (0)