diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index e0822bb9e..06edf3c59 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -3,10 +3,7 @@ from compressed_tensors import CompressionFormat from compressed_tensors.config import SparsityStructure from compressed_tensors.quantization import QuantizationStrategy, QuantizationType -from compressed_tensors.quantization.utils import ( - is_model_quantized, - is_module_quantized, -) +from compressed_tensors.quantization.utils import is_module_quantized __all__ = ["infer_quantization_format"] @@ -47,14 +44,14 @@ def infer_quantization_format( :param save_compressed: used to infer a quantization format if None is provided :return compression format appropriate for model """ - if not is_model_quantized(model): - return None - if quantization_format is not None: return quantization_format + weight_args, input_args = _get_unique_quant_args(model) + if len(weight_args) <= 0: + return None + if save_compressed: - weight_args, input_args = _get_unique_quant_args(model) is_24_structure = ( SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR )