Skip to content

Commit 6d022f2

Browse files
committed
add check for if there are no weight quantizations
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent f28a9d5 commit 6d022f2

File tree

1 file changed

+45
-44
lines changed

1 file changed

+45
-44
lines changed

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 45 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from compressed_tensors.config import SparsityStructure
55
from compressed_tensors.quantization import QuantizationStrategy, QuantizationType
66
from compressed_tensors.quantization.utils import (
7-
is_model_quantized,
87
is_module_quantized,
98
)
109

@@ -47,57 +46,59 @@ def infer_quantization_format(
4746
:param save_compressed: used to infer a quantization format if None is provided
4847
:return compression format appropriate for model
4948
"""
50-
if not is_model_quantized(model):
51-
return None
52-
5349
if quantization_format is not None:
5450
return quantization_format
51+
52+
if not save_compressed:
53+
# format will be inferred from config
54+
return None
55+
56+
weight_args, input_args = _get_unique_quant_args(model)
5557

56-
if save_compressed:
57-
weight_args, input_args = _get_unique_quant_args(model)
58-
is_24_structure = (
59-
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
60-
)
61-
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
58+
# no quantization format if no weights are quantized
59+
if len(weight_args) <= 0:
60+
return None
61+
62+
is_24_structure = (
63+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
64+
)
65+
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
6266

63-
if (
64-
weight_args[0].num_bits == 4
65-
and weight_args[0].type == QuantizationType.FLOAT.value
66-
):
67-
return CompressionFormat.nvfp4_pack_quantized
67+
if (
68+
weight_args[0].num_bits == 4
69+
and weight_args[0].type == QuantizationType.FLOAT.value
70+
):
71+
return CompressionFormat.nvfp4_pack_quantized
6872

69-
if is_weight_only: # w4a16 and w8a16
70-
is_valid_pack = all(
71-
weight_arg.num_bits in [4, 8]
72-
and weight_arg.type == QuantizationType.INT.value
73-
for weight_arg in weight_args
74-
)
75-
if not is_valid_pack: # packing only valid for int4 and int 8
76-
return CompressionFormat.naive_quantized
77-
if is_24_structure:
78-
for arg in weight_args:
79-
if (
80-
arg.strategy is not QuantizationStrategy.CHANNEL.value
81-
and arg.strategy is not QuantizationStrategy.GROUP.value
82-
):
83-
# marlin24 kernel only applicable for channel/group quantization
84-
return CompressionFormat.pack_quantized
85-
return CompressionFormat.marlin_24
86-
return CompressionFormat.pack_quantized
87-
else: # w8a8 float and int
88-
if len(weight_args) == 1:
73+
if is_weight_only: # w4a16 and w8a16
74+
is_valid_pack = all(
75+
weight_arg.num_bits in [4, 8]
76+
and weight_arg.type == QuantizationType.INT.value
77+
for weight_arg in weight_args
78+
)
79+
if not is_valid_pack: # packing only valid for int4 and int 8
80+
return CompressionFormat.naive_quantized
81+
if is_24_structure:
82+
for arg in weight_args:
8983
if (
90-
weight_args[0].type == QuantizationType.FLOAT.value
91-
and weight_args[0].num_bits == 8
84+
arg.strategy is not QuantizationStrategy.CHANNEL.value
85+
and arg.strategy is not QuantizationStrategy.GROUP.value
9286
):
93-
return CompressionFormat.float_quantized
94-
if weight_args[0].type == QuantizationType.INT.value:
95-
return CompressionFormat.int_quantized
87+
# marlin24 kernel only applicable for channel/group quantization
88+
return CompressionFormat.pack_quantized
89+
return CompressionFormat.marlin_24
90+
return CompressionFormat.pack_quantized
91+
else: # w8a8 float and int
92+
if len(weight_args) == 1:
93+
if (
94+
weight_args[0].type == QuantizationType.FLOAT.value
95+
and weight_args[0].num_bits == 8
96+
):
97+
return CompressionFormat.float_quantized
98+
if weight_args[0].type == QuantizationType.INT.value:
99+
return CompressionFormat.int_quantized
96100

97-
return CompressionFormat.naive_quantized
98-
else:
99-
# format will be inferred from config
100-
return None
101+
return CompressionFormat.naive_quantized
101102

102103

103104
def _get_unique_quant_args(model):

0 commit comments

Comments
 (0)