diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 1bec18e2a..0d678135b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,6 +2,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.nn.functional as F from compressed_tensors.quantization import ( disable_quantization, find_name_or_class_matches, @@ -474,8 +475,8 @@ def _apply_smoothing(self, model: Module) -> None: with calibration_forward_context(model), HooksMixin.disable_hooks(): # [STEP 3]: Compute output of module # could cache from hook, rather than recomputing here - fp16_outputs = self._run_samples(parent_module) - if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs): + fp16_output = self._get_flattened_output(parent_module) + if fp16_output.numel() == 0: logger.info( f"Skipping smooth_layer {mapping.smooth_name}, no activations " "found to scale. This can occasionally occur in MoE models " @@ -488,7 +489,7 @@ def _apply_smoothing(self, model: Module) -> None: # [STEP 4]: Compute loss best_scales = self._compute_best_scale( - x_mean, w_mean, parent_module, balance_layers, fp16_outputs + x_mean, w_mean, parent_module, balance_layers, fp16_output ) @torch.no_grad() @@ -541,17 +542,28 @@ def smooth(module): v.batch_intermediates.clear() self._assert_all_activations_consumed() - def _run_samples(self, module: Module) -> List[torch.Tensor]: + def _get_flattened_output(self, module: Module) -> torch.Tensor: + """ + Returns output of running cached batch inputs through module. + Outputs from all batches are concatenated and flattened into a 1D tensor, + as shapes aren't necessary for calculating loss. + """ with align_module_device(module): outputs = [ module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] ] - return [ - # If Tuple, assume that first argument is the input - output[0] if isinstance(output, Tuple) else output - for output in outputs - ] + return torch.cat( + [ + # If Tuple, assume that first argument is the input + ( + output[0].reshape(-1) + if isinstance(output, Tuple) + else output.reshape(-1) + ) + for output in outputs + ] + ) def _compute_best_scale( self, @@ -559,7 +571,7 @@ def _compute_best_scale( w_mean: torch.Tensor, parent_module: torch.nn.Module, linears2scale: List[torch.nn.Linear], - fp16_outputs: List[torch.Tensor], + fp16_output: torch.Tensor, ) -> torch.Tensor: """ Compute loss and select best scales @@ -582,9 +594,9 @@ def _compute_best_scale( x_mean = x_mean.view(-1).to(device) w_mean = w_mean.view(-1).to(device) - for ratio in range(n_grid): + for grid_idx in range(n_grid): # create new scales - ratio = ratio / n_grid + ratio = grid_idx / n_grid # NOTE: s^-1 * x is fused here, according to paper if self.duo_scaling: @@ -618,10 +630,10 @@ def _compute_best_scale( # W * X with HooksMixin.disable_hooks(): - int_w_outputs = self._run_samples(parent_module) + int_w_output = self._get_flattened_output(parent_module) # compute mean squared error (L2 norm) - loss = self._compute_loss(fp16_outputs, int_w_outputs, device) + loss = F.mse_loss(int_w_output, fp16_output).item() history.append(loss) if loss < best_error: @@ -631,8 +643,9 @@ def _compute_best_scale( parent_module.load_state_dict(org_sd) if best_ratio == -1: - logger.debug(history) - raise Exception + raise Exception( + f"Loss during best scale computation never less than inf: {history}" + ) assert ( torch.isnan(best_scales).sum() == 0 @@ -640,34 +653,6 @@ def _compute_best_scale( return best_scales.detach().cpu() - @torch.no_grad() - def _compute_loss( - self, - fp16_outputs: List[torch.Tensor], - int_w_outputs: List[torch.Tensor], - device: torch.device, - ) -> torch.Tensor: - loss = 0.0 - num_elements = 0 - - # Compute the MSE loss for each batch - for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): - batch_loss = ( - (fp16_batch.to(device) - int_w_batch.to(device)) - .view(-1) - .float() - .pow(2) - .sum() - .item() - ) - loss += batch_loss - num_elements += fp16_batch.numel() - - # Normalize the loss by the total number of elements - loss /= num_elements - - return loss - def _assert_all_activations_consumed(self): """ Confirm all activations have been consumed @@ -677,6 +662,7 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") +@torch.compile() def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ):