Skip to content

AWQ minor performance improvements to smoothing #1557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 30 additions & 44 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from compressed_tensors.quantization import (
disable_quantization,
find_name_or_class_matches,
Expand Down Expand Up @@ -474,8 +475,8 @@ def _apply_smoothing(self, model: Module) -> None:
with calibration_forward_context(model), HooksMixin.disable_hooks():
# [STEP 3]: Compute output of module
# could cache from hook, rather than recomputing here
fp16_outputs = self._run_samples(parent_module)
if len(fp16_outputs) == 0 or all(f.numel() == 0 for f in fp16_outputs):
fp16_output = self._get_flattened_output(parent_module)
if fp16_output.numel() == 0:
logger.info(
f"Skipping smooth_layer {mapping.smooth_name}, no activations "
"found to scale. This can occasionally occur in MoE models "
Expand All @@ -488,7 +489,7 @@ def _apply_smoothing(self, model: Module) -> None:

# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
x_mean, w_mean, parent_module, balance_layers, fp16_outputs
x_mean, w_mean, parent_module, balance_layers, fp16_output
)

@torch.no_grad()
Expand Down Expand Up @@ -541,25 +542,36 @@ def smooth(module):
v.batch_intermediates.clear()
self._assert_all_activations_consumed()

def _run_samples(self, module: Module) -> List[torch.Tensor]:
def _get_flattened_output(self, module: Module) -> torch.Tensor:
"""
Returns output of running cached batch inputs through module.
Outputs from all batches are concatenated and flattened into a 1D tensor,
as shapes aren't necessary for calculating loss.
"""
with align_module_device(module):
outputs = [
module(**batch_kwargs)
for batch_kwargs in self._parent_args_cache[module]
]
return [
# If Tuple, assume that first argument is the input
output[0] if isinstance(output, Tuple) else output
for output in outputs
]
return torch.cat(
[
# If Tuple, assume that first argument is the input
(
output[0].reshape(-1)
if isinstance(output, Tuple)
else output.reshape(-1)
)
for output in outputs
]
)

def _compute_best_scale(
self,
x_mean: torch.Tensor,
w_mean: torch.Tensor,
parent_module: torch.nn.Module,
linears2scale: List[torch.nn.Linear],
fp16_outputs: List[torch.Tensor],
fp16_output: torch.Tensor,
) -> torch.Tensor:
"""
Compute loss and select best scales
Expand All @@ -582,9 +594,9 @@ def _compute_best_scale(
x_mean = x_mean.view(-1).to(device)
w_mean = w_mean.view(-1).to(device)

for ratio in range(n_grid):
for grid_idx in range(n_grid):
# create new scales
ratio = ratio / n_grid
ratio = grid_idx / n_grid

# NOTE: s^-1 * x is fused here, according to paper
if self.duo_scaling:
Expand Down Expand Up @@ -618,10 +630,10 @@ def _compute_best_scale(

# W * X
with HooksMixin.disable_hooks():
int_w_outputs = self._run_samples(parent_module)
int_w_output = self._get_flattened_output(parent_module)

# compute mean squared error (L2 norm)
loss = self._compute_loss(fp16_outputs, int_w_outputs, device)
loss = F.mse_loss(int_w_output, fp16_output).item()

history.append(loss)
if loss < best_error:
Expand All @@ -631,43 +643,16 @@ def _compute_best_scale(
parent_module.load_state_dict(org_sd)

if best_ratio == -1:
logger.debug(history)
raise Exception
raise Exception(
f"Loss during best scale computation never less than inf: {history}"
)

assert (
torch.isnan(best_scales).sum() == 0
), f"Nan found in scales: {best_scales}"

return best_scales.detach().cpu()

@torch.no_grad()
def _compute_loss(
self,
fp16_outputs: List[torch.Tensor],
int_w_outputs: List[torch.Tensor],
device: torch.device,
) -> torch.Tensor:
loss = 0.0
num_elements = 0

# Compute the MSE loss for each batch
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
batch_loss = (
(fp16_batch.to(device) - int_w_batch.to(device))
.view(-1)
.float()
.pow(2)
.sum()
.item()
)
loss += batch_loss
num_elements += fp16_batch.numel()

# Normalize the loss by the total number of elements
loss /= num_elements

return loss

def _assert_all_activations_consumed(self):
"""
Confirm all activations have been consumed
Expand All @@ -677,6 +662,7 @@ def _assert_all_activations_consumed(self):
raise RuntimeError("Some cached activations were not used")


@torch.compile()
def _pseudo_quantize_tensor(
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1
):
Expand Down