Skip to content

Commit 7d5f7c9

Browse files
committed
add check for if there are no weight quantizations
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 6d022f2 commit 7d5f7c9

File tree

1 file changed

+43
-43
lines changed

1 file changed

+43
-43
lines changed

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -48,57 +48,57 @@ def infer_quantization_format(
4848
"""
4949
if quantization_format is not None:
5050
return quantization_format
51-
52-
if not save_compressed:
53-
# format will be inferred from config
54-
return None
55-
51+
5652
weight_args, input_args = _get_unique_quant_args(model)
5753

5854
# no quantization format if no weights are quantized
5955
if len(weight_args) <= 0:
6056
return None
61-
62-
is_24_structure = (
63-
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
64-
)
65-
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
66-
67-
if (
68-
weight_args[0].num_bits == 4
69-
and weight_args[0].type == QuantizationType.FLOAT.value
70-
):
71-
return CompressionFormat.nvfp4_pack_quantized
72-
73-
if is_weight_only: # w4a16 and w8a16
74-
is_valid_pack = all(
75-
weight_arg.num_bits in [4, 8]
76-
and weight_arg.type == QuantizationType.INT.value
77-
for weight_arg in weight_args
57+
58+
if save_compressed:
59+
is_24_structure = (
60+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
7861
)
79-
if not is_valid_pack: # packing only valid for int4 and int 8
80-
return CompressionFormat.naive_quantized
81-
if is_24_structure:
82-
for arg in weight_args:
62+
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
63+
64+
if (
65+
weight_args[0].num_bits == 4
66+
and weight_args[0].type == QuantizationType.FLOAT.value
67+
):
68+
return CompressionFormat.nvfp4_pack_quantized
69+
70+
if is_weight_only: # w4a16 and w8a16
71+
is_valid_pack = all(
72+
weight_arg.num_bits in [4, 8]
73+
and weight_arg.type == QuantizationType.INT.value
74+
for weight_arg in weight_args
75+
)
76+
if not is_valid_pack: # packing only valid for int4 and int 8
77+
return CompressionFormat.naive_quantized
78+
if is_24_structure:
79+
for arg in weight_args:
80+
if (
81+
arg.strategy is not QuantizationStrategy.CHANNEL.value
82+
and arg.strategy is not QuantizationStrategy.GROUP.value
83+
):
84+
# marlin24 kernel only applicable for channel/group quantization
85+
return CompressionFormat.pack_quantized
86+
return CompressionFormat.marlin_24
87+
return CompressionFormat.pack_quantized
88+
else: # w8a8 float and int
89+
if len(weight_args) == 1:
8390
if (
84-
arg.strategy is not QuantizationStrategy.CHANNEL.value
85-
and arg.strategy is not QuantizationStrategy.GROUP.value
91+
weight_args[0].type == QuantizationType.FLOAT.value
92+
and weight_args[0].num_bits == 8
8693
):
87-
# marlin24 kernel only applicable for channel/group quantization
88-
return CompressionFormat.pack_quantized
89-
return CompressionFormat.marlin_24
90-
return CompressionFormat.pack_quantized
91-
else: # w8a8 float and int
92-
if len(weight_args) == 1:
93-
if (
94-
weight_args[0].type == QuantizationType.FLOAT.value
95-
and weight_args[0].num_bits == 8
96-
):
97-
return CompressionFormat.float_quantized
98-
if weight_args[0].type == QuantizationType.INT.value:
99-
return CompressionFormat.int_quantized
100-
101-
return CompressionFormat.naive_quantized
94+
return CompressionFormat.float_quantized
95+
if weight_args[0].type == QuantizationType.INT.value:
96+
return CompressionFormat.int_quantized
97+
98+
return CompressionFormat.naive_quantized
99+
else:
100+
# format will be inferred from config
101+
return None
102102

103103

104104
def _get_unique_quant_args(model):

0 commit comments

Comments
 (0)