Skip to content

Commit b2cad7e

Browse files
committed
Implement memory compression and decompression
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 59f02b5 commit b2cad7e

File tree

10 files changed

+317
-45
lines changed

10 files changed

+317
-45
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
iter_named_leaf_modules,
4848
)
4949
from compressed_tensors.utils import (
50+
get_execution_device,
5051
get_safetensors_folder,
5152
has_offloaded_params,
5253
merge_names,
@@ -98,6 +99,9 @@ class ModelCompressor:
9899
:param quantization_config: config specifying quantization compression parameters
99100
"""
100101

102+
sparsity_config: Optional[SparsityCompressionConfig] = None
103+
quantization_config: Optional[QuantizationConfig] = None
104+
101105
@classmethod
102106
def from_pretrained(
103107
cls,
@@ -261,6 +265,8 @@ def __init__(
261265
quantization_config.format, config=quantization_config
262266
)
263267

268+
# ----- used by hf quantizer ----- #
269+
264270
def get_missing_module_keys(self, model: Module) -> List[str]:
265271
"""
266272
Identifies the expected missing weight keys in the compressed state_dict.
@@ -362,8 +368,117 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
362368

363369
return list(unexpected_keys)
364370

371+
# ----- model memory compression/decompression pathways ----- #
372+
373+
def compress_model(self, model: Module):
374+
"""
375+
Compress a model in memory. Because the model structure is modified in place,
376+
this method is more memory-efficient than `self.compress`
377+
378+
:param model: model containing parameters to compress
379+
"""
380+
module_to_scheme = map_module_to_scheme(model)
381+
sparse_compression_targets: Set[str] = expand_target_names(
382+
model=model,
383+
targets=self.sparsity_config.targets if self.sparsity_config else [],
384+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
385+
)
386+
387+
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
388+
if prefix in module_to_scheme or prefix in sparse_compression_targets:
389+
state_dict = module.state_dict(prefix=f"{prefix}.")
390+
# quantization first
391+
if prefix in module_to_scheme:
392+
state_dict = self.quantization_compressor.compress(
393+
state_dict,
394+
names_to_scheme=module_to_scheme,
395+
show_progress=False,
396+
)
397+
398+
# sparsity second
399+
if prefix in sparse_compression_targets:
400+
state_dict = self.sparsity_compressor.compress(
401+
state_dict,
402+
compression_targets=sparse_compression_targets,
403+
show_progress=False,
404+
)
405+
406+
# remove any existing parameters
407+
device = get_execution_device(module)
408+
for name, _ in list(module.named_parameters()):
409+
delattr(module, name)
410+
411+
# replace with compressed parameters
412+
for name, value in state_dict.items():
413+
name = name.removeprefix(f"{prefix}.")
414+
value = value.to(device)
415+
param = torch.nn.Parameter(value, requires_grad=False)
416+
register_offload_parameter(module, name, param)
417+
418+
module.quantization_status = QuantizationStatus.COMPRESSED
419+
420+
def decompress_model(self, model: Module):
421+
"""
422+
Decompress a model in memory. Because the model structure is modified in place,
423+
this method does not require loading some compression parameters from disk
424+
425+
:param model: model containing parameters to compress
426+
"""
427+
module_to_scheme = map_module_to_scheme(model)
428+
sparse_compression_targets: Set[str] = expand_target_names(
429+
model=model,
430+
targets=self.sparsity_config.targets if self.sparsity_config else [],
431+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
432+
)
433+
434+
for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
435+
if prefix in module_to_scheme or prefix in sparse_compression_targets:
436+
state_dict = module.state_dict(prefix=f"{prefix}.")
437+
# sparsity first
438+
if prefix in sparse_compression_targets:
439+
# sparse_compression_targets are automatically inferred by this fn
440+
generator = self.sparsity_compressor.decompress_from_state_dict(
441+
state_dict,
442+
)
443+
# generates (param_path, param_val)
444+
# of compressed and unused params
445+
state_dict = {key: value for key, value in generator}
446+
447+
# quantization second
448+
if prefix in module_to_scheme:
449+
generator = self.quantization_compressor.decompress_from_state_dict(
450+
state_dict,
451+
names_to_scheme=module_to_scheme,
452+
)
453+
# generates (mod_path, {param_name, param_val})
454+
# of compressed params only (ignores unused params)
455+
state_dict = {
456+
merge_names(module_path, param_name): param_value
457+
for module_path, compressed_data in generator
458+
for param_name, param_value in compressed_data.items()
459+
}
460+
461+
# remove any existing parameters
462+
device = get_execution_device(module)
463+
for name, _ in list(module.named_parameters()):
464+
delattr(module, name)
465+
466+
# replace with decompressed parameters
467+
for name, value in state_dict.items():
468+
name = name.removeprefix(f"{prefix}.")
469+
value = value.to(device)
470+
param = torch.nn.Parameter(value, requires_grad=False)
471+
register_offload_parameter(module, name, param)
472+
473+
module.quantization_status = QuantizationStatus.FROZEN
474+
475+
# ----- state dict compression pathways ----- #
476+
365477
def compress(
366-
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
478+
self,
479+
model: Module,
480+
state_dict: Optional[Dict[str, Tensor]] = None,
481+
show_progress: bool = False,
367482
) -> Dict[str, Tensor]:
368483
"""
369484
Compresses a dense state dict or model with sparsity and/or quantization
@@ -379,7 +494,9 @@ def compress(
379494
if self.quantization_compressor is not None:
380495
module_to_scheme = map_module_to_scheme(model)
381496
state_dict = self.quantization_compressor.compress(
382-
state_dict, names_to_scheme=module_to_scheme
497+
state_dict,
498+
names_to_scheme=module_to_scheme,
499+
show_progress=show_progress,
383500
)
384501

385502
# TODO: consider sparse compression to also be compression
@@ -397,6 +514,7 @@ def compress(
397514
state_dict = self.sparsity_compressor.compress(
398515
state_dict,
399516
compression_targets=sparse_compression_targets,
517+
show_progress=show_progress,
400518
)
401519

402520
# HACK: Override the dtype_byte_size function in transformers to
@@ -406,6 +524,8 @@ def compress(
406524

407525
return state_dict
408526

527+
# ----- disk decompression pathways ----- #
528+
409529
def decompress(self, model_path: str, model: Module):
410530
"""
411531
Overwrites the weights in model with weights decompressed from model_path

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
get_nested_mappings_from_state_dict,
2424
get_nested_weight_mappings,
2525
merge_names,
26-
remove_suffix,
2726
)
2827
from safetensors import safe_open
2928
from torch import Tensor
@@ -71,6 +70,7 @@ def compress(
7170
self,
7271
model_state: Dict[str, Tensor],
7372
names_to_scheme: Dict[str, QuantizationScheme],
73+
show_progress: bool = False,
7474
**kwargs,
7575
) -> Dict[str, Tensor]:
7676
"""
@@ -79,18 +79,21 @@ def compress(
7979
:param model_state: state dict of uncompressed model
8080
:param names_to_scheme: quantization args for each quantized weight, needed for
8181
quantize function to calculate bit depth
82+
:param show_progress: whether to show tqdm progress
8283
:return: compressed state dict
8384
"""
85+
uncompressed_names = list(model_state.keys())
8486
compressed_dict = {}
8587
save_device = "cpu"
8688

87-
uncompressed_names = list(model_state.keys())
88-
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
89+
# compress values
90+
desc = "Compressing with quantization"
91+
for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)):
8992
value = model_state[name]
9093

9194
# compress weights
9295
if name.endswith("weight"):
93-
prefix = remove_suffix(name, "weight")
96+
prefix = name.removesuffix("weight")
9497

9598
# gather qparams
9699
scale = model_state.get(prefix + "weight_scale", None)
@@ -182,7 +185,7 @@ def decompress(
182185
)
183186

184187
else:
185-
yield from self._decompress_from_state_dict(
188+
yield from self.decompress_from_state_dict(
186189
path_to_model_or_tensors, names_to_scheme
187190
)
188191

@@ -209,7 +212,11 @@ def _decompress_from_path(
209212
weight_data["weight"] = decompressed
210213
yield module_path, weight_data
211214

212-
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
215+
def decompress_from_state_dict(
216+
self,
217+
state_dict: Dict[str, torch.Tensor],
218+
names_to_scheme: Dict[str, QuantizationScheme],
219+
) -> Generator[Tuple[str, Dict[str, torch.Tensor]], None, None]:
213220
weight_mappings = get_nested_mappings_from_state_dict(
214221
state_dict, self.compression_param_names
215222
)
@@ -219,7 +226,7 @@ def _decompress_from_state_dict(self, state_dict, names_to_scheme):
219226
weight_data[param_name] = param_value
220227

221228
if "weight_scale" in weight_data:
222-
quant_args = names_to_scheme[module_path]
229+
quant_args = names_to_scheme[module_path].weights
223230
decompressed = self.decompress_weight(
224231
compressed_data=weight_data, quantization_args=quant_args
225232
)

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from typing import Dict, Generator, Optional, Set, Tuple
1717

1818
from compressed_tensors.compressors.base import BaseCompressor
19-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
19+
from compressed_tensors.utils import (
20+
get_nested_mappings_from_state_dict,
21+
get_nested_weight_mappings,
22+
merge_names,
23+
)
2024
from safetensors import safe_open
2125
from torch import Tensor
2226
from tqdm import tqdm
@@ -63,6 +67,7 @@ def compress(
6367
self,
6468
model_state: Dict[str, Tensor],
6569
compression_targets: Optional[Set[str]] = None,
70+
show_progress: bool = False,
6671
) -> Dict[str, Tensor]:
6772
"""
6873
Compresses a dense state dict using bitmask compression
@@ -76,7 +81,11 @@ def compress(
7681
_LOGGER.debug(
7782
f"Compressing model with {len(model_state)} parameterized layers..."
7883
)
79-
for name, value in tqdm(model_state.items(), desc="Compressing model"):
84+
for name, value in tqdm(
85+
model_state.items(),
86+
desc="Compressing with sparsity",
87+
disable=(not show_progress),
88+
):
8089
if not self.should_compress(name, compression_targets):
8190
compressed_dict[name] = value
8291
continue
@@ -124,15 +133,15 @@ def decompress(
124133
self.compression_param_names,
125134
return_unmatched_params=True,
126135
)
127-
for weight_name in weight_mappings.keys():
136+
for module_path in weight_mappings.keys():
128137
weight_data = {}
129-
for param_name, safe_path in weight_mappings[weight_name].items():
130-
full_name = merge_names(weight_name, param_name)
138+
for param_name, safe_path in weight_mappings[module_path].items():
139+
full_name = merge_names(module_path, param_name)
131140
with safe_open(safe_path, framework="pt", device=device) as f:
132141
weight_data[param_name] = f.get_tensor(full_name)
133142

134143
decompressed = self.decompress_weight(weight_data)
135-
yield merge_names(weight_name, "weight"), decompressed
144+
yield merge_names(module_path, "weight"), decompressed
136145

137146
for ignored_param_name, safe_path in ignored_params.items():
138147
should_skip = False
@@ -146,6 +155,35 @@ def decompress(
146155
value = f.get_tensor(ignored_param_name)
147156
yield ignored_param_name, value
148157

158+
def decompress_from_state_dict(
159+
self,
160+
state_dict: Dict[str, Tensor],
161+
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
162+
"""
163+
Decompress the state dict of a module (or model)
164+
165+
Unlike `self.decompress`, this function does not need to explicitly skip params
166+
via params_to_skip_load because it is more convenient for its only caller
167+
(ModelCompressor.decompress_model) to retrieve all unused param keys
168+
169+
:param state_dict: state dict containing parameters to decompress
170+
:return: Generator of (param_path, param_val)
171+
"""
172+
weight_mappings, ignored_params = get_nested_mappings_from_state_dict(
173+
state_dict, self.compression_param_names, return_unmatched_params=True
174+
)
175+
176+
for module_path in weight_mappings.keys():
177+
weight_data = {}
178+
for param_name, param_value in weight_mappings[module_path].items():
179+
weight_data[param_name] = param_value
180+
181+
decompressed = self.decompress_weight(weight_data)
182+
yield merge_names(module_path, "weight"), decompressed
183+
184+
for ignored_param_path, ignored_param_value in ignored_params.items():
185+
yield ignored_param_path, ignored_param_value
186+
149187
@staticmethod
150188
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
151189
"""

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import Dict, List, Tuple, Union
16+
from typing import Dict, Generator, List, Tuple, Union
1717

1818
import torch
1919
from compressed_tensors.compressors.base import BaseCompressor
@@ -202,11 +202,7 @@ def sparse24_bitmask_decompress(
202202
decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
203203
decompressed_tensor = decompressed_tensor.to(values.device)
204204
values = values.flatten()
205-
if decompressed_tensor.dtype == FP8_DTYPE:
206-
decompressed_tensor[bytemasks_unpacked] = values
207-
decompressed_tensor = decompressed_tensor.cuda()
208-
else:
209-
decompressed_tensor[bytemasks_unpacked] = values
205+
decompressed_tensor[bytemasks_unpacked] = values
210206
return decompressed_tensor
211207

212208

src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def compress(
125125
self,
126126
model_state: Dict[str, Tensor],
127127
names_to_scheme: Dict[str, QuantizationScheme],
128+
show_progress: bool = False,
128129
**kwargs,
129130
) -> Dict[str, Tensor]:
130131
"""
@@ -134,6 +135,7 @@ def compress(
134135
:param model_state: state dict of uncompressed model
135136
:param names_to_scheme: quantization scheme for each quantized weight, needed
136137
for quantize function to calculate bit depth
138+
:param show_progress: whether to show tqdm progress
137139
:return: compressed state dict
138140
"""
139141
self.validate_quant_compatability(names_to_scheme)
@@ -144,7 +146,9 @@ def compress(
144146
f"Compressing model with {len(model_state)} parameterized layers..."
145147
)
146148

147-
for name, value in tqdm(model_state.items(), desc="Compressing model"):
149+
for name, value in tqdm(
150+
model_state.items(), desc="Compressing model", disable=(not show_progress)
151+
):
148152
if name.endswith(weight_suffix):
149153
prefix = name[: -(len(weight_suffix))]
150154
scale = model_state.get(merge_names(prefix, "weight_scale"), None)

0 commit comments

Comments
 (0)