Skip to content

Commit 3967e84

Browse files
decrease memory when calculating w_mean
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 6687eff commit 3967e84

File tree

1 file changed

+5
-3
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+5
-3
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -465,11 +465,13 @@ def _apply_smoothing(self, model: Module) -> None:
465465
# Calculates the relative magnitude of the weights within
466466
# each of the quantization groups, and rescales each group
467467
# individually so that each group has weights on a 0-1 scale.
468-
w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6)
468+
weight.abs_()
469+
weight.div_(weight.amax(dim=1, keepdim=True) + 1e-6)
469470
# Resizes the rescaled weight matrix back up to its original dimensions
470-
w_scale = w_scale.view(org_shape)
471+
weight = weight.view(org_shape)
471472
# Gets the average rescaled magnitude for each output channel
472-
w_mean = w_scale.mean(0)
473+
w_mean = weight.mean(0)
474+
del weight
473475

474476
with calibration_forward_context(model), HooksMixin.disable_hooks():
475477
# [STEP 3]: Compute output of module

0 commit comments

Comments
 (0)