@@ -474,8 +474,8 @@ def _apply_smoothing(self, model: Module) -> None:
474
474
with calibration_forward_context (model ), HooksMixin .disable_hooks ():
475
475
# [STEP 3]: Compute output of module
476
476
# could cache from hook, rather than recomputing here
477
- fp16_outputs = self ._run_samples (parent_module )
478
- if len ( fp16_outputs ) == 0 or all ( f .numel () == 0 for f in fp16_outputs ) :
477
+ fp16_output = self ._get_flattened_output (parent_module )
478
+ if fp16_output .numel () == 0 :
479
479
logger .info (
480
480
f"Skipping smooth_layer { mapping .smooth_name } , no activations "
481
481
"found to scale. This can occasionally occur in MoE models "
@@ -488,7 +488,7 @@ def _apply_smoothing(self, model: Module) -> None:
488
488
489
489
# [STEP 4]: Compute loss
490
490
best_scales = self ._compute_best_scale (
491
- x_mean , w_mean , parent_module , balance_layers , fp16_outputs
491
+ x_mean , w_mean , parent_module , balance_layers , fp16_output
492
492
)
493
493
494
494
@torch .no_grad ()
@@ -541,25 +541,35 @@ def smooth(module):
541
541
v .batch_intermediates .clear ()
542
542
self ._assert_all_activations_consumed ()
543
543
544
- def _run_samples (self , module : Module ) -> List [torch .Tensor ]:
544
+ def _get_flattened_output (self , module : Module ) -> torch .Tensor :
545
+ """
546
+ Returns output of running cached batch inputs through module
547
+ Output tensor is 1D, as shapes aren't necessary for calculating loss
548
+ """
545
549
with align_module_device (module ):
546
550
outputs = [
547
551
module (** batch_kwargs )
548
552
for batch_kwargs in self ._parent_args_cache [module ]
549
553
]
550
- return [
551
- # If Tuple, assume that first argument is the input
552
- output [0 ] if isinstance (output , Tuple ) else output
553
- for output in outputs
554
- ]
554
+ return torch .cat (
555
+ [
556
+ # If Tuple, assume that first argument is the input
557
+ (
558
+ output [0 ].reshape (- 1 )
559
+ if isinstance (output , Tuple )
560
+ else output .reshape (- 1 )
561
+ )
562
+ for output in outputs
563
+ ]
564
+ )
555
565
556
566
def _compute_best_scale (
557
567
self ,
558
568
x_mean : torch .Tensor ,
559
569
w_mean : torch .Tensor ,
560
570
parent_module : torch .nn .Module ,
561
571
linears2scale : List [torch .nn .Linear ],
562
- fp16_outputs : List [ torch .Tensor ] ,
572
+ fp16_output : torch .Tensor ,
563
573
) -> torch .Tensor :
564
574
"""
565
575
Compute loss and select best scales
@@ -618,10 +628,10 @@ def _compute_best_scale(
618
628
619
629
# W * X
620
630
with HooksMixin .disable_hooks ():
621
- int_w_outputs = self ._run_samples (parent_module )
631
+ int_w_output = self ._get_flattened_output (parent_module )
622
632
623
633
# compute mean squared error (L2 norm)
624
- loss = self . _compute_loss (fp16_outputs , int_w_outputs , device )
634
+ loss = _compute_loss (fp16_output , int_w_output )
625
635
626
636
history .append (loss )
627
637
if loss < best_error :
@@ -640,34 +650,6 @@ def _compute_best_scale(
640
650
641
651
return best_scales .detach ().cpu ()
642
652
643
- @torch .no_grad ()
644
- def _compute_loss (
645
- self ,
646
- fp16_outputs : List [torch .Tensor ],
647
- int_w_outputs : List [torch .Tensor ],
648
- device : torch .device ,
649
- ) -> torch .Tensor :
650
- loss = 0.0
651
- num_elements = 0
652
-
653
- # Compute the MSE loss for each batch
654
- for fp16_batch , int_w_batch in zip (fp16_outputs , int_w_outputs ):
655
- batch_loss = (
656
- (fp16_batch .to (device ) - int_w_batch .to (device ))
657
- .view (- 1 )
658
- .float ()
659
- .pow (2 )
660
- .sum ()
661
- .item ()
662
- )
663
- loss += batch_loss
664
- num_elements += fp16_batch .numel ()
665
-
666
- # Normalize the loss by the total number of elements
667
- loss /= num_elements
668
-
669
- return loss
670
-
671
653
def _assert_all_activations_consumed (self ):
672
654
"""
673
655
Confirm all activations have been consumed
@@ -677,6 +659,17 @@ def _assert_all_activations_consumed(self):
677
659
raise RuntimeError ("Some cached activations were not used" )
678
660
679
661
662
+ @torch .no_grad ()
663
+ @torch .compile ()
664
+ def _compute_loss (
665
+ fp16_output : torch .Tensor ,
666
+ int_w_output : torch .Tensor ,
667
+ ) -> torch .Tensor :
668
+ """Compute MSE loss for each batch"""
669
+ return (fp16_output - int_w_output ).view (- 1 ).float ().pow (2 ).mean ()
670
+
671
+
672
+ @torch .compile ()
680
673
def _pseudo_quantize_tensor (
681
674
w : torch .Tensor , symmetric : bool = False , bit_width : int = 8 , group_size : int = - 1
682
675
):
0 commit comments