From 2b4676682ebeaa8e34d358a51d572c0bb794babb Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 16 Jun 2025 20:10:26 +0000 Subject: [PATCH 1/3] AWQ minor performance improvements to smoothing Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 73 +++++++++++-------------- 1 file changed, 33 insertions(+), 40 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 1bec18e2a..28f3ced65 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -474,8 +474,8 @@ def _apply_smoothing(self, model: Module) -> None: with calibration_forward_context(model), HooksMixin.disable_hooks(): # [STEP 3]: Compute output of module # could cache from hook, rather than recomputing here - fp16_outputs = self._run_samples(parent_module) - if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): + fp16_output = self._get_flattened_output(parent_module) + if fp16_output.numel() == 0: logger.info( f"Skipping smooth_layer {mapping.smooth_name}, no activations " "found to scale. This can occasionally occur in MoE models " @@ -488,7 +488,7 @@ def _apply_smoothing(self, model: Module) -> None: # [STEP 4]: Compute loss best_scales = self._compute_best_scale( - x_mean, w_mean, parent_module, balance_layers, fp16_outputs + x_mean, w_mean, parent_module, balance_layers, fp16_output ) @torch.no_grad() @@ -541,17 +541,27 @@ def smooth(module): v.batch_intermediates.clear() self._assert_all_activations_consumed() - def _run_samples(self, module: Module) -> List[torch.Tensor]: + def _get_flattened_output(self, module: Module) -> torch.Tensor: + """ + Returns output of running cached batch inputs through module + Output tensor is 1D, as shapes aren't necessary for calculating loss + """ with align_module_device(module): outputs = [ module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] ] - return [ - # If Tuple, assume that first argument is the input - output[0] if isinstance(output, Tuple) else output - for output in outputs - ] + return torch.cat( + [ + # If Tuple, assume that first argument is the input + ( + output[0].reshape(-1) + if isinstance(output, Tuple) + else output.reshape(-1) + ) + for output in outputs + ] + ) def _compute_best_scale( self, @@ -559,7 +569,7 @@ def _compute_best_scale( w_mean: torch.Tensor, parent_module: torch.nn.Module, linears2scale: List[torch.nn.Linear], - fp16_outputs: List[torch.Tensor], + fp16_output: torch.Tensor, ) -> torch.Tensor: """ Compute loss and select best scales @@ -618,10 +628,10 @@ def _compute_best_scale( # W * X with HooksMixin.disable_hooks(): - int_w_outputs = self._run_samples(parent_module) + int_w_output = self._get_flattened_output(parent_module) # compute mean squared error (L2 norm) - loss = self._compute_loss(fp16_outputs, int_w_outputs, device) + loss = _compute_loss(fp16_output, int_w_output) history.append(loss) if loss < best_error: @@ -640,34 +650,6 @@ def _compute_best_scale( return best_scales.detach().cpu() - @torch.no_grad() - def _compute_loss( - self, - fp16_outputs: List[torch.Tensor], - int_w_outputs: List[torch.Tensor], - device: torch.device, - ) -> torch.Tensor: - loss = 0.0 - num_elements = 0 - - # Compute the MSE loss for each batch - for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): - batch_loss = ( - (fp16_batch.to(device) - int_w_batch.to(device)) - .view(-1) - .float() - .pow(2) - .sum() - .item() - ) - loss += batch_loss - num_elements += fp16_batch.numel() - - # Normalize the loss by the total number of elements - loss /= num_elements - - return loss - def _assert_all_activations_consumed(self): """ Confirm all activations have been consumed @@ -677,6 +659,17 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") +@torch.no_grad() +@torch.compile() +def _compute_loss( + fp16_output: torch.Tensor, + int_w_output: torch.Tensor, +) -> torch.Tensor: + """Compute MSE loss for each batch""" + return (fp16_output - int_w_output).view(-1).float().pow(2).mean() + + +@torch.compile() def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ): From aecfac9035ff7f40c690d1634350557a0c0b485b Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 16 Jun 2025 20:26:22 +0000 Subject: [PATCH 2/3] codeassist updates Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 28f3ced65..53dcadec8 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -543,8 +543,9 @@ def smooth(module): def _get_flattened_output(self, module: Module) -> torch.Tensor: """ - Returns output of running cached batch inputs through module - Output tensor is 1D, as shapes aren't necessary for calculating loss + Returns output of running cached batch inputs through module. + Outputs from all batches are concatenated and flattened into a 1D tensor, + as shapes aren't necessary for calculating loss. """ with align_module_device(module): outputs = [ @@ -665,7 +666,9 @@ def _compute_loss( fp16_output: torch.Tensor, int_w_output: torch.Tensor, ) -> torch.Tensor: - """Compute MSE loss for each batch""" + """ + Compute MSE loss over the flattened output of all batches + """ return (fp16_output - int_w_output).view(-1).float().pow(2).mean() From 0213b9c66f6318cf4aee3e1fdbe7414b1b0348af Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 17 Jun 2025 18:54:34 +0000 Subject: [PATCH 3/3] switch to F.mse_loss() Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 53dcadec8..0d678135b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.nn.functional as F from compressed_tensors.quantization import ( disable_quantization, find_name_or_class_matches, @@ -593,9 +594,9 @@ def _compute_best_scale( x_mean = x_mean.view(-1).to(device) w_mean = w_mean.view(-1).to(device) - for ratio in range(n_grid): + for grid_idx in range(n_grid): # create new scales - ratio = ratio / n_grid + ratio = grid_idx / n_grid # NOTE: s^-1 * x is fused here, according to paper if self.duo_scaling: @@ -632,7 +633,7 @@ def _compute_best_scale( int_w_output = self._get_flattened_output(parent_module) # compute mean squared error (L2 norm) - loss = _compute_loss(fp16_output, int_w_output) + loss = F.mse_loss(int_w_output, fp16_output).item() history.append(loss) if loss < best_error: @@ -642,8 +643,9 @@ def _compute_best_scale( parent_module.load_state_dict(org_sd) if best_ratio == -1: - logger.debug(history) - raise Exception + raise Exception( + f"Loss during best scale computation never less than inf: {history}" + ) assert ( torch.isnan(best_scales).sum() == 0 @@ -660,18 +662,6 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") -@torch.no_grad() -@torch.compile() -def _compute_loss( - fp16_output: torch.Tensor, - int_w_output: torch.Tensor, -) -> torch.Tensor: - """ - Compute MSE loss over the flattened output of all batches - """ - return (fp16_output - int_w_output).view(-1).float().pow(2).mean() - - @torch.compile() def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1