19
19
import torch
20
20
from compressed_tensors .compressors .base import BaseCompressor
21
21
from compressed_tensors .config import CompressionFormat
22
- from compressed_tensors .quantization import QuantizationArgs , QuantizationStrategy
22
+ from compressed_tensors .quantization import (
23
+ QuantizationArgs ,
24
+ QuantizationScheme ,
25
+ QuantizationStrategy ,
26
+ )
23
27
from compressed_tensors .quantization .lifecycle .forward import quantize
24
28
from compressed_tensors .utils import (
25
29
get_permutations_24 ,
@@ -44,19 +48,25 @@ class Marlin24Compressor(BaseCompressor):
44
48
45
49
@staticmethod
46
50
def validate_quant_compatability (
47
- model_quant_args : Dict [str , QuantizationArgs ]
51
+ names_to_scheme : Dict [str , QuantizationScheme ]
48
52
) -> bool :
49
53
"""
50
54
Checks if every quantized module in the model is compatible with Marlin24
51
55
compression. Quantization must be channel or group strategy with group_size
52
56
of 128. Only symmetric quantization is supported
53
57
54
- :param model_quant_args : dictionary of mapping module names to their
55
- quantization configuration
58
+ :param names_to_scheme : dictionary of mapping module names to their
59
+ quantization schemes
56
60
:return: True if all modules are compatible with Marlin24 compression, raises
57
61
a ValueError otherwise
58
62
"""
59
- for name , quant_args in model_quant_args .items ():
63
+ for name , scheme in names_to_scheme .items ():
64
+ quant_args = scheme .weights
65
+ if quant_args is None :
66
+ raise ValueError (
67
+ "Marlin24 Compressor is only valid for weight quantization schemes"
68
+ )
69
+
60
70
strategy = quant_args .strategy
61
71
group_size = quant_args .group_size
62
72
symmetric = quant_args .symmetric
@@ -114,16 +124,16 @@ def compression_param_names(self) -> Tuple[str]:
114
124
def compress (
115
125
self ,
116
126
model_state : Dict [str , Tensor ],
117
- names_to_scheme : Dict [str , QuantizationArgs ],
127
+ names_to_scheme : Dict [str , QuantizationScheme ],
118
128
** kwargs ,
119
129
) -> Dict [str , Tensor ]:
120
130
"""
121
131
Compresses a quantized state_dict with 2:4 sparsity structure for inference
122
132
with the Marlin24 kernel
123
133
124
134
:param model_state: state dict of uncompressed model
125
- :param names_to_scheme: quantization args for each quantized weight, needed for
126
- quantize function to calculate bit depth
135
+ :param names_to_scheme: quantization scheme for each quantized weight, needed
136
+ for quantize function to calculate bit depth
127
137
:return: compressed state dict
128
138
"""
129
139
self .validate_quant_compatability (names_to_scheme )
@@ -146,7 +156,7 @@ def compress(
146
156
value = value .to (torch .float16 )
147
157
148
158
# quantize weight, keeping it as a float16 for now
149
- quant_args = names_to_scheme [prefix ]
159
+ quant_args = names_to_scheme [prefix ]. weights
150
160
value = quantize (
151
161
x = value , scale = scale , zero_point = zp , args = quant_args
152
162
)
@@ -215,7 +225,7 @@ def pack_weight_24(
215
225
weight : Tensor ,
216
226
quantization_args : QuantizationArgs ,
217
227
tile : int = 16 ,
218
- ):
228
+ ) -> torch . Tensor :
219
229
size_k = weight .shape [0 ]
220
230
size_n = weight .shape [1 ]
221
231
num_bits = quantization_args .num_bits
@@ -236,7 +246,9 @@ def pack_weight_24(
236
246
return q_packed
237
247
238
248
239
- def pack_scales_24 (scales , quantization_args , w_shape ):
249
+ def pack_scales_24 (
250
+ scales : torch .Tensor , quantization_args : QuantizationArgs , w_shape : torch .Size
251
+ ) -> torch .Tensor :
240
252
size_k = w_shape [0 ]
241
253
size_n = w_shape [1 ]
242
254
num_bits = quantization_args .num_bits
0 commit comments