Skip to content

Commit 0213b9c

Browse files
switch to F.mse_loss()
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent aecfac9 commit 0213b9c

File tree

1 file changed

+7
-17
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+7
-17
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Dict, List, Optional, Tuple, Union
33

44
import torch
5+
import torch.nn.functional as F
56
from compressed_tensors.quantization import (
67
disable_quantization,
78
find_name_or_class_matches,
@@ -593,9 +594,9 @@ def _compute_best_scale(
593594
x_mean = x_mean.view(-1).to(device)
594595
w_mean = w_mean.view(-1).to(device)
595596

596-
for ratio in range(n_grid):
597+
for grid_idx in range(n_grid):
597598
# create new scales
598-
ratio = ratio / n_grid
599+
ratio = grid_idx / n_grid
599600

600601
# NOTE: s^-1 * x is fused here, according to paper
601602
if self.duo_scaling:
@@ -632,7 +633,7 @@ def _compute_best_scale(
632633
int_w_output = self._get_flattened_output(parent_module)
633634

634635
# compute mean squared error (L2 norm)
635-
loss = _compute_loss(fp16_output, int_w_output)
636+
loss = F.mse_loss(int_w_output, fp16_output).item()
636637

637638
history.append(loss)
638639
if loss < best_error:
@@ -642,8 +643,9 @@ def _compute_best_scale(
642643
parent_module.load_state_dict(org_sd)
643644

644645
if best_ratio == -1:
645-
logger.debug(history)
646-
raise Exception
646+
raise Exception(
647+
f"Loss during best scale computation never less than inf: {history}"
648+
)
647649

648650
assert (
649651
torch.isnan(best_scales).sum() == 0
@@ -660,18 +662,6 @@ def _assert_all_activations_consumed(self):
660662
raise RuntimeError("Some cached activations were not used")
661663

662664

663-
@torch.no_grad()
664-
@torch.compile()
665-
def _compute_loss(
666-
fp16_output: torch.Tensor,
667-
int_w_output: torch.Tensor,
668-
) -> torch.Tensor:
669-
"""
670-
Compute MSE loss over the flattened output of all batches
671-
"""
672-
return (fp16_output - int_w_output).view(-1).float().pow(2).mean()
673-
674-
675665
@torch.compile()
676666
def _pseudo_quantize_tensor(
677667
w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1

0 commit comments

Comments
 (0)