Skip to content

Commit 2b46766

Browse files
AWQ minor performance improvements to smoothing
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 1c4f639 commit 2b46766

File tree

1 file changed

+33
-40
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+33
-40
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,8 @@ def _apply_smoothing(self, model: Module) -> None:
474474
with calibration_forward_context(model), HooksMixin.disable_hooks():
475475
# [STEP 3]: Compute output of module
476476
# could cache from hook, rather than recomputing here
477-
fp16_outputs = self._run_samples(parent_module)
478-
if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs):
477+
fp16_output = self._get_flattened_output(parent_module)
478+
if fp16_output.numel() == 0:
479479
logger.info(
480480
f"Skipping smooth_layer {mapping.smooth_name}, no activations "
481481
"found to scale. This can occasionally occur in MoE models "
@@ -488,7 +488,7 @@ def _apply_smoothing(self, model: Module) -> None:
488488

489489
# [STEP 4]: Compute loss
490490
best_scales = self._compute_best_scale(
491-
x_mean, w_mean, parent_module, balance_layers, fp16_outputs
491+
x_mean, w_mean, parent_module, balance_layers, fp16_output
492492
)
493493

494494
@torch.no_grad()
@@ -541,25 +541,35 @@ def smooth(module):
541541
v.batch_intermediates.clear()
542542
self._assert_all_activations_consumed()
543543

544-
def _run_samples(self, module: Module) -> List[torch.Tensor]:
544+
def _get_flattened_output(self, module: Module) -> torch.Tensor:
545+
"""
546+
Returns output of running cached batch inputs through module
547+
Output tensor is 1D, as shapes aren't necessary for calculating loss
548+
"""
545549
with align_module_device(module):
546550
outputs = [
547551
module(**batch_kwargs)
548552
for batch_kwargs in self._parent_args_cache[module]
549553
]
550-
return [
551-
# If Tuple, assume that first argument is the input
552-
output[0] if isinstance(output, Tuple) else output
553-
for output in outputs
554-
]
554+
return torch.cat(
555+
[
556+
# If Tuple, assume that first argument is the input
557+
(
558+
output[0].reshape(-1)
559+
if isinstance(output, Tuple)
560+
else output.reshape(-1)
561+
)
562+
for output in outputs
563+
]
564+
)
555565

556566
def _compute_best_scale(
557567
self,
558568
x_mean: torch.Tensor,
559569
w_mean: torch.Tensor,
560570
parent_module: torch.nn.Module,
561571
linears2scale: List[torch.nn.Linear],
562-
fp16_outputs: List[torch.Tensor],
572+
fp16_output: torch.Tensor,
563573
) -> torch.Tensor:
564574
"""
565575
Compute loss and select best scales
@@ -618,10 +628,10 @@ def _compute_best_scale(
618628

619629
# W * X
620630
with HooksMixin.disable_hooks():
621-
int_w_outputs = self._run_samples(parent_module)
631+
int_w_output = self._get_flattened_output(parent_module)
622632

623633
# compute mean squared error (L2 norm)
624-
loss = self._compute_loss(fp16_outputs, int_w_outputs, device)
634+
loss = _compute_loss(fp16_output, int_w_output)
625635

626636
history.append(loss)
627637
if loss < best_error:
@@ -640,34 +650,6 @@ def _compute_best_scale(
640650

641651
return best_scales.detach().cpu()
642652

643-
@torch.no_grad()
644-
def _compute_loss(
645-
self,
646-
fp16_outputs: List[torch.Tensor],
647-
int_w_outputs: List[torch.Tensor],
648-
device: torch.device,
649-
) -> torch.Tensor:
650-
loss = 0.0
651-
num_elements = 0
652-
653-
# Compute the MSE loss for each batch
654-
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
655-
batch_loss = (
656-
(fp16_batch.to(device) - int_w_batch.to(device))
657-
.view(-1)
658-
.float()
659-
.pow(2)
660-
.sum()
661-
.item()
662-
)
663-
loss += batch_loss
664-
num_elements += fp16_batch.numel()
665-
666-
# Normalize the loss by the total number of elements
667-
loss /= num_elements
668-
669-
return loss
670-
671653
def _assert_all_activations_consumed(self):
672654
"""
673655
Confirm all activations have been consumed
@@ -677,6 +659,17 @@ def _assert_all_activations_consumed(self):
677659
raise RuntimeError("Some cached activations were not used")
678660

679661

662+
@torch.no_grad()
663+
@torch.compile()
664+
def _compute_loss(
665+
fp16_output: torch.Tensor,
666+
int_w_output: torch.Tensor,
667+
) -> torch.Tensor:
668+
"""Compute MSE loss for each batch"""
669+
return (fp16_output - int_w_output).view(-1).float().pow(2).mean()
670+
671+
672+
@torch.compile()
680673
def _pseudo_quantize_tensor(
681674
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
682675
):

0 commit comments

Comments
 (0)