|
4 | 4 | from compressed_tensors.config import SparsityStructure
|
5 | 5 | from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
|
6 | 6 | from compressed_tensors.quantization.utils import (
|
7 |
| - is_model_quantized, |
8 | 7 | is_module_quantized,
|
9 | 8 | )
|
10 | 9 |
|
@@ -47,57 +46,59 @@ def infer_quantization_format(
|
47 | 46 | :param save_compressed: used to infer a quantization format if None is provided
|
48 | 47 | :return compression format appropriate for model
|
49 | 48 | """
|
50 |
| - if not is_model_quantized(model): |
51 |
| - return None |
52 |
| - |
53 | 49 | if quantization_format is not None:
|
54 | 50 | return quantization_format
|
| 51 | + |
| 52 | + if not save_compressed: |
| 53 | + # format will be inferred from config |
| 54 | + return None |
| 55 | + |
| 56 | + weight_args, input_args = _get_unique_quant_args(model) |
55 | 57 |
|
56 |
| - if save_compressed: |
57 |
| - weight_args, input_args = _get_unique_quant_args(model) |
58 |
| - is_24_structure = ( |
59 |
| - SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR |
60 |
| - ) |
61 |
| - is_weight_only = len(input_args) == 0 and len(weight_args) > 0 |
| 58 | + # no quantization format if no weights are quantized |
| 59 | + if len(weight_args) <= 0: |
| 60 | + return None |
| 61 | + |
| 62 | + is_24_structure = ( |
| 63 | + SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR |
| 64 | + ) |
| 65 | + is_weight_only = len(input_args) == 0 and len(weight_args) > 0 |
62 | 66 |
|
63 |
| - if ( |
64 |
| - weight_args[0].num_bits == 4 |
65 |
| - and weight_args[0].type == QuantizationType.FLOAT.value |
66 |
| - ): |
67 |
| - return CompressionFormat.nvfp4_pack_quantized |
| 67 | + if ( |
| 68 | + weight_args[0].num_bits == 4 |
| 69 | + and weight_args[0].type == QuantizationType.FLOAT.value |
| 70 | + ): |
| 71 | + return CompressionFormat.nvfp4_pack_quantized |
68 | 72 |
|
69 |
| - if is_weight_only: # w4a16 and w8a16 |
70 |
| - is_valid_pack = all( |
71 |
| - weight_arg.num_bits in [4, 8] |
72 |
| - and weight_arg.type == QuantizationType.INT.value |
73 |
| - for weight_arg in weight_args |
74 |
| - ) |
75 |
| - if not is_valid_pack: # packing only valid for int4 and int 8 |
76 |
| - return CompressionFormat.naive_quantized |
77 |
| - if is_24_structure: |
78 |
| - for arg in weight_args: |
79 |
| - if ( |
80 |
| - arg.strategy is not QuantizationStrategy.CHANNEL.value |
81 |
| - and arg.strategy is not QuantizationStrategy.GROUP.value |
82 |
| - ): |
83 |
| - # marlin24 kernel only applicable for channel/group quantization |
84 |
| - return CompressionFormat.pack_quantized |
85 |
| - return CompressionFormat.marlin_24 |
86 |
| - return CompressionFormat.pack_quantized |
87 |
| - else: # w8a8 float and int |
88 |
| - if len(weight_args) == 1: |
| 73 | + if is_weight_only: # w4a16 and w8a16 |
| 74 | + is_valid_pack = all( |
| 75 | + weight_arg.num_bits in [4, 8] |
| 76 | + and weight_arg.type == QuantizationType.INT.value |
| 77 | + for weight_arg in weight_args |
| 78 | + ) |
| 79 | + if not is_valid_pack: # packing only valid for int4 and int 8 |
| 80 | + return CompressionFormat.naive_quantized |
| 81 | + if is_24_structure: |
| 82 | + for arg in weight_args: |
89 | 83 | if (
|
90 |
| - weight_args[0].type == QuantizationType.FLOAT.value |
91 |
| - and weight_args[0].num_bits == 8 |
| 84 | + arg.strategy is not QuantizationStrategy.CHANNEL.value |
| 85 | + and arg.strategy is not QuantizationStrategy.GROUP.value |
92 | 86 | ):
|
93 |
| - return CompressionFormat.float_quantized |
94 |
| - if weight_args[0].type == QuantizationType.INT.value: |
95 |
| - return CompressionFormat.int_quantized |
| 87 | + # marlin24 kernel only applicable for channel/group quantization |
| 88 | + return CompressionFormat.pack_quantized |
| 89 | + return CompressionFormat.marlin_24 |
| 90 | + return CompressionFormat.pack_quantized |
| 91 | + else: # w8a8 float and int |
| 92 | + if len(weight_args) == 1: |
| 93 | + if ( |
| 94 | + weight_args[0].type == QuantizationType.FLOAT.value |
| 95 | + and weight_args[0].num_bits == 8 |
| 96 | + ): |
| 97 | + return CompressionFormat.float_quantized |
| 98 | + if weight_args[0].type == QuantizationType.INT.value: |
| 99 | + return CompressionFormat.int_quantized |
96 | 100 |
|
97 |
| - return CompressionFormat.naive_quantized |
98 |
| - else: |
99 |
| - # format will be inferred from config |
100 |
| - return None |
| 101 | + return CompressionFormat.naive_quantized |
101 | 102 |
|
102 | 103 |
|
103 | 104 | def _get_unique_quant_args(model):
|
|
0 commit comments