From 6d022f2eab1bd1dcbcd585bf1e92e27da5421067 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 9 Jul 2025 22:16:21 -0400 Subject: [PATCH 1/3] add check for if there are no weight quantizations Signed-off-by: Kyle Sayers --- .../compression/quantization_format.py | 89 ++++++++++--------- 1 file changed, 45 insertions(+), 44 deletions(-) diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index e0822bb9e..3feac42e6 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -4,7 +4,6 @@ from compressed_tensors.config import SparsityStructure from compressed_tensors.quantization import QuantizationStrategy, QuantizationType from compressed_tensors.quantization.utils import ( - is_model_quantized, is_module_quantized, ) @@ -47,57 +46,59 @@ def infer_quantization_format( :param save_compressed: used to infer a quantization format if None is provided :return compression format appropriate for model """ - if not is_model_quantized(model): - return None - if quantization_format is not None: return quantization_format + + if not save_compressed: + # format will be inferred from config + return None + + weight_args, input_args = _get_unique_quant_args(model) - if save_compressed: - weight_args, input_args = _get_unique_quant_args(model) - is_24_structure = ( - SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR - ) - is_weight_only = len(input_args) == 0 and len(weight_args) > 0 + # no quantization format if no weights are quantized + if len(weight_args) <= 0: + return None + + is_24_structure = ( + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR + ) + is_weight_only = len(input_args) == 0 and len(weight_args) > 0 - if ( - weight_args[0].num_bits == 4 - and weight_args[0].type == QuantizationType.FLOAT.value - ): - return CompressionFormat.nvfp4_pack_quantized + if ( + weight_args[0].num_bits == 4 + and weight_args[0].type == QuantizationType.FLOAT.value + ): + return CompressionFormat.nvfp4_pack_quantized - if is_weight_only: # w4a16 and w8a16 - is_valid_pack = all( - weight_arg.num_bits in [4, 8] - and weight_arg.type == QuantizationType.INT.value - for weight_arg in weight_args - ) - if not is_valid_pack: # packing only valid for int4 and int 8 - return CompressionFormat.naive_quantized - if is_24_structure: - for arg in weight_args: - if ( - arg.strategy is not QuantizationStrategy.CHANNEL.value - and arg.strategy is not QuantizationStrategy.GROUP.value - ): - # marlin24 kernel only applicable for channel/group quantization - return CompressionFormat.pack_quantized - return CompressionFormat.marlin_24 - return CompressionFormat.pack_quantized - else: # w8a8 float and int - if len(weight_args) == 1: + if is_weight_only: # w4a16 and w8a16 + is_valid_pack = all( + weight_arg.num_bits in [4, 8] + and weight_arg.type == QuantizationType.INT.value + for weight_arg in weight_args + ) + if not is_valid_pack: # packing only valid for int4 and int 8 + return CompressionFormat.naive_quantized + if is_24_structure: + for arg in weight_args: if ( - weight_args[0].type == QuantizationType.FLOAT.value - and weight_args[0].num_bits == 8 + arg.strategy is not QuantizationStrategy.CHANNEL.value + and arg.strategy is not QuantizationStrategy.GROUP.value ): - return CompressionFormat.float_quantized - if weight_args[0].type == QuantizationType.INT.value: - return CompressionFormat.int_quantized + # marlin24 kernel only applicable for channel/group quantization + return CompressionFormat.pack_quantized + return CompressionFormat.marlin_24 + return CompressionFormat.pack_quantized + else: # w8a8 float and int + if len(weight_args) == 1: + if ( + weight_args[0].type == QuantizationType.FLOAT.value + and weight_args[0].num_bits == 8 + ): + return CompressionFormat.float_quantized + if weight_args[0].type == QuantizationType.INT.value: + return CompressionFormat.int_quantized - return CompressionFormat.naive_quantized - else: - # format will be inferred from config - return None + return CompressionFormat.naive_quantized def _get_unique_quant_args(model): From 7d5f7c9203878786987c62db61b1f53d041e7fb9 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 9 Jul 2025 22:18:03 -0400 Subject: [PATCH 2/3] add check for if there are no weight quantizations Signed-off-by: Kyle Sayers --- .../compression/quantization_format.py | 86 +++++++++---------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index 3feac42e6..17b539431 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -48,57 +48,57 @@ def infer_quantization_format( """ if quantization_format is not None: return quantization_format - - if not save_compressed: - # format will be inferred from config - return None - + weight_args, input_args = _get_unique_quant_args(model) # no quantization format if no weights are quantized if len(weight_args) <= 0: return None - - is_24_structure = ( - SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR - ) - is_weight_only = len(input_args) == 0 and len(weight_args) > 0 - - if ( - weight_args[0].num_bits == 4 - and weight_args[0].type == QuantizationType.FLOAT.value - ): - return CompressionFormat.nvfp4_pack_quantized - - if is_weight_only: # w4a16 and w8a16 - is_valid_pack = all( - weight_arg.num_bits in [4, 8] - and weight_arg.type == QuantizationType.INT.value - for weight_arg in weight_args + + if save_compressed: + is_24_structure = ( + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR ) - if not is_valid_pack: # packing only valid for int4 and int 8 - return CompressionFormat.naive_quantized - if is_24_structure: - for arg in weight_args: + is_weight_only = len(input_args) == 0 and len(weight_args) > 0 + + if ( + weight_args[0].num_bits == 4 + and weight_args[0].type == QuantizationType.FLOAT.value + ): + return CompressionFormat.nvfp4_pack_quantized + + if is_weight_only: # w4a16 and w8a16 + is_valid_pack = all( + weight_arg.num_bits in [4, 8] + and weight_arg.type == QuantizationType.INT.value + for weight_arg in weight_args + ) + if not is_valid_pack: # packing only valid for int4 and int 8 + return CompressionFormat.naive_quantized + if is_24_structure: + for arg in weight_args: + if ( + arg.strategy is not QuantizationStrategy.CHANNEL.value + and arg.strategy is not QuantizationStrategy.GROUP.value + ): + # marlin24 kernel only applicable for channel/group quantization + return CompressionFormat.pack_quantized + return CompressionFormat.marlin_24 + return CompressionFormat.pack_quantized + else: # w8a8 float and int + if len(weight_args) == 1: if ( - arg.strategy is not QuantizationStrategy.CHANNEL.value - and arg.strategy is not QuantizationStrategy.GROUP.value + weight_args[0].type == QuantizationType.FLOAT.value + and weight_args[0].num_bits == 8 ): - # marlin24 kernel only applicable for channel/group quantization - return CompressionFormat.pack_quantized - return CompressionFormat.marlin_24 - return CompressionFormat.pack_quantized - else: # w8a8 float and int - if len(weight_args) == 1: - if ( - weight_args[0].type == QuantizationType.FLOAT.value - and weight_args[0].num_bits == 8 - ): - return CompressionFormat.float_quantized - if weight_args[0].type == QuantizationType.INT.value: - return CompressionFormat.int_quantized - - return CompressionFormat.naive_quantized + return CompressionFormat.float_quantized + if weight_args[0].type == QuantizationType.INT.value: + return CompressionFormat.int_quantized + + return CompressionFormat.naive_quantized + else: + # format will be inferred from config + return None def _get_unique_quant_args(model): From cd28f402feb84d3801492d3bd907d1c331acb60d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Wed, 9 Jul 2025 22:45:39 -0400 Subject: [PATCH 3/3] style Signed-off-by: Kyle Sayers --- .../transformers/compression/quantization_format.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/llmcompressor/transformers/compression/quantization_format.py b/src/llmcompressor/transformers/compression/quantization_format.py index 17b539431..06edf3c59 100644 --- a/src/llmcompressor/transformers/compression/quantization_format.py +++ b/src/llmcompressor/transformers/compression/quantization_format.py @@ -3,9 +3,7 @@ from compressed_tensors import CompressionFormat from compressed_tensors.config import SparsityStructure from compressed_tensors.quantization import QuantizationStrategy, QuantizationType -from compressed_tensors.quantization.utils import ( - is_module_quantized, -) +from compressed_tensors.quantization.utils import is_module_quantized __all__ = ["infer_quantization_format"] @@ -50,8 +48,6 @@ def infer_quantization_format( return quantization_format weight_args, input_args = _get_unique_quant_args(model) - - # no quantization format if no weights are quantized if len(weight_args) <= 0: return None