Skip to content

Commit 1068c84

Browse files
authored
Revert "Enable module state_dict compression, simplify compression logic (#302)" (#306)
This reverts commit 4438d08.
1 parent 7477534 commit 1068c84

File tree

8 files changed

+165
-169
lines changed

8 files changed

+165
-169
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import re
2020
from contextlib import contextmanager
2121
from copy import deepcopy
22-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
22+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
2323

2424
import compressed_tensors
2525
import torch
@@ -36,12 +36,12 @@
3636
from compressed_tensors.quantization import (
3737
DEFAULT_QUANTIZATION_METHOD,
3838
QuantizationConfig,
39-
QuantizationScheme,
4039
QuantizationStatus,
4140
apply_quantization_config,
4241
load_pretrained_quantization_parameters,
4342
)
4443
from compressed_tensors.quantization.lifecycle import expand_target_names
44+
from compressed_tensors.quantization.quant_args import QuantizationArgs
4545
from compressed_tensors.quantization.utils import (
4646
is_module_quantized,
4747
iter_named_leaf_modules,
@@ -64,7 +64,7 @@
6464
from transformers.file_utils import CONFIG_NAME
6565

6666

67-
__all__ = ["ModelCompressor", "map_module_to_scheme"]
67+
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]
6868

6969
_LOGGER: logging.Logger = logging.getLogger(__name__)
7070

@@ -372,17 +372,20 @@ def compress(
372372
:param state_dict: optional uncompressed state_dict to insert into model
373373
:return: compressed state dict
374374
"""
375-
376375
if state_dict is None:
377376
state_dict = model.state_dict()
378377

378+
compressed_state_dict = state_dict
379+
380+
quantized_modules_to_args: Dict[
381+
str, QuantizationArgs
382+
] = map_modules_to_quant_args(model)
383+
379384
if self.quantization_compressor is not None:
380-
module_to_scheme = map_module_to_scheme(model)
381-
state_dict = self.quantization_compressor.compress(
382-
state_dict, names_to_scheme=module_to_scheme
385+
compressed_state_dict = self.quantization_compressor.compress(
386+
state_dict, names_to_scheme=quantized_modules_to_args
383387
)
384388

385-
# TODO: consider sparse compression to also be compression
386389
if self.quantization_config.format != CompressionFormat.dense.value:
387390
self.quantization_config.quantization_status = (
388391
QuantizationStatus.COMPRESSED
@@ -394,8 +397,8 @@ def compress(
394397
targets=self.sparsity_config.targets,
395398
ignore=self.sparsity_config.ignore,
396399
)
397-
state_dict = self.sparsity_compressor.compress(
398-
state_dict,
400+
compressed_state_dict = self.sparsity_compressor.compress(
401+
compressed_state_dict,
399402
compression_targets=sparse_compression_targets,
400403
)
401404

@@ -404,7 +407,7 @@ def compress(
404407
# https://github.com/huggingface/transformers/pull/30488
405408
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
406409

407-
return state_dict
410+
return compressed_state_dict
408411

409412
def decompress(self, model_path: str, model: Module):
410413
"""
@@ -602,15 +605,30 @@ def _replace_weights(self, dense_weight_generator, model: Module):
602605
update_parameter_data(module, param_data, param_name)
603606

604607

605-
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
608+
def map_modules_to_quant_args(
609+
model: Module,
610+
) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]:
606611
"""
607-
Returns a dictionary which maps quantized module names to their quantization schemes
612+
Given a pytorch model, map out the submodule name (usually linear layers)
613+
to the weight QuantizationArgs. If running input activation quantization, will also
614+
map to the input QuantizationArgs in a tuple.
615+
616+
:param model: pytorch model
608617
"""
609-
return {
610-
fix_fsdp_module_name(name): module.quantization_scheme
611-
for name, module in iter_named_leaf_modules(model)
612-
if is_module_quantized(module)
613-
}
618+
quantized_modules_to_args = {}
619+
for name, submodule in iter_named_leaf_modules(model):
620+
if is_module_quantized(submodule):
621+
if submodule.quantization_scheme.weights is not None:
622+
name = fix_fsdp_module_name(name)
623+
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
624+
if submodule.quantization_scheme.input_activations is not None:
625+
weight_args = quantized_modules_to_args.get(name)
626+
quantized_modules_to_args[name] = (
627+
weight_args,
628+
submodule.quantization_scheme.input_activations,
629+
)
630+
631+
return quantized_modules_to_args
614632

615633

616634
# HACK: Override the dtype_byte_size function in transformers to support float8 types

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 77 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@
1414

1515
import logging
1616
from pathlib import Path
17-
from typing import Any, Dict, Generator, Tuple, Union
17+
from typing import Any, Dict, Generator, Optional, Tuple, Union
1818

1919
import torch
2020
from compressed_tensors.compressors.base import BaseCompressor
21-
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
21+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
2222
from compressed_tensors.utils import (
2323
get_nested_mappings_from_state_dict,
2424
get_nested_weight_mappings,
2525
merge_names,
26-
remove_suffix,
2726
)
2827
from safetensors import safe_open
2928
from torch import Tensor
@@ -70,7 +69,7 @@ class BaseQuantizationCompressor(BaseCompressor):
7069
def compress(
7170
self,
7271
model_state: Dict[str, Tensor],
73-
names_to_scheme: Dict[str, QuantizationScheme],
72+
names_to_scheme: Dict[str, QuantizationArgs],
7473
**kwargs,
7574
) -> Dict[str, Tensor]:
7675
"""
@@ -82,87 +81,87 @@ def compress(
8281
:return: compressed state dict
8382
"""
8483
compressed_dict = {}
85-
save_device = "cpu"
86-
87-
uncompressed_names = list(model_state.keys())
88-
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
89-
value = model_state[name]
90-
91-
# compress weights
92-
if name.endswith("weight"):
93-
prefix = remove_suffix(name, "weight")
94-
95-
# gather qparams
96-
scale = model_state.get(prefix + "weight_scale", None)
97-
g_idx = model_state.get(prefix + "weight_g_idx", None)
98-
zp = model_state.get(prefix + "weight_zero_point", None)
99-
100-
# is scale does not exist, then weight cannot be compressed
101-
if scale is None:
102-
compressed_dict[name] = value.to(save_device)
103-
continue
104-
105-
# compress values on cpu (memory movement too expensive)
106-
module_path = prefix[:-1] if prefix.endswith(".") else prefix
107-
quant_args = names_to_scheme[module_path].weights
108-
compressed_values = self.compress_weight(
109-
weight=value,
110-
scale=scale,
111-
zero_point=zp,
112-
g_idx=g_idx,
113-
quantization_args=quant_args,
114-
device="cpu",
115-
)
116-
117-
# update state dict
118-
for key, value in compressed_values.items():
119-
compressed_dict[prefix + key] = value.to(save_device)
84+
weight_suffix = ".weight"
85+
input_zp_suffix = ".input_zero_point"
86+
weight_zp_suffix = ".weight_zero_point"
87+
_LOGGER.debug(
88+
f"Compressing model with {len(model_state)} parameterized layers..."
89+
)
12090

91+
for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
92+
# check if the parameter we're compressing is the weight zp
93+
# or the input zp
94+
is_weight_zp = name.endswith(weight_zp_suffix)
95+
is_input_zp = name.endswith(input_zp_suffix)
96+
97+
# if we're saving the weight zp, fetch weight quant args
98+
if is_weight_zp:
99+
quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))])
100+
if isinstance(quant_args_zp, tuple):
101+
# If tuple, first value is weight args, second is input args
102+
quant_args_zp = quant_args_zp[0]
103+
104+
# if we're saving the input zp, fetch input quant args
105+
if is_input_zp:
106+
input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))])
107+
if isinstance(input_args_zp, tuple):
108+
# If tuple, first value is weight args, second is input args
109+
input_args_zp = input_args_zp[-1]
110+
111+
if name.endswith(weight_suffix):
112+
prefix = name[: -(len(weight_suffix))]
113+
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
114+
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
115+
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
116+
if scale is not None:
117+
# weight is quantized, compress it
118+
if isinstance(names_to_scheme[prefix], tuple):
119+
quant_args = names_to_scheme[prefix][0]
120+
else:
121+
quant_args = names_to_scheme[prefix]
122+
123+
compressed_data = self.compress_weight(
124+
weight=value,
125+
scale=scale,
126+
zero_point=zp,
127+
g_idx=g_idx,
128+
quantization_args=quant_args,
129+
device="cpu",
130+
)
131+
for key, value in compressed_data.items():
132+
compressed_dict[merge_names(prefix, key)] = value
133+
else:
134+
compressed_dict[name] = value.to("cpu")
135+
# only save zp if asym and not packed zp
136+
elif is_weight_zp and (
137+
quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args)
138+
):
139+
continue
140+
# only save if asym
141+
elif is_input_zp and input_args_zp.symmetric:
142+
continue
143+
elif name.endswith("g_idx") and torch.any(value <= -1):
144+
continue
121145
else:
122-
# omit saving zero points for symmetric or packed quantization
123-
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
124-
continue
125-
126-
# omit saving for g_idx if uninitialized
127-
# TODO: does this case actually occur?
128-
elif name.endswith("g_idx") and torch.any(value <= -1):
129-
continue
130-
131-
compressed_dict[name] = value.to(save_device)
146+
compressed_dict[name] = value.to("cpu")
132147

133148
return compressed_dict
134149

135-
def _skip_zp(
136-
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
137-
) -> bool:
150+
def _check_if_zp_pack_quantized(self, quant_args):
138151
from compressed_tensors.compressors import PackedQuantizationCompressor
139152

140-
module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
141-
scheme = names_to_scheme[module_name]
142-
143-
if zp_name == "weight_zero_point":
144-
args = scheme.weights
145-
if zp_name == "input_zero_point":
146-
args = scheme.input_activations
147-
if zp_name == "output_zero_point":
148-
args = scheme.output_activations
149-
150-
symmetric = args.symmetric
151-
packable_strategies = [
152-
QuantizationStrategy.GROUP.value,
153-
QuantizationStrategy.CHANNEL.value,
154-
]
155-
packed = (
156-
isinstance(self, PackedQuantizationCompressor)
157-
and args.strategy in packable_strategies
158-
)
159-
160-
return symmetric or packed
153+
if isinstance(self, PackedQuantizationCompressor):
154+
if not quant_args.symmetric and quant_args.strategy in [
155+
QuantizationStrategy.GROUP.value,
156+
QuantizationStrategy.CHANNEL.value,
157+
]:
158+
return True
159+
return False
161160

162161
def decompress(
163162
self,
164163
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
165-
names_to_scheme: Dict[str, QuantizationScheme],
164+
names_to_scheme: Dict[str, QuantizationArgs],
166165
device: str = "cpu",
167166
) -> Generator[Tuple[str, Tensor], None, None]:
168167
"""
@@ -171,9 +170,8 @@ def decompress(
171170
dense state dict
172171
:param path_to_model_or_tensors: path to compressed safetensors model (directory
173172
with one or more safetensors files) or compressed tensors file
174-
:param names_to_scheme: quantization scheme for each quantized weight
175-
:param device: optional device to load intermediate weights into (must be `str`,
176-
not `torch.device`)
173+
:param names_to_scheme: quantization args for each quantized weight
174+
:param device: optional device to load intermediate weights into
177175
:return: compressed state dict
178176
"""
179177
if isinstance(path_to_model_or_tensors, (str, Path)):
@@ -186,12 +184,7 @@ def decompress(
186184
path_to_model_or_tensors, names_to_scheme
187185
)
188186

189-
def _decompress_from_path(
190-
self,
191-
path_to_model: Union[str, Path, Dict[str, Any]],
192-
names_to_scheme: Dict[str, QuantizationScheme],
193-
device: str,
194-
):
187+
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
195188
weight_mappings = get_nested_weight_mappings(
196189
path_to_model, self.compression_param_names
197190
)
@@ -202,7 +195,7 @@ def _decompress_from_path(
202195
with safe_open(safe_path, framework="pt", device=device) as f:
203196
weight_data[param_name] = f.get_tensor(full_name)
204197
if "weight_scale" in weight_data:
205-
quant_args = names_to_scheme[weight_name].weights
198+
quant_args = names_to_scheme[weight_name]
206199
decompressed = self.decompress_weight(
207200
compressed_data=weight_data, quantization_args=quant_args
208201
)

0 commit comments

Comments
 (0)