File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed
src/llmcompressor/modifiers/awq Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -543,8 +543,9 @@ def smooth(module):
543
543
544
544
def _get_flattened_output (self , module : Module ) -> torch .Tensor :
545
545
"""
546
- Returns output of running cached batch inputs through module
547
- Output tensor is 1D, as shapes aren't necessary for calculating loss
546
+ Returns output of running cached batch inputs through module.
547
+ Outputs from all batches are concatenated and flattened into a 1D tensor,
548
+ as shapes aren't necessary for calculating loss.
548
549
"""
549
550
with align_module_device (module ):
550
551
outputs = [
@@ -665,7 +666,9 @@ def _compute_loss(
665
666
fp16_output : torch .Tensor ,
666
667
int_w_output : torch .Tensor ,
667
668
) -> torch .Tensor :
668
- """Compute MSE loss for each batch"""
669
+ """
670
+ Compute MSE loss over the flattened output of all batches
671
+ """
669
672
return (fp16_output - int_w_output ).view (- 1 ).float ().pow (2 ).mean ()
670
673
671
674
You can’t perform that action at this time.
0 commit comments