Skip to content

Commit b457898

Browse files
kylesayrsdsikka
andauthored
[Bugfix] infer_quantization_format when model only has activation quantization (#1635)
## Purpose ## * Fix KV cache tests, whose models only have activation quantization ## Background Previously, `is_model_quantized` would only check for quantization on leaf modules. Now it checks on attention modules as well, but since we have examples of attention modules with only activation quantization, this triggers a bug in `infer_quantization_format` ## Testing ## * Requires neuralmagic/compressed-tensors#387 to pass KV cache tests --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent 395eedb commit b457898

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
from compressed_tensors import CompressionFormat
44
from compressed_tensors.config import SparsityStructure
55
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
6-
from compressed_tensors.quantization.utils import (
7-
is_model_quantized,
8-
is_module_quantized,
9-
)
6+
from compressed_tensors.quantization.utils import is_module_quantized
107

118
__all__ = ["infer_quantization_format"]
129

@@ -47,14 +44,14 @@ def infer_quantization_format(
4744
:param save_compressed: used to infer a quantization format if None is provided
4845
:return compression format appropriate for model
4946
"""
50-
if not is_model_quantized(model):
51-
return None
52-
5347
if quantization_format is not None:
5448
return quantization_format
5549

50+
weight_args, input_args = _get_unique_quant_args(model)
51+
if len(weight_args) <= 0:
52+
return None
53+
5654
if save_compressed:
57-
weight_args, input_args = _get_unique_quant_args(model)
5855
is_24_structure = (
5956
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
6057
)

0 commit comments

Comments
 (0)