diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index e0822bb9e..5fd89e361 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -3,10 +3,7 @@ from compressed_tensors import CompressionFormat from compressed_tensors.config import SparsityStructure from compressed_tensors.quantization import QuantizationStrategy, QuantizationType -from compressed_tensors.quantization.utils import ( - is_model_quantized, - is_module_quantized, -) +from compressed_tensors.quantization.utils import is_module_quantized __all__ = ["infer_quantization_format"] @@ -47,57 +44,57 @@ def infer_quantization_format( :param save_compressed: used to infer a quantization format if None is provided :return compression format appropriate for model """ - if not is_model_quantized(model): - return None - if quantization_format is not None: return quantization_format - if save_compressed: - weight_args, input_args = _get_unique_quant_args(model) - is_24_structure = ( - SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + if not save_compressed: + # format will be inferred from config + return None + + weight_args, input_args = _get_unique_quant_args(model) + if len(weight_args) <= 0: + return None + + is_24_structure = ( + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + ) + is_weight_only = len(input_args) == 0 and len(weight_args) > 0 + + if ( + weight_args[0].num_bits == 4 + and weight_args[0].type == QuantizationType.FLOAT.value + ): + return CompressionFormat.nvfp4_pack_quantized + + if is_weight_only: # w4a16 and w8a16 + is_valid_pack = all( + weight_arg.num_bits in [4, 8] + and weight_arg.type == QuantizationType.INT.value + for weight_arg in weight_args ) - is_weight_only = len(input_args) == 0 and len(weight_args) > 0 - - if ( - weight_args[0].num_bits == 4 - and weight_args[0].type == QuantizationType.FLOAT.value - ): - return CompressionFormat.nvfp4_pack_quantized - - if is_weight_only: # w4a16 and w8a16 - is_valid_pack = all( - weight_arg.num_bits in [4, 8] - and weight_arg.type == QuantizationType.INT.value - for weight_arg in weight_args - ) - if not is_valid_pack: # packing only valid for int4 and int 8 - return CompressionFormat.naive_quantized - if is_24_structure: - for arg in weight_args: - if ( - arg.strategy is not QuantizationStrategy.CHANNEL.value - and arg.strategy is not QuantizationStrategy.GROUP.value - ): - # marlin24 kernel only applicable for channel/group quantization - return CompressionFormat.pack_quantized - return CompressionFormat.marlin_24 - return CompressionFormat.pack_quantized - else: # w8a8 float and int - if len(weight_args) == 1: + if not is_valid_pack: # packing only valid for int4 and int 8 + return CompressionFormat.naive_quantized + if is_24_structure: + for arg in weight_args: if ( - weight_args[0].type == QuantizationType.FLOAT.value - and weight_args[0].num_bits == 8 + arg.strategy is not QuantizationStrategy.CHANNEL.value + and arg.strategy is not QuantizationStrategy.GROUP.value ): - return CompressionFormat.float_quantized - if weight_args[0].type == QuantizationType.INT.value: - return CompressionFormat.int_quantized + # marlin24 kernel only applicable for channel/group quantization + return CompressionFormat.pack_quantized + return CompressionFormat.marlin_24 + return CompressionFormat.pack_quantized + else: # w8a8 float and int + if len(weight_args) == 1: + if ( + weight_args[0].type == QuantizationType.FLOAT.value + and weight_args[0].num_bits == 8 + ): + return CompressionFormat.float_quantized + if weight_args[0].type == QuantizationType.INT.value: + return CompressionFormat.int_quantized - return CompressionFormat.naive_quantized - else: - # format will be inferred from config - return None + return CompressionFormat.naive_quantized def _get_unique_quant_args(model):