Skip to content

Commit 9524c7f

Browse files
committed
update marlin test, marlin uses scheme
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 6017f05 commit 9524c7f

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
import torch
2020
from compressed_tensors.compressors.base import BaseCompressor
2121
from compressed_tensors.config import CompressionFormat
22-
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
22+
from compressed_tensors.quantization import (
23+
QuantizationArgs,
24+
QuantizationScheme,
25+
QuantizationStrategy,
26+
)
2327
from compressed_tensors.quantization.lifecycle.forward import quantize
2428
from compressed_tensors.utils import (
2529
get_permutations_24,
@@ -44,19 +48,25 @@ class Marlin24Compressor(BaseCompressor):
4448

4549
@staticmethod
4650
def validate_quant_compatability(
47-
model_quant_args: Dict[str, QuantizationArgs]
51+
names_to_scheme: Dict[str, QuantizationScheme]
4852
) -> bool:
4953
"""
5054
Checks if every quantized module in the model is compatible with Marlin24
5155
compression. Quantization must be channel or group strategy with group_size
5256
of 128. Only symmetric quantization is supported
5357
54-
:param model_quant_args: dictionary of mapping module names to their
55-
quantization configuration
58+
:param names_to_scheme: dictionary of mapping module names to their
59+
quantization schemes
5660
:return: True if all modules are compatible with Marlin24 compression, raises
5761
a ValueError otherwise
5862
"""
59-
for name, quant_args in model_quant_args.items():
63+
for name, scheme in names_to_scheme.items():
64+
quant_args = scheme.weights
65+
if quant_args is None:
66+
raise ValueError(
67+
"Marlin24 Compressor is only valid for weight quantization schemes"
68+
)
69+
6070
strategy = quant_args.strategy
6171
group_size = quant_args.group_size
6272
symmetric = quant_args.symmetric
@@ -114,16 +124,16 @@ def compression_param_names(self) -> Tuple[str]:
114124
def compress(
115125
self,
116126
model_state: Dict[str, Tensor],
117-
names_to_scheme: Dict[str, QuantizationArgs],
127+
names_to_scheme: Dict[str, QuantizationScheme],
118128
**kwargs,
119129
) -> Dict[str, Tensor]:
120130
"""
121131
Compresses a quantized state_dict with 2:4 sparsity structure for inference
122132
with the Marlin24 kernel
123133
124134
:param model_state: state dict of uncompressed model
125-
:param names_to_scheme: quantization args for each quantized weight, needed for
126-
quantize function to calculate bit depth
135+
:param names_to_scheme: quantization scheme for each quantized weight, needed
136+
for quantize function to calculate bit depth
127137
:return: compressed state dict
128138
"""
129139
self.validate_quant_compatability(names_to_scheme)
@@ -146,7 +156,7 @@ def compress(
146156
value = value.to(torch.float16)
147157

148158
# quantize weight, keeping it as a float16 for now
149-
quant_args = names_to_scheme[prefix]
159+
quant_args = names_to_scheme[prefix].weights
150160
value = quantize(
151161
x=value, scale=scale, zero_point=zp, args=quant_args
152162
)
@@ -215,7 +225,7 @@ def pack_weight_24(
215225
weight: Tensor,
216226
quantization_args: QuantizationArgs,
217227
tile: int = 16,
218-
):
228+
) -> torch.Tensor:
219229
size_k = weight.shape[0]
220230
size_n = weight.shape[1]
221231
num_bits = quantization_args.num_bits
@@ -236,7 +246,9 @@ def pack_weight_24(
236246
return q_packed
237247

238248

239-
def pack_scales_24(scales, quantization_args, w_shape):
249+
def pack_scales_24(
250+
scales: torch.Tensor, quantization_args: QuantizationArgs, w_shape: torch.Size
251+
) -> torch.Tensor:
240252
size_k = w_shape[0]
241253
size_n = w_shape[1]
242254
num_bits = quantization_args.num_bits

tests/test_compressors/sparse_quantized_compressors/test_marlin_24.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from compressed_tensors.compressors import (
2020
BaseCompressor,
2121
Marlin24Compressor,
22-
map_modules_to_quant_args,
22+
map_module_to_scheme,
2323
)
2424
from compressed_tensors.config import CompressionFormat
2525
from compressed_tensors.quantization import (
@@ -92,9 +92,9 @@ def test_marlin24_format(
9292
assert f"{NOT_QUANT_NAME}.weight_scale" not in state_dict
9393
assert f"{QUANT_NAME}.weight_scale" in state_dict
9494

95-
model_to_quant_args = map_modules_to_quant_args(model)
95+
module_to_scheme = map_module_to_scheme(model)
9696
compressor = Marlin24Compressor()
97-
compressor.validate_quant_compatability(model_to_quant_args)
97+
compressor.validate_quant_compatability(module_to_scheme)
9898
compressor.validate_sparsity_structure(
9999
QUANT_NAME, state_dict[f"{QUANT_NAME}.weight"]
100100
)
@@ -104,7 +104,7 @@ def test_marlin24_format(
104104
)
105105

106106
compressor = Marlin24Compressor()
107-
compressed_state_dict = compressor.compress(state_dict, model_to_quant_args)
107+
compressed_state_dict = compressor.compress(state_dict, module_to_scheme)
108108

109109
assert len(compressed_state_dict) == 4
110110
assert torch.equal(

0 commit comments

Comments
 (0)