@@ -47,54 +47,54 @@ def infer_quantization_format(
47
47
if quantization_format is not None :
48
48
return quantization_format
49
49
50
+ if not save_compressed :
51
+ # format will be inferred from config
52
+ return None
53
+
50
54
weight_args , input_args = _get_unique_quant_args (model )
51
55
if len (weight_args ) <= 0 :
52
56
return None
53
57
54
- if save_compressed :
55
- is_24_structure = (
56
- SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
58
+ is_24_structure = (
59
+ SparsityStructure (sparsity_structure ) == SparsityStructure .TWO_FOUR
60
+ )
61
+ is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
62
+
63
+ if (
64
+ weight_args [0 ].num_bits == 4
65
+ and weight_args [0 ].type == QuantizationType .FLOAT .value
66
+ ):
67
+ return CompressionFormat .nvfp4_pack_quantized
68
+
69
+ if is_weight_only : # w4a16 and w8a16
70
+ is_valid_pack = all (
71
+ weight_arg .num_bits in [4 , 8 ]
72
+ and weight_arg .type == QuantizationType .INT .value
73
+ for weight_arg in weight_args
57
74
)
58
- is_weight_only = len (input_args ) == 0 and len (weight_args ) > 0
59
-
60
- if (
61
- weight_args [0 ].num_bits == 4
62
- and weight_args [0 ].type == QuantizationType .FLOAT .value
63
- ):
64
- return CompressionFormat .nvfp4_pack_quantized
65
-
66
- if is_weight_only : # w4a16 and w8a16
67
- is_valid_pack = all (
68
- weight_arg .num_bits in [4 , 8 ]
69
- and weight_arg .type == QuantizationType .INT .value
70
- for weight_arg in weight_args
71
- )
72
- if not is_valid_pack : # packing only valid for int4 and int 8
73
- return CompressionFormat .naive_quantized
74
- if is_24_structure :
75
- for arg in weight_args :
76
- if (
77
- arg .strategy is not QuantizationStrategy .CHANNEL .value
78
- and arg .strategy is not QuantizationStrategy .GROUP .value
79
- ):
80
- # marlin24 kernel only applicable for channel/group quantization
81
- return CompressionFormat .pack_quantized
82
- return CompressionFormat .marlin_24
83
- return CompressionFormat .pack_quantized
84
- else : # w8a8 float and int
85
- if len (weight_args ) == 1 :
75
+ if not is_valid_pack : # packing only valid for int4 and int 8
76
+ return CompressionFormat .naive_quantized
77
+ if is_24_structure :
78
+ for arg in weight_args :
86
79
if (
87
- weight_args [ 0 ]. type == QuantizationType . FLOAT .value
88
- and weight_args [ 0 ]. num_bits == 8
80
+ arg . strategy is not QuantizationStrategy . CHANNEL .value
81
+ and arg . strategy is not QuantizationStrategy . GROUP . value
89
82
):
90
- return CompressionFormat .float_quantized
91
- if weight_args [0 ].type == QuantizationType .INT .value :
92
- return CompressionFormat .int_quantized
83
+ # marlin24 kernel only applicable for channel/group quantization
84
+ return CompressionFormat .pack_quantized
85
+ return CompressionFormat .marlin_24
86
+ return CompressionFormat .pack_quantized
87
+ else : # w8a8 float and int
88
+ if len (weight_args ) == 1 :
89
+ if (
90
+ weight_args [0 ].type == QuantizationType .FLOAT .value
91
+ and weight_args [0 ].num_bits == 8
92
+ ):
93
+ return CompressionFormat .float_quantized
94
+ if weight_args [0 ].type == QuantizationType .INT .value :
95
+ return CompressionFormat .int_quantized
93
96
94
- return CompressionFormat .naive_quantized
95
- else :
96
- # format will be inferred from config
97
- return None
97
+ return CompressionFormat .naive_quantized
98
98
99
99
100
100
def _get_unique_quant_args (model ):
0 commit comments