Skip to content

Commit 97bda13

Browse files
committed
use map_module_to_scheme, _should_save_zp
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 16e6435 commit 97bda13

File tree

8 files changed

+185
-165
lines changed

8 files changed

+185
-165
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,18 @@
1919
import re
2020
from contextlib import contextmanager
2121
from copy import deepcopy
22-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
22+
from typing import (
23+
TYPE_CHECKING,
24+
Any,
25+
Callable,
26+
Dict,
27+
List,
28+
Optional,
29+
Set,
30+
Tuple,
31+
TypeVar,
32+
Union,
33+
)
2334

2435
import compressed_tensors
2536
import torch
@@ -33,15 +44,16 @@
3344
from compressed_tensors.compressors.base import BaseCompressor
3445
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
3546
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
47+
from compressed_tensors.linear.compressed_linear import CompressedLinear
3648
from compressed_tensors.quantization import (
3749
DEFAULT_QUANTIZATION_METHOD,
3850
QuantizationConfig,
51+
QuantizationScheme,
3952
QuantizationStatus,
4053
apply_quantization_config,
4154
load_pretrained_quantization_parameters,
4255
)
4356
from compressed_tensors.quantization.lifecycle import expand_target_names
44-
from compressed_tensors.quantization.quant_args import QuantizationArgs
4557
from compressed_tensors.quantization.utils import (
4658
is_module_quantized,
4759
iter_named_leaf_modules,
@@ -50,21 +62,23 @@
5062
get_safetensors_folder,
5163
has_offloaded_params,
5264
merge_names,
65+
module_replace_dfs,
5366
register_offload_parameter,
5467
update_parameter_data,
5568
)
5669
from compressed_tensors.utils.helpers import (
5770
fix_fsdp_module_name,
5871
is_compressed_tensors_config,
5972
)
73+
from compressed_tensors.utils.offload import update_offload_parameter
6074
from torch import Tensor
6175
from torch.nn import Module
6276
from tqdm import tqdm
6377
from transformers import AutoConfig
6478
from transformers.file_utils import CONFIG_NAME
6579

6680

67-
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]
81+
__all__ = ["ModelCompressor", "map_module_to_scheme"]
6882

6983
_LOGGER: logging.Logger = logging.getLogger(__name__)
7084

@@ -372,20 +386,17 @@ def compress(
372386
:param state_dict: optional uncompressed state_dict to insert into model
373387
:return: compressed state dict
374388
"""
389+
375390
if state_dict is None:
376391
state_dict = model.state_dict()
377392

378-
compressed_state_dict = state_dict
379-
380-
quantized_modules_to_args: Dict[
381-
str, QuantizationArgs
382-
] = map_modules_to_quant_args(model)
383-
384393
if self.quantization_compressor is not None:
385-
compressed_state_dict = self.quantization_compressor.compress(
386-
state_dict, names_to_scheme=quantized_modules_to_args
394+
module_to_scheme = map_module_to_scheme(model)
395+
state_dict = self.quantization_compressor.compress(
396+
state_dict, names_to_scheme=module_to_scheme
387397
)
388398

399+
# TODO: consider sparse compression to also be compression
389400
if self.quantization_config.format != CompressionFormat.dense.value:
390401
self.quantization_config.quantization_status = (
391402
QuantizationStatus.COMPRESSED
@@ -397,8 +408,8 @@ def compress(
397408
targets=self.sparsity_config.targets,
398409
ignore=self.sparsity_config.ignore,
399410
)
400-
compressed_state_dict = self.sparsity_compressor.compress(
401-
compressed_state_dict,
411+
state_dict = self.sparsity_compressor.compress(
412+
state_dict,
402413
compression_targets=sparse_compression_targets,
403414
)
404415

@@ -407,7 +418,7 @@ def compress(
407418
# https://github.com/huggingface/transformers/pull/30488
408419
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
409420

410-
return compressed_state_dict
421+
return state_dict
411422

412423
def decompress(self, model_path: str, model: Module):
413424
"""
@@ -605,30 +616,15 @@ def _replace_weights(self, dense_weight_generator, model: Module):
605616
update_parameter_data(module, param_data, param_name)
606617

607618

608-
def map_modules_to_quant_args(
609-
model: Module,
610-
) -> Dict[str, Union[QuantizationArgs, Tuple[QuantizationArgs, QuantizationArgs]]]:
619+
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
611620
"""
612-
Given a pytorch model, map out the submodule name (usually linear layers)
613-
to the weight QuantizationArgs. If running input activation quantization, will also
614-
map to the input QuantizationArgs in a tuple.
615-
616-
:param model: pytorch model
621+
Returns a dictionary which maps quantized module names to their quantization schemes
617622
"""
618-
quantized_modules_to_args = {}
619-
for name, submodule in iter_named_leaf_modules(model):
620-
if is_module_quantized(submodule):
621-
if submodule.quantization_scheme.weights is not None:
622-
name = fix_fsdp_module_name(name)
623-
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
624-
if submodule.quantization_scheme.input_activations is not None:
625-
weight_args = quantized_modules_to_args.get(name)
626-
quantized_modules_to_args[name] = (
627-
weight_args,
628-
submodule.quantization_scheme.input_activations,
629-
)
630-
631-
return quantized_modules_to_args
623+
return {
624+
fix_fsdp_module_name(name): module.quantization_scheme
625+
for name, module in iter_named_leaf_modules(model)
626+
if is_module_quantized(module)
627+
}
632628

633629

634630
# HACK: Override the dtype_byte_size function in transformers to support float8 types

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 84 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818

1919
import torch
2020
from compressed_tensors.compressors.base import BaseCompressor
21-
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
21+
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
2222
from compressed_tensors.utils import (
2323
get_nested_mappings_from_state_dict,
2424
get_nested_weight_mappings,
2525
merge_names,
26+
remove_suffix,
2627
)
2728
from safetensors import safe_open
2829
from torch import Tensor
@@ -69,7 +70,7 @@ class BaseQuantizationCompressor(BaseCompressor):
6970
def compress(
7071
self,
7172
model_state: Dict[str, Tensor],
72-
names_to_scheme: Dict[str, QuantizationArgs],
73+
names_to_scheme: Dict[str, QuantizationScheme],
7374
**kwargs,
7475
) -> Dict[str, Tensor]:
7576
"""
@@ -81,96 +82,98 @@ def compress(
8182
:return: compressed state dict
8283
"""
8384
compressed_dict = {}
84-
weight_suffix = ".weight"
85-
input_zp_suffix = ".input_zero_point"
86-
weight_zp_suffix = ".weight_zero_point"
87-
_LOGGER.debug(
88-
f"Compressing model with {len(model_state)} parameterized layers..."
89-
)
85+
save_device = "cpu"
86+
87+
uncompressed_names = list(model_state.keys())
88+
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
89+
value = model_state[name]
90+
91+
# compress weights
92+
if name.endswith("weight"):
93+
prefix = remove_suffix(name, "weight")
94+
95+
# gather qparams
96+
scale = model_state.get(prefix + "weight_scale", None)
97+
g_idx = model_state.get(prefix + "weight_g_idx", None)
98+
zp = model_state.get(prefix + "weight_zero_point", None)
99+
100+
# is scale does not exist, then weight cannot be compressed
101+
if scale is None:
102+
compressed_dict[name] = value.to(save_device)
103+
continue
104+
105+
# compress values on cpu (memory movement too expensive)
106+
module_path = prefix[:-1] if prefix.endswith(".") else prefix
107+
quant_args = names_to_scheme[module_path].weights
108+
compressed_values = self.compress_weight(
109+
weight=value,
110+
scale=scale,
111+
zero_point=zp,
112+
g_idx=g_idx,
113+
quantization_args=quant_args,
114+
device="cpu",
115+
)
116+
117+
# update state dict
118+
for key, value in compressed_values.items():
119+
compressed_dict[prefix + key] = value.to(save_device)
90120

91-
for name, value in tqdm(model_state.items(), desc="Quantized Compression"):
92-
# check if the parameter we're compressing is the weight zp
93-
# or the input zp
94-
is_weight_zp = name.endswith(weight_zp_suffix)
95-
is_input_zp = name.endswith(input_zp_suffix)
96-
97-
# if we're saving the weight zp, fetch weight quant args
98-
if is_weight_zp:
99-
quant_args_zp = names_to_scheme.get(name[: -(len(weight_zp_suffix))])
100-
if isinstance(quant_args_zp, tuple):
101-
# If tuple, first value is weight args, second is input args
102-
quant_args_zp = quant_args_zp[0]
103-
104-
# if we're saving the input zp, fetch input quant args
105-
if is_input_zp:
106-
input_args_zp = names_to_scheme.get(name[: -(len(input_zp_suffix))])
107-
if isinstance(input_args_zp, tuple):
108-
# If tuple, first value is weight args, second is input args
109-
input_args_zp = input_args_zp[-1]
110-
111-
if name.endswith(weight_suffix):
112-
prefix = name[: -(len(weight_suffix))]
113-
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
114-
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
115-
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
116-
if scale is not None:
117-
# weight is quantized, compress it
118-
if isinstance(names_to_scheme[prefix], tuple):
119-
quant_args = names_to_scheme[prefix][0]
120-
else:
121-
quant_args = names_to_scheme[prefix]
122-
123-
compressed_data = self.compress_weight(
124-
weight=value,
125-
scale=scale,
126-
zero_point=zp,
127-
g_idx=g_idx,
128-
quantization_args=quant_args,
129-
device="cpu",
130-
)
131-
for key, value in compressed_data.items():
132-
compressed_dict[merge_names(prefix, key)] = value
133-
else:
134-
compressed_dict[name] = value.to("cpu")
135-
# only save zp if asym and not packed zp
136-
elif is_weight_zp and (
137-
quant_args_zp.symmetric or self._check_if_zp_pack_quantized(quant_args)
138-
):
139-
continue
140-
# only save if asym
141-
elif is_input_zp and input_args_zp.symmetric:
142-
continue
143-
elif name.endswith("g_idx") and torch.any(value <= -1):
144-
continue
145121
else:
146-
compressed_dict[name] = value.to("cpu")
122+
# omit saving zero points for symmetric quantization
123+
if name.endswith("zero_point") and not self._should_save_zp(
124+
name, names_to_scheme
125+
):
126+
continue
127+
128+
# omit saving for g_idx if uninitialized
129+
# TODO: does this case actually occur?
130+
elif name.endswith("g_idx") and torch.any(value <= -1):
131+
continue
132+
133+
compressed_dict[name] = value.to(save_device)
147134

148135
return compressed_dict
149136

150-
def _check_if_zp_pack_quantized(self, quant_args):
137+
def _should_save_zp(
138+
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
139+
) -> bool:
151140
from compressed_tensors.compressors import PackedQuantizationCompressor
152141

153-
if isinstance(self, PackedQuantizationCompressor):
154-
if not quant_args.symmetric and quant_args.strategy in [
155-
QuantizationStrategy.GROUP.value,
156-
QuantizationStrategy.CHANNEL.value,
157-
]:
158-
return True
159-
return False
142+
module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
143+
scheme = names_to_scheme[module_name]
144+
145+
if zp_name == "weight_zero_point":
146+
args = scheme.weights
147+
if zp_name == "input_zero_point":
148+
args = scheme.input_activations
149+
if zp_name == "output_zero_point":
150+
args = scheme.output_activations
151+
152+
symmetric = args.symmetric
153+
packable_strats = [
154+
QuantizationStrategy.GROUP.value,
155+
QuantizationStrategy.CHANNEL.value,
156+
]
157+
packed = (
158+
isinstance(self, PackedQuantizationCompressor)
159+
and args.strategy in packable_strats
160+
)
161+
162+
return not symmetric and not packed
160163

161164
def decompress(
162165
self,
163166
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
164-
names_to_scheme: Dict[str, QuantizationArgs],
165-
device: str = "cpu",
167+
names_to_scheme: Dict[str, QuantizationScheme],
168+
device: torch.device = "cpu",
166169
) -> Generator[Tuple[str, Tensor], None, None]:
167170
"""
168171
Reads a compressed state dict located at path_to_model_or_tensors
169172
and returns a generator for sequentially decompressing back to a
170173
dense state dict
171174
:param path_to_model_or_tensors: path to compressed safetensors model (directory
172175
with one or more safetensors files) or compressed tensors file
173-
:param names_to_scheme: quantization args for each quantized weight
176+
:param names_to_scheme: quantization scheme for each quantized weight
174177
:param device: optional device to load intermediate weights into
175178
:return: compressed state dict
176179
"""
@@ -184,7 +187,12 @@ def decompress(
184187
path_to_model_or_tensors, names_to_scheme
185188
)
186189

187-
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
190+
def _decompress_from_path(
191+
self,
192+
path_to_model: Union[str, Path, Dict[str, Any]],
193+
names_to_scheme: Dict[str, QuantizationScheme],
194+
device: torch.device,
195+
):
188196
weight_mappings = get_nested_weight_mappings(
189197
path_to_model, self.compression_param_names
190198
)
@@ -195,7 +203,7 @@ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
195203
with safe_open(safe_path, framework="pt", device=device) as f:
196204
weight_data[param_name] = f.get_tensor(full_name)
197205
if "weight_scale" in weight_data:
198-
quant_args = names_to_scheme[weight_name]
206+
quant_args = names_to_scheme[weight_name].weights
199207
decompressed = self.decompress_weight(
200208
compressed_data=weight_data, quantization_args=quant_args
201209
)

0 commit comments

Comments
 (0)