Skip to content

Commit 1e676f7

Browse files
committed
guard earlier
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent cd28f40 commit 1e676f7

File tree

1 file changed

+40
-40
lines changed

1 file changed

+40
-40
lines changed

src/llmcompressor/transformers/compression/quantization_format.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -47,54 +47,54 @@ def infer_quantization_format(
4747
if quantization_format is not None:
4848
return quantization_format
4949

50+
if not save_compressed:
51+
# format will be inferred from config
52+
return None
53+
5054
weight_args, input_args = _get_unique_quant_args(model)
5155
if len(weight_args) <= 0:
5256
return None
5357

54-
if save_compressed:
55-
is_24_structure = (
56-
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
58+
is_24_structure = (
59+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
60+
)
61+
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
62+
63+
if (
64+
weight_args[0].num_bits == 4
65+
and weight_args[0].type == QuantizationType.FLOAT.value
66+
):
67+
return CompressionFormat.nvfp4_pack_quantized
68+
69+
if is_weight_only: # w4a16 and w8a16
70+
is_valid_pack = all(
71+
weight_arg.num_bits in [4, 8]
72+
and weight_arg.type == QuantizationType.INT.value
73+
for weight_arg in weight_args
5774
)
58-
is_weight_only = len(input_args) == 0 and len(weight_args) > 0
59-
60-
if (
61-
weight_args[0].num_bits == 4
62-
and weight_args[0].type == QuantizationType.FLOAT.value
63-
):
64-
return CompressionFormat.nvfp4_pack_quantized
65-
66-
if is_weight_only: # w4a16 and w8a16
67-
is_valid_pack = all(
68-
weight_arg.num_bits in [4, 8]
69-
and weight_arg.type == QuantizationType.INT.value
70-
for weight_arg in weight_args
71-
)
72-
if not is_valid_pack: # packing only valid for int4 and int 8
73-
return CompressionFormat.naive_quantized
74-
if is_24_structure:
75-
for arg in weight_args:
76-
if (
77-
arg.strategy is not QuantizationStrategy.CHANNEL.value
78-
and arg.strategy is not QuantizationStrategy.GROUP.value
79-
):
80-
# marlin24 kernel only applicable for channel/group quantization
81-
return CompressionFormat.pack_quantized
82-
return CompressionFormat.marlin_24
83-
return CompressionFormat.pack_quantized
84-
else: # w8a8 float and int
85-
if len(weight_args) == 1:
75+
if not is_valid_pack: # packing only valid for int4 and int 8
76+
return CompressionFormat.naive_quantized
77+
if is_24_structure:
78+
for arg in weight_args:
8679
if (
87-
weight_args[0].type == QuantizationType.FLOAT.value
88-
and weight_args[0].num_bits == 8
80+
arg.strategy is not QuantizationStrategy.CHANNEL.value
81+
and arg.strategy is not QuantizationStrategy.GROUP.value
8982
):
90-
return CompressionFormat.float_quantized
91-
if weight_args[0].type == QuantizationType.INT.value:
92-
return CompressionFormat.int_quantized
83+
# marlin24 kernel only applicable for channel/group quantization
84+
return CompressionFormat.pack_quantized
85+
return CompressionFormat.marlin_24
86+
return CompressionFormat.pack_quantized
87+
else: # w8a8 float and int
88+
if len(weight_args) == 1:
89+
if (
90+
weight_args[0].type == QuantizationType.FLOAT.value
91+
and weight_args[0].num_bits == 8
92+
):
93+
return CompressionFormat.float_quantized
94+
if weight_args[0].type == QuantizationType.INT.value:
95+
return CompressionFormat.int_quantized
9396

94-
return CompressionFormat.naive_quantized
95-
else:
96-
# format will be inferred from config
97-
return None
97+
return CompressionFormat.naive_quantized
9898

9999

100100
def _get_unique_quant_args(model):

0 commit comments

Comments
 (0)