Skip to content

Commit c84b5b4

Browse files
authored
Enable module state_dict compression, simplify compression logic (#307)
* use map_module_to_scheme, _should_save_zp Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unused import Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * remove unused imports Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * rename to _skip_zp Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * type hint nit Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * safetensors function signature is weird Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * make apply_quantization_config return module to scheme, as stated Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 6148fef commit c84b5b4

File tree

9 files changed

+173
-169
lines changed

9 files changed

+173
-169
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import re
2020
from contextlib import contextmanager
2121
from copy import deepcopy
22-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
22+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
2323

2424
import compressed_tensors
2525
import torch
@@ -36,12 +36,12 @@
3636
from compressed_tensors.quantization import (
3737
DEFAULT_QUANTIZATION_METHOD,
3838
QuantizationConfig,
39+
QuantizationScheme,
3940
QuantizationStatus,
4041
apply_quantization_config,
4142
load_pretrained_quantization_parameters,
4243
)
4344
from compressed_tensors.quantization.lifecycle import expand_target_names
44-
from compressed_tensors.quantization.quant_args import QuantizationArgs
4545
from compressed_tensors.quantization.utils import (
4646
is_module_quantized,
4747
iter_named_leaf_modules,
@@ -64,7 +64,7 @@
6464
from transformers.file_utils import CONFIG_NAME
6565

6666

67-
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]
67+
__all__ = ["ModelCompressor", "map_module_to_scheme"]
6868

6969
_LOGGER: logging.Logger = logging.getLogger(__name__)
7070

@@ -372,20 +372,17 @@ def compress(
372372
:param state_dict: optional uncompressed state_dict to insert into model
373373
:return: compressed state dict
374374
"""
375+
375376
if state_dict is None:
376377
state_dict = model.state_dict()
377378

378-
compressed_state_dict = state_dict
379-
380-
quantized_modules_to_args: Dict[
381-
str, QuantizationArgs
382-
] = map_modules_to_quant_args(model)
383-
384379
if self.quantization_compressor is not None:
385-
compressed_state_dict = self.quantization_compressor.compress(
386-
state_dict, names_to_scheme=quantized_modules_to_args
380+
module_to_scheme = map_module_to_scheme(model)
381+
state_dict = self.quantization_compressor.compress(
382+
state_dict, names_to_scheme=module_to_scheme
387383
)
388384

385+
# TODO: consider sparse compression to also be compression
389386
if self.quantization_config.format != CompressionFormat.dense.value:
390387
self.quantization_config.quantization_status = (
391388
QuantizationStatus.COMPRESSED
@@ -397,8 +394,8 @@ def compress(
397394
targets=self.sparsity_config.targets,
398395
ignore=self.sparsity_config.ignore,
399396
)
400-
compressed_state_dict = self.sparsity_compressor.compress(
401-
compressed_state_dict,
397+
state_dict = self.sparsity_compressor.compress(
398+
state_dict,
402399
compression_targets=sparse_compression_targets,
403400
)
404401

@@ -407,7 +404,7 @@ def compress(
407404
# https://github.com/huggingface/transformers/pull/30488
408405
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
409406

410-
return compressed_state_dict
407+
return state_dict
411408

412409
def decompress(self, model_path: str, model: Module):
413410
"""
@@ -605,30 +602,15 @@ def _replace_weights(self, dense_weight_generator, model: Module):
605602
update_parameter_data(module, param_data, param_name)
606603

607604

608-
def map_modules_to_quant_args(
609-
model: Module,
610-
) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]:
605+
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
611606
"""
612-
Given a pytorch model, map out the submodule name (usually linear layers)
613-
to the weight QuantizationArgs. If running input activation quantization, will also
614-
map to the input QuantizationArgs in a tuple.
615-
616-
:param model: pytorch model
607+
Returns a dictionary which maps quantized module names to their quantization schemes
617608
"""
618-
quantized_modules_to_args = {}
619-
for name, submodule in iter_named_leaf_modules(model):
620-
if is_module_quantized(submodule):
621-
if submodule.quantization_scheme.weights is not None:
622-
name = fix_fsdp_module_name(name)
623-
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
624-
if submodule.quantization_scheme.input_activations is not None:
625-
weight_args = quantized_modules_to_args.get(name)
626-
quantized_modules_to_args[name] = (
627-
weight_args,
628-
submodule.quantization_scheme.input_activations,
629-
)
630-
631-
return quantized_modules_to_args
609+
return {
610+
fix_fsdp_module_name(name): module.quantization_scheme
611+
for name, module in iter_named_leaf_modules(model)
612+
if is_module_quantized(module)
613+
}
632614

633615

634616
# HACK: Override the dtype_byte_size function in transformers to support float8 types

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 84 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414

1515
import logging
1616
from pathlib import Path
17-
from typing import Any, Dict, Generator, Optional, Tuple, Union
17+
from typing import Any, Dict, Generator, Tuple, Union
1818

1919
import torch
2020
from compressed_tensors.compressors.base import BaseCompressor
21-
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
21+
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
2222
from compressed_tensors.utils import (
2323
get_nested_mappings_from_state_dict,
2424
get_nested_weight_mappings,
2525
merge_names,
26+
remove_suffix,
2627
)
2728
from safetensors import safe_open
2829
from torch import Tensor
@@ -69,7 +70,7 @@ class BaseQuantizationCompressor(BaseCompressor):
6970
def compress(
7071
self,
7172
model_state: Dict[str, Tensor],
72-
names_to_scheme: Dict[str, QuantizationArgs],
73+
names_to_scheme: Dict[str, QuantizationScheme],
7374
**kwargs,
7475
) -> Dict[str, Tensor]:
7576
"""
@@ -81,87 +82,87 @@ def compress(
8182
:return: compressed state dict
8283
"""
8384
compressed_dict = {}
84-
weight_suffix = ".weight"
85-
input_zp_suffix = ".input_zero_point"
86-
weight_zp_suffix = ".weight_zero_point"
87-
_LOGGER.debug(
88-
f"Compressing model with {len(model_state)} parameterized layers..."
89-
)
85+
save_device = "cpu"
86+
87+
uncompressed_names = list(model_state.keys())
88+
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
89+
value = model_state[name]
90+
91+
# compress weights
92+
if name.endswith("weight"):
93+
prefix = remove_suffix(name, "weight")
94+
95+
# gather qparams
96+
scale = model_state.get(prefix + "weight_scale", None)
97+
g_idx = model_state.get(prefix + "weight_g_idx", None)
98+
zp = model_state.get(prefix + "weight_zero_point", None)
99+
100+
# is scale does not exist, then weight cannot be compressed
101+
if scale is None:
102+
compressed_dict[name] = value.to(save_device)
103+
continue
104+
105+
# compress values on cpu (memory movement too expensive)
106+
module_path = prefix[:-1] if prefix.endswith(".") else prefix
107+
quant_args = names_to_scheme[module_path].weights
108+
compressed_values = self.compress_weight(
109+
weight=value,
110+
scale=scale,
111+
zero_point=zp,
112+
g_idx=g_idx,
113+
quantization_args=quant_args,
114+
device="cpu",
115+
)
116+
117+
# update state dict
118+
for key, value in compressed_values.items():
119+
compressed_dict[prefix + key] = value.to(save_device)
90120

91-
for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
92-
# check if the parameter we're compressing is the weight zp
93-
# or the input zp
94-
is_weight_zp = name.endswith(weight_zp_suffix)
95-
is_input_zp = name.endswith(input_zp_suffix)
96-
97-
# if we're saving the weight zp, fetch weight quant args
98-
if is_weight_zp:
99-
quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))])
100-
if isinstance(quant_args_zp, tuple):
101-
# If tuple, first value is weight args, second is input args
102-
quant_args_zp = quant_args_zp[0]
103-
104-
# if we're saving the input zp, fetch input quant args
105-
if is_input_zp:
106-
input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))])
107-
if isinstance(input_args_zp, tuple):
108-
# If tuple, first value is weight args, second is input args
109-
input_args_zp = input_args_zp[-1]
110-
111-
if name.endswith(weight_suffix):
112-
prefix = name[: -(len(weight_suffix))]
113-
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
114-
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
115-
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
116-
if scale is not None:
117-
# weight is quantized, compress it
118-
if isinstance(names_to_scheme[prefix], tuple):
119-
quant_args = names_to_scheme[prefix][0]
120-
else:
121-
quant_args = names_to_scheme[prefix]
122-
123-
compressed_data = self.compress_weight(
124-
weight=value,
125-
scale=scale,
126-
zero_point=zp,
127-
g_idx=g_idx,
128-
quantization_args=quant_args,
129-
device="cpu",
130-
)
131-
for key, value in compressed_data.items():
132-
compressed_dict[merge_names(prefix, key)] = value
133-
else:
134-
compressed_dict[name] = value.to("cpu")
135-
# only save zp if asym and not packed zp
136-
elif is_weight_zp and (
137-
quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args)
138-
):
139-
continue
140-
# only save if asym
141-
elif is_input_zp and input_args_zp.symmetric:
142-
continue
143-
elif name.endswith("g_idx") and torch.any(value <= -1):
144-
continue
145121
else:
146-
compressed_dict[name] = value.to("cpu")
122+
# omit saving zero points for symmetric or packed quantization
123+
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
124+
continue
125+
126+
# omit saving for g_idx if uninitialized
127+
# TODO: does this case actually occur?
128+
elif name.endswith("g_idx") and torch.any(value <= -1):
129+
continue
130+
131+
compressed_dict[name] = value.to(save_device)
147132

148133
return compressed_dict
149134

150-
def _check_if_zp_pack_quantized(self, quant_args):
135+
def _skip_zp(
136+
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
137+
) -> bool:
151138
from compressed_tensors.compressors import PackedQuantizationCompressor
152139

153-
if isinstance(self, PackedQuantizationCompressor):
154-
if not quant_args.symmetric and quant_args.strategy in [
155-
QuantizationStrategy.GROUP.value,
156-
QuantizationStrategy.CHANNEL.value,
157-
]:
158-
return True
159-
return False
140+
module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
141+
scheme = names_to_scheme[module_name]
142+
143+
if zp_name == "weight_zero_point":
144+
args = scheme.weights
145+
if zp_name == "input_zero_point":
146+
args = scheme.input_activations
147+
if zp_name == "output_zero_point":
148+
args = scheme.output_activations
149+
150+
symmetric = args.symmetric
151+
packable_strategies = [
152+
QuantizationStrategy.GROUP.value,
153+
QuantizationStrategy.CHANNEL.value,
154+
]
155+
packed = (
156+
isinstance(self, PackedQuantizationCompressor)
157+
and args.strategy in packable_strategies
158+
)
159+
160+
return symmetric or packed
160161

161162
def decompress(
162163
self,
163164
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
164-
names_to_scheme: Dict[str, QuantizationArgs],
165+
names_to_scheme: Dict[str, QuantizationScheme],
165166
device: str = "cpu",
166167
) -> Generator[Tuple[str, Tensor], None, None]:
167168
"""
@@ -170,8 +171,9 @@ def decompress(
170171
dense state dict
171172
:param path_to_model_or_tensors: path to compressed safetensors model (directory
172173
with one or more safetensors files) or compressed tensors file
173-
:param names_to_scheme: quantization args for each quantized weight
174-
:param device: optional device to load intermediate weights into
174+
:param names_to_scheme: quantization scheme for each quantized weight
175+
:param device: optional device to load intermediate weights into (must be `str`,
176+
not `torch.device`)
175177
:return: compressed state dict
176178
"""
177179
if isinstance(path_to_model_or_tensors, (str, Path)):
@@ -184,7 +186,12 @@ def decompress(
184186
path_to_model_or_tensors, names_to_scheme
185187
)
186188

187-
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
189+
def _decompress_from_path(
190+
self,
191+
path_to_model: Union[str, Path, Dict[str, Any]],
192+
names_to_scheme: Dict[str, QuantizationScheme],
193+
device: str,
194+
):
188195
weight_mappings = get_nested_weight_mappings(
189196
path_to_model, self.compression_param_names
190197
)
@@ -195,7 +202,7 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
195202
with safe_open(safe_path, framework="pt", device=device) as f:
196203
weight_data[param_name] = f.get_tensor(full_name)
197204
if "weight_scale" in weight_data:
198-
quant_args = names_to_scheme[weight_name]
205+
quant_args = names_to_scheme[weight_name].weights
199206
decompressed = self.decompress_weight(
200207
compressed_data=weight_data, quantization_args=quant_args
201208
)

0 commit comments

Comments
 (0)