Skip to content

Commit dfef94d

Browse files
committed
Merge remote-tracking branch 'origin' into kylesayrs/reduce-quantized-compression-memory
2 parents b5374ae + 16e6435 commit dfef94d

File tree

12 files changed

+247
-97
lines changed

12 files changed

+247
-97
lines changed

src/compressed_tensors/compressors/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from compressed_tensors.config import SparsityCompressionConfig
2020
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
2121
from compressed_tensors.registry import RegistryMixin
22+
from compressed_tensors.utils import has_offloaded_params
2223
from torch import Tensor
2324
from torch.nn import Module
2425

@@ -169,6 +170,10 @@ def decompress_module(self, module: Module):
169170
:param module: PyTorch module to decompress
170171
:return: tensor of the decompressed weight, or None if module is not quantized
171172
"""
173+
174+
params_device = next(module.parameters()).device
175+
device = "cpu" if has_offloaded_params(module) else params_device
176+
172177
if not hasattr(module, "quantization_scheme"):
173178
return None # module is not quantized
174179
quantization_scheme = module.quantization_scheme
@@ -182,7 +187,7 @@ def decompress_module(self, module: Module):
182187

183188
return self.decompress_weight(
184189
compressed_data=compressed_data, quantization_args=quantization_args
185-
)
190+
).to(device)
186191

187192
def decompress_weight(
188193
self, compressed_data: Dict[str, Tensor], **kwargs

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
SPARSITY_CONFIG_NAME,
4343
)
4444
from compressed_tensors.compressors.base import BaseCompressor
45+
from compressed_tensors.compressors.sparse_compressors import DenseCompressor
4546
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
4647
from compressed_tensors.linear.compressed_linear import CompressedLinear
4748
from compressed_tensors.quantization import (
@@ -50,7 +51,7 @@
5051
QuantizationScheme,
5152
QuantizationStatus,
5253
apply_quantization_config,
53-
load_pretrained_quantization,
54+
load_pretrained_quantization_parameters,
5455
)
5556
from compressed_tensors.quantization.lifecycle import expand_target_names
5657
from compressed_tensors.quantization.utils import (
@@ -59,8 +60,10 @@
5960
)
6061
from compressed_tensors.utils import (
6162
get_safetensors_folder,
63+
has_offloaded_params,
6264
merge_names,
6365
module_replace_dfs,
66+
register_offload_parameter,
6467
update_parameter_data,
6568
)
6669
from compressed_tensors.utils.helpers import (
@@ -448,6 +451,13 @@ def decompress(self, model_path: str, model: Module):
448451
449452
:param model_path: path to compressed weights
450453
:param model: pytorch model to load decompressed weights into
454+
455+
Note: decompress makes use of both _replace_sparsity_weights and _replace_weights
456+
The variations in these methods are a result of the subtle variations between the sparsity
457+
and quantization compressors. Specifically, quantization compressors return not just the
458+
decompressed weight, but the quantization parameters (e.g scales, zero_point) whereas sparsity
459+
compressors only return the decompressed weight.
460+
451461
"""
452462
model_path = get_safetensors_folder(model_path)
453463
sparse_decompressed = False
@@ -456,9 +466,16 @@ def decompress(self, model_path: str, model: Module):
456466
self.sparsity_compressor is not None
457467
and self.sparsity_config.format != CompressionFormat.dense.value
458468
):
469+
params_to_ignore = None
470+
if self.quantization_compressor is not None:
471+
params_to_ignore = self.quantization_compressor.compression_param_names
459472
# Sparse decompression is applied on the model_path
460-
dense_gen = self.sparsity_compressor.decompress(model_path)
461-
self._replace_weights(dense_gen, model)
473+
# The compressor will try and load any quantization parameters as well
474+
# params_to_skip_load will skip over quantization params from being loaded
475+
dense_gen = self.sparsity_compressor.decompress(
476+
model_path, params_to_skip_load=params_to_ignore
477+
)
478+
self._replace_sparsity_weights(dense_gen, model)
462479
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
463480
sparse_decompressed = True
464481

@@ -467,13 +484,27 @@ def decompress(self, model_path: str, model: Module):
467484
# quantization during apply_quantization_config. This ensures
468485
# that the dtypes of the weights are not unintentionally updated.
469486
# The status is restored after quantization params are loaded.
487+
470488
with override_quantization_status(
471489
self.quantization_config, QuantizationStatus.FROZEN
472490
):
491+
473492
names_to_scheme = apply_quantization_config(
474493
model, self.quantization_config
475494
)
476-
load_pretrained_quantization(model, model_path)
495+
# Load activation scales/zp or any other quantization parameters
496+
# Conditionally load the weight quantization parameters if we have a dense compressor
497+
# Or if a sparsity compressor has already been applied
498+
load_pretrained_quantization_parameters(
499+
model,
500+
model_path,
501+
# TODO: all weight quantization params will be moved to the compressor in a follow-up
502+
# including initialization
503+
load_weight_quantization=(
504+
sparse_decompressed
505+
or isinstance(self.quantization_compressor, DenseCompressor)
506+
),
507+
)
477508

478509
model_path_or_state_dict = (
479510
model.state_dict() if sparse_decompressed else model_path
@@ -482,6 +513,8 @@ def decompress(self, model_path: str, model: Module):
482513
dense_gen = self.quantization_compressor.decompress(
483514
model_path_or_state_dict, names_to_scheme=names_to_scheme
484515
)
516+
# TODO: all weight quantization params will be moved to the compressor
517+
# to prevent duplicate parameter updates in update_parameter_data
485518
self._replace_weights(dense_gen, model)
486519

487520
def freeze_quantization_status(module):
@@ -537,7 +570,7 @@ def update_config(self, save_directory: str):
537570
with open(config_file_path, "w") as config_file:
538571
json.dump(config_data, config_file, indent=2, sort_keys=True)
539572

540-
def _replace_weights(self, dense_weight_generator, model: Module):
573+
def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
541574
"""
542575
Replace the weights of the model with the
543576
provided dense weights.
@@ -552,11 +585,60 @@ def _replace_weights(self, dense_weight_generator, model: Module):
552585
:param model: The model whose weights are to be updated.
553586
"""
554587
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
588+
555589
split_name = name.split(".")
556590
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
557591
module = operator.attrgetter(prefix)(model)
558-
if hasattr(module, param_name):
559-
update_parameter_data(module, data, param_name)
592+
593+
params_device = next(module.parameters()).device
594+
device = "cpu" if has_offloaded_params(module) else params_device
595+
delattr(module, param_name)
596+
requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)
597+
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
598+
register_offload_parameter(module, param_name, param)
599+
600+
def _replace_weights(self, dense_weight_generator, model: Module):
601+
"""
602+
Replace the weights of the model with the
603+
provided dense weights.
604+
605+
This method iterates over the dense_weight_generator and
606+
updates the corresponding weights in the model. If a parameter
607+
name does not exist in the model, it will be skipped.
608+
609+
:param dense_weight_generator (generator): A generator that yields
610+
tuples of (name, data), where 'name' is the parameter name and
611+
'data' is the updated param data
612+
:param model: The model whose weights are to be updated.
613+
"""
614+
615+
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
616+
module = operator.attrgetter(name)(model)
617+
618+
params_device = next(module.parameters()).device
619+
device = "cpu" if has_offloaded_params(module) else params_device
620+
621+
for param_name, param_data in data.items():
622+
if hasattr(module, param_name):
623+
# If compressed, will have an incorrect dtype for transformers >4.49
624+
# TODO: we can also just skip initialization of scales/zp if in decompression in init
625+
# to be consistent with loading which happens later as well
626+
# however, update_data does a good shape check - should be moved to the compressor
627+
if param_name == "weight":
628+
delattr(module, param_name)
629+
requires_grad = param_data.dtype in (
630+
torch.float16,
631+
torch.float32,
632+
torch.bfloat16,
633+
)
634+
param = torch.nn.Parameter(
635+
param_data.to(device), requires_grad=requires_grad
636+
)
637+
register_offload_parameter(module, param_name, param)
638+
else:
639+
# Should already be registered to the correct device for
640+
# for scales/zero-points
641+
update_parameter_data(module, param_data, param_name)
560642

561643

562644
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 35 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import logging
1616
from pathlib import Path
17-
from typing import Any, Dict, Generator, Tuple, Union
17+
from typing import Any, Dict, Generator, Optional, Tuple, Union
1818

1919
import torch
2020
from compressed_tensors.compressors.base import BaseCompressor
@@ -121,29 +121,46 @@ def compress(
121121

122122
else:
123123
# omit saving zero points for symmetric quantization
124-
if name.endswith("zero_point") and _is_symmetric(name, names_to_scheme):
124+
if name.endswith("zero_point") and not self._should_save_zp(
125+
name, names_to_scheme
126+
):
125127
continue
126128

127129
# omit saving for g_idx if uninitialized
128130
# TODO: does this case actually occur?
129131
elif name.endswith("g_idx") and torch.any(value <= -1):
130132
continue
131133

132-
else:
133-
compressed_dict[name] = value.to(save_device)
134+
compressed_dict[name] = value.to(save_device)
134135

135136
return compressed_dict
136137

137-
def _check_if_zp_pack_quantized(self, quant_args):
138+
def _should_save_zp(
139+
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
140+
) -> bool:
138141
from compressed_tensors.compressors import PackedQuantizationCompressor
139142

140-
if isinstance(self, PackedQuantizationCompressor):
141-
if not quant_args.symmetric and quant_args.strategy in [
142-
QuantizationStrategy.GROUP.value,
143-
QuantizationStrategy.CHANNEL.value,
144-
]:
145-
return True
146-
return False
143+
module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
144+
scheme = names_to_scheme[module_name]
145+
146+
if zp_name == "weight_zero_point":
147+
args = scheme.weights
148+
if zp_name == "input_zero_point":
149+
args = scheme.input_activations
150+
if zp_name == "output_zero_point":
151+
args = scheme.output_activations
152+
153+
symmetric = args.symmetric
154+
packable_strats = [
155+
QuantizationStrategy.GROUP.value,
156+
QuantizationStrategy.CHANNEL.value,
157+
]
158+
packed = (
159+
isinstance(self, PackedQuantizationCompressor)
160+
and args.strategy in packable_strats
161+
)
162+
163+
return not symmetric and not packed
147164

148165
def decompress(
149166
self,
@@ -191,13 +208,10 @@ def _decompress_from_path(
191208
decompressed = self.decompress_weight(
192209
compressed_data=weight_data, quantization_args=quant_args
193210
)
194-
yield merge_names(weight_name, "weight"), decompressed
211+
weight_data["weight"] = decompressed
212+
yield weight_name, weight_data
195213

196-
def _decompress_from_state_dict(
197-
self,
198-
state_dict: Dict[str, torch.Tensor],
199-
names_to_scheme: Dict[str, QuantizationScheme],
200-
):
214+
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
201215
weight_mappings = get_nested_mappings_from_state_dict(
202216
state_dict, self.compression_param_names
203217
)
@@ -207,26 +221,9 @@ def _decompress_from_state_dict(
207221
weight_data[param_name] = param_value
208222

209223
if "weight_scale" in weight_data:
210-
quant_args = names_to_scheme[weight_name].weights
224+
quant_args = names_to_scheme[weight_name]
211225
decompressed = self.decompress_weight(
212226
compressed_data=weight_data, quantization_args=quant_args
213227
)
214-
yield merge_names(weight_name, "weight"), decompressed
215-
216-
217-
def _is_symmetric(name: str, names_to_scheme: Dict[str, QuantizationScheme]) -> bool:
218-
try:
219-
weight_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
220-
except:
221-
breakpoint()
222-
scheme = names_to_scheme[weight_name]
223-
224-
if zp_name == "weight_zero_point":
225-
quant_args = scheme.weights
226-
if zp_name == "input_zero_point":
227-
quant_args = scheme.input_activations
228-
if zp_name == "output_zero_point":
229-
quant_args = scheme.output_activations
230-
231-
assert quant_args is not None
232-
return quant_args.symmetric
228+
weight_data["weight"] = decompressed
229+
yield weight_name, weight_data

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ def compress(
9898
return compressed_dict
9999

100100
def decompress(
101-
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
101+
self,
102+
path_to_model_or_tensors: str,
103+
device: str = "cpu",
104+
params_to_skip_load: Optional[Tuple] = None,
105+
**kwargs,
102106
) -> Generator[Tuple[str, Tensor], None, None]:
103107
"""
104108
Reads a bitmask compressed state dict located
@@ -108,6 +112,11 @@ def decompress(
108112
:param model_path: path to compressed safetensors model (directory with
109113
one or more safetensors files) or compressed tensors file
110114
:param device: device to load decompressed weights onto
115+
:param params_to_skip_load: a list of non-sparsity parameters (e.g quantization
116+
parameters) that we want to skip loading. As the sparsity compresssor does
117+
not handle quantized decompression, this should contain any quantization
118+
parameters when decompressing stacked compressors. We want these parameters
119+
to be handled by the quantization decompressor
111120
:return: iterator for generating decompressed weights
112121
"""
113122
weight_mappings, ignored_params = get_nested_weight_mappings(
@@ -121,13 +130,21 @@ def decompress(
121130
full_name = merge_names(weight_name, param_name)
122131
with safe_open(safe_path, framework="pt", device=device) as f:
123132
weight_data[param_name] = f.get_tensor(full_name)
133+
124134
decompressed = self.decompress_weight(weight_data)
125135
yield merge_names(weight_name, "weight"), decompressed
126136

127137
for ignored_param_name, safe_path in ignored_params.items():
128-
with safe_open(safe_path, framework="pt", device=device) as f:
129-
value = f.get_tensor(ignored_param_name)
130-
yield ignored_param_name, value
138+
should_skip = False
139+
if params_to_skip_load is not None:
140+
for param_to_skip in params_to_skip_load:
141+
if param_to_skip in ignored_param_name:
142+
should_skip = True
143+
144+
if not should_skip:
145+
with safe_open(safe_path, framework="pt", device=device) as f:
146+
value = f.get_tensor(ignored_param_name)
147+
yield ignored_param_name, value
131148

132149
@staticmethod
133150
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:

0 commit comments

Comments
 (0)