Skip to content

Commit df12bd9

Browse files
codeassist updates
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 2eceee5 commit df12bd9

File tree

1 file changed

+6
-3
lines changed
  • src/llmcompressor/modifiers/awq

1 file changed

+6
-3
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,8 +543,9 @@ def smooth(module):
543543

544544
def _get_flattened_output(self, module: Module) -> torch.Tensor:
545545
"""
546-
Returns output of running cached batch inputs through module
547-
Output tensor is 1D, as shapes aren't necessary for calculating loss
546+
Returns output of running cached batch inputs through module.
547+
Outputs from all batches are concatenated and flattened into a 1D tensor,
548+
as shapes aren't necessary for calculating loss.
548549
"""
549550
with align_module_device(module):
550551
outputs = [
@@ -665,7 +666,9 @@ def _compute_loss(
665666
fp16_output: torch.Tensor,
666667
int_w_output: torch.Tensor,
667668
) -> torch.Tensor:
668-
"""Compute MSE loss for each batch"""
669+
"""
670+
Compute MSE loss over the flattened output of all batches
671+
"""
669672
return (fp16_output - int_w_output).view(-1).float().pow(2).mean()
670673

671674

0 commit comments

Comments
 (0)