Skip to content

Commit 22c09f3

Browse files
authored
Replace COMPRESSION_PARAM_NAMES with Abstract Property (#249)
* Add compression_param_names abstract property * Review Comments - Get rid of COMPRESSION_PARAM_NAMES - Enforce implementation of compression_param_names - Make compression_param_names immutable by using tuple instead of list
1 parent 62c2bae commit 22c09f3

File tree

11 files changed

+80
-32
lines changed

11 files changed

+80
-32
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def compression_param_info(
7777
"""
7878
raise NotImplementedError()
7979

80+
@property
81+
@abstractmethod
82+
def compression_param_names(self) -> Tuple[str]:
83+
"""
84+
Returns a tuple of compression parameter names introduced by
85+
the compressor during compression
86+
"""
87+
raise NotImplementedError()
88+
8089
@abstractmethod
8190
def compress(
8291
self,

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def decompress(
144144

145145
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
146146
weight_mappings = get_nested_weight_mappings(
147-
path_to_model, self.COMPRESSION_PARAM_NAMES
147+
path_to_model, self.compression_param_names
148148
)
149149
for weight_name in weight_mappings.keys():
150150
weight_data = {}
@@ -161,7 +161,7 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
161161

162162
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
163163
weight_mappings = get_nested_mappings_from_state_dict(
164-
state_dict, self.COMPRESSION_PARAM_NAMES
164+
state_dict, self.compression_param_names
165165
)
166166
for weight_name in weight_mappings.keys():
167167
weight_data = {}

src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,18 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
4141
type to the type specified by the layer's QuantizationArgs.
4242
"""
4343

44-
COMPRESSION_PARAM_NAMES = [
45-
"weight",
46-
"weight_scale",
47-
"weight_zero_point",
48-
"weight_g_idx",
49-
]
44+
@property
45+
def compression_param_names(self) -> Tuple[str]:
46+
"""
47+
Returns a tuple of compression parameter names introduced by
48+
the compressor during compression
49+
"""
50+
return (
51+
"weight",
52+
"weight_scale",
53+
"weight_zero_point",
54+
"weight_g_idx",
55+
)
5056

5157
def compression_param_info(
5258
self,

src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,19 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
3636
Compresses a quantized model by packing every eight 4-bit weights into an int32
3737
"""
3838

39-
COMPRESSION_PARAM_NAMES = [
40-
"weight_packed",
41-
"weight_scale",
42-
"weight_zero_point",
43-
"weight_g_idx",
44-
"weight_shape",
45-
]
39+
@property
40+
def compression_param_names(self) -> Tuple[str]:
41+
"""
42+
Returns a tuple of compression parameter names introduced by
43+
the compressor during compression
44+
"""
45+
return (
46+
"weight_packed",
47+
"weight_scale",
48+
"weight_zero_point",
49+
"weight_g_idx",
50+
"weight_shape",
51+
)
4652

4753
def compression_param_info(
4854
self,

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
class BaseSparseCompressor(BaseCompressor):
3131
"""
3232
Base class representing a sparse compression algorithm. Each child class should
33-
implement compression_param_info, compress_weight and decompress_weight; child
34-
classes should also define COMPRESSION_PARAM_NAMES.
33+
implement compression_param_names, compress_weight and decompress_weight;
3534
3635
Compressors support compressing/decompressing a full module state dict or a single
3736
quantized PyTorch leaf module.
@@ -113,7 +112,7 @@ def decompress(
113112
"""
114113
weight_mappings, ignored_params = get_nested_weight_mappings(
115114
path_to_model_or_tensors,
116-
self.COMPRESSION_PARAM_NAMES,
115+
self.compression_param_names,
117116
return_unmatched_params=True,
118117
)
119118
for weight_name in weight_mappings.keys():

src/compressed_tensors/compressors/sparse_compressors/dense.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ class DenseCompressor(BaseCompressor):
2525
Identity compressor for dense models, returns the original state_dict
2626
"""
2727

28+
@property
29+
def compression_param_names(self) -> Tuple[str]:
30+
"""
31+
Returns a tuple of compression parameter names introduced by
32+
the compressor during compression
33+
"""
34+
return ()
35+
2836
def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
2937
return model_state
3038

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,17 @@ class Sparse24BitMaskCompressor(BaseSparseCompressor):
4040
values tensor, with their locations stored in a 2d bitmask
4141
"""
4242

43-
COMPRESSION_PARAM_NAMES = [
44-
"shape",
45-
"compressed",
46-
"bitmask",
47-
]
43+
@property
44+
def compression_param_names(self) -> Tuple[str]:
45+
"""
46+
Returns a tuple of compression parameter names introduced by
47+
the compressor during compression
48+
"""
49+
return (
50+
"shape",
51+
"compressed",
52+
"bitmask",
53+
)
4854

4955
def compress_weight(self, name, value):
5056
bitmask_tensor = Sparse24BitMaskTensor.from_dense(

src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ class BitmaskCompressor(BaseSparseCompressor):
3838
values tensor, with their locations stored in a 2d bitmask
3939
"""
4040

41-
COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"]
41+
@property
42+
def compression_param_names(self) -> Tuple[str]:
43+
"""
44+
Returns a tuple of compression parameter names introduced by
45+
the compressor during compression
46+
"""
47+
return ("shape", "compressed", "bitmask", "row_offsets")
4248

4349
def compress_weight(self, name, value):
4450
bitmask_tensor = BitmaskTensor.from_dense(value)

src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ class Marlin24Compressor(BaseCompressor):
4242
Marlin24 kernel. Decompression is not implemented for this compressor.
4343
"""
4444

45-
COMPRESSION_PARAM_NAMES = ["weight_packed", "scale_packed", "meta"]
46-
4745
@staticmethod
4846
def validate_quant_compatability(
4947
model_quant_args: Dict[str, QuantizationArgs]
@@ -105,6 +103,14 @@ def validate_sparsity_structure(name: str, weight: Tensor) -> bool:
105103

106104
return True
107105

106+
@property
107+
def compression_param_names(self) -> Tuple[str]:
108+
"""
109+
Returns a tuple of compression parameter names introduced by
110+
the compressor during compression
111+
"""
112+
return ("weight_packed", "scale_packed", "meta")
113+
108114
def compress(
109115
self,
110116
model_state: Dict[str, Tensor],

src/compressed_tensors/utils/safetensors_load.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import re
1818
import struct
19-
from typing import Dict, List, Optional, Tuple, Union
19+
from typing import Dict, Iterable, Optional, Tuple, Union
2020

2121
from safetensors import safe_open
2222
from torch import Tensor
@@ -180,7 +180,9 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:
180180

181181

182182
def get_nested_weight_mappings(
183-
model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False
183+
model_path: str,
184+
params_to_nest: Iterable[str],
185+
return_unmatched_params: bool = False,
184186
) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
185187
"""
186188
Takes a path to a state dict saved in safetensors format and returns a nested
@@ -211,7 +213,7 @@ def get_nested_weight_mappings(
211213
212214
:param model_path: Path to the safetensors state dict, must contain either a
213215
single safetensors file or multiple files with an index.
214-
:param params_to_nest: List of parameter names to nest.
216+
:param params_to_nest: Iterable of parameter names to nest.
215217
:param return_unmatched_params: If True, return a second dictionary containing
216218
the remaining parameters that were not matched to the params_to_nest.
217219
:return:
@@ -247,7 +249,7 @@ def get_nested_weight_mappings(
247249

248250

249251
def get_nested_mappings_from_state_dict(
250-
state_dict, params_to_nest
252+
state_dict, params_to_nest: Iterable[str]
251253
) -> NestedWeightMappingType:
252254
"""
253255
Takes a state dict and returns a nested mapping from uncompressed
@@ -262,7 +264,7 @@ def get_nested_mappings_from_state_dict(
262264
}
263265
264266
:param state_dict: state dict of the model
265-
:param params_to_nest: List of parameter names to nest.
267+
:param params_to_nest: Iterable of parameter names to nest.
266268
:return: Nested mapping of parameterized layer names to the value of
267269
each layer's compression parameters.
268270
"""

0 commit comments

Comments
 (0)