2
2
from typing import Dict , List , Optional , Tuple , Union
3
3
4
4
import torch
5
+ import torch .nn .functional as F
5
6
from compressed_tensors .quantization import (
6
7
disable_quantization ,
7
8
find_name_or_class_matches ,
@@ -593,9 +594,9 @@ def _compute_best_scale(
593
594
x_mean = x_mean .view (- 1 ).to (device )
594
595
w_mean = w_mean .view (- 1 ).to (device )
595
596
596
- for ratio in range (n_grid ):
597
+ for grid_idx in range (n_grid ):
597
598
# create new scales
598
- ratio = ratio / n_grid
599
+ ratio = grid_idx / n_grid
599
600
600
601
# NOTE: s^-1 * x is fused here, according to paper
601
602
if self .duo_scaling :
@@ -632,7 +633,7 @@ def _compute_best_scale(
632
633
int_w_output = self ._get_flattened_output (parent_module )
633
634
634
635
# compute mean squared error (L2 norm)
635
- loss = _compute_loss ( fp16_output , int_w_output )
636
+ loss = F . mse_loss ( int_w_output , fp16_output ). item ( )
636
637
637
638
history .append (loss )
638
639
if loss < best_error :
@@ -642,8 +643,9 @@ def _compute_best_scale(
642
643
parent_module .load_state_dict (org_sd )
643
644
644
645
if best_ratio == - 1 :
645
- logger .debug (history )
646
- raise Exception
646
+ raise Exception (
647
+ f"Loss during best scale computation never less than inf: { history } "
648
+ )
647
649
648
650
assert (
649
651
torch .isnan (best_scales ).sum () == 0
@@ -660,18 +662,6 @@ def _assert_all_activations_consumed(self):
660
662
raise RuntimeError ("Some cached activations were not used" )
661
663
662
664
663
- @torch .no_grad ()
664
- @torch .compile ()
665
- def _compute_loss (
666
- fp16_output : torch .Tensor ,
667
- int_w_output : torch .Tensor ,
668
- ) -> torch .Tensor :
669
- """
670
- Compute MSE loss over the flattened output of all batches
671
- """
672
- return (fp16_output - int_w_output ).view (- 1 ).float ().pow (2 ).mean ()
673
-
674
-
675
665
@torch .compile ()
676
666
def _pseudo_quantize_tensor (
677
667
w : torch .Tensor , symmetric : bool = False , bit_width : int = 8 , group_size : int = - 1
0 commit comments