@@ -48,57 +48,57 @@ def infer_quantization_format(
48
48
"""
49
49
if quantization_format is not None :
50
50
return quantization_format
51
-
52
- if not save_compressed :
53
- # format will be inferred from config
54
- return None
55
-
51
+
56
52
weight_args , input_args = _get_unique_quant_args (model )
57
53
58
54
# no quantization format if no weights are quantized
59
55
if len (weight_args ) <= 0 :
60
56
return None
61
-
62
- is_24_structure = (
63
- SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
64
- )
65
- is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
66
-
67
- if (
68
- weight_args [0 ].num_bits == 4
69
- and weight_args [0 ].type == QuantizationType .FLOAT .value
70
- ):
71
- return CompressionFormat .nvfp4_pack_quantized
72
-
73
- if is_weight_only : # w4a16 and w8a16
74
- is_valid_pack = all (
75
- weight_arg .num_bits in [4 , 8 ]
76
- and weight_arg .type == QuantizationType .INT .value
77
- for weight_arg in weight_args
57
+
58
+ if save_compressed :
59
+ is_24_structure = (
60
+ SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
78
61
)
79
- if not is_valid_pack : # packing only valid for int4 and int 8
80
- return CompressionFormat .naive_quantized
81
- if is_24_structure :
82
- for arg in weight_args :
62
+ is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
63
+
64
+ if (
65
+ weight_args [0 ].num_bits == 4
66
+ and weight_args [0 ].type == QuantizationType .FLOAT .value
67
+ ):
68
+ return CompressionFormat .nvfp4_pack_quantized
69
+
70
+ if is_weight_only : # w4a16 and w8a16
71
+ is_valid_pack = all (
72
+ weight_arg .num_bits in [4 , 8 ]
73
+ and weight_arg .type == QuantizationType .INT .value
74
+ for weight_arg in weight_args
75
+ )
76
+ if not is_valid_pack : # packing only valid for int4 and int 8
77
+ return CompressionFormat .naive_quantized
78
+ if is_24_structure :
79
+ for arg in weight_args :
80
+ if (
81
+ arg .strategy is not QuantizationStrategy .CHANNEL .value
82
+ and arg .strategy is not QuantizationStrategy .GROUP .value
83
+ ):
84
+ # marlin24 kernel only applicable for channel/group quantization
85
+ return CompressionFormat .pack_quantized
86
+ return CompressionFormat .marlin_24
87
+ return CompressionFormat .pack_quantized
88
+ else : # w8a8 float and int
89
+ if len (weight_args ) == 1 :
83
90
if (
84
- arg . strategy is not QuantizationStrategy . CHANNEL .value
85
- and arg . strategy is not QuantizationStrategy . GROUP . value
91
+ weight_args [ 0 ]. type == QuantizationType . FLOAT .value
92
+ and weight_args [ 0 ]. num_bits == 8
86
93
):
87
- # marlin24 kernel only applicable for channel/group quantization
88
- return CompressionFormat .pack_quantized
89
- return CompressionFormat .marlin_24
90
- return CompressionFormat .pack_quantized
91
- else : # w8a8 float and int
92
- if len (weight_args ) == 1 :
93
- if (
94
- weight_args [0 ].type == QuantizationType .FLOAT .value
95
- and weight_args [0 ].num_bits == 8
96
- ):
97
- return CompressionFormat .float_quantized
98
- if weight_args [0 ].type == QuantizationType .INT .value :
99
- return CompressionFormat .int_quantized
100
-
101
- return CompressionFormat .naive_quantized
94
+ return CompressionFormat .float_quantized
95
+ if weight_args [0 ].type == QuantizationType .INT .value :
96
+ return CompressionFormat .int_quantized
97
+
98
+ return CompressionFormat .naive_quantized
99
+ else :
100
+ # format will be inferred from config
101
+ return None
102
102
103
103
104
104
def _get_unique_quant_args (model ):
0 commit comments