Skip to content

Commit f192f68

Browse files
authored
[Performance] Add memory compression and decompression pathways (#301)
* Implement memory compression and decompression Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * perform ops on cpu, move back to module device Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> * add mixed tests Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent c1645e3 commit f192f68

File tree

11 files changed

+337
-35
lines changed

11 files changed

+337
-35
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
iter_named_leaf_modules,
4848
)
4949
from compressed_tensors.utils import (
50+
align_module_device,
51+
delete_offload_parameter,
52+
get_execution_device,
5053
get_safetensors_folder,
5154
has_offloaded_params,
5255
merge_names,
@@ -98,6 +101,9 @@ class ModelCompressor:
98101
:param quantization_config: config specifying quantization compression parameters
99102
"""
100103

104+
sparsity_config: Optional[SparsityCompressionConfig] = None
105+
quantization_config: Optional[QuantizationConfig] = None
106+
101107
@classmethod
102108
def from_pretrained(
103109
cls,
@@ -261,6 +267,8 @@ def __init__(
261267
quantization_config.format, config=quantization_config
262268
)
263269

270+
# ----- used by hf quantizer ----- #
271+
264272
def get_missing_module_keys(self, model: Module) -> List[str]:
265273
"""
266274
Identifies the expected missing weight keys in the compressed state_dict.
@@ -270,7 +278,6 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
270278
This function determines which weight keys are missing based on the
271279
applied compression techniques.
272280
273-
274281
:param model: The PyTorch model to check for missing keys.
275282
:return: A list of missing keys expected in the compressed state_dict.
276283
"""
@@ -362,8 +369,124 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
362369

363370
return list(unexpected_keys)
364371

372+
# ----- model memory compression/decompression pathways ----- #
373+
374+
def compress_model(self, model: Module):
375+
"""
376+
Compress a model in memory. Because the model structure is modified in place,
377+
this method is more memory-efficient than `self.compress`
378+
379+
:param model: model containing parameters to compress
380+
"""
381+
module_to_scheme = map_module_to_scheme(model)
382+
sparse_compression_targets: Set[str] = expand_target_names(
383+
model=model,
384+
targets=self.sparsity_config.targets if self.sparsity_config else [],
385+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
386+
)
387+
388+
for prefix, module in tqdm(model.named_modules(), desc="Compressing model"):
389+
if prefix in module_to_scheme or prefix in sparse_compression_targets:
390+
# in the future, support compression on same device
391+
with align_module_device(module, execution_device="cpu"):
392+
state_dict = module.state_dict(prefix=f"{prefix}.")
393+
394+
# quantization first
395+
if prefix in module_to_scheme:
396+
state_dict = self.quantization_compressor.compress(
397+
state_dict,
398+
names_to_scheme=module_to_scheme,
399+
show_progress=False,
400+
)
401+
402+
# sparsity second
403+
if prefix in sparse_compression_targets:
404+
state_dict = self.sparsity_compressor.compress(
405+
state_dict,
406+
compression_targets=sparse_compression_targets,
407+
show_progress=False,
408+
)
409+
410+
# remove any existing parameters
411+
device = get_execution_device(module)
412+
for name, _ in list(module.named_parameters()):
413+
delattr(module, name)
414+
415+
# replace with compressed parameters
416+
for name, value in state_dict.items():
417+
name = name.removeprefix(f"{prefix}.")
418+
value = value.to(device)
419+
param = torch.nn.Parameter(value, requires_grad=False)
420+
register_offload_parameter(module, name, param)
421+
422+
module.quantization_status = QuantizationStatus.COMPRESSED
423+
424+
def decompress_model(self, model: Module):
425+
"""
426+
Decompress a model in memory. Because the model structure is modified in place,
427+
this method does not require loading some compression parameters from disk
428+
429+
:param model: model containing parameters to compress
430+
"""
431+
module_to_scheme = map_module_to_scheme(model)
432+
sparse_compression_targets: Set[str] = expand_target_names(
433+
model=model,
434+
targets=self.sparsity_config.targets if self.sparsity_config else [],
435+
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
436+
)
437+
438+
for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"):
439+
if prefix in module_to_scheme or prefix in sparse_compression_targets:
440+
# in the future, support decompression on same device
441+
with align_module_device(module, execution_device="cpu"):
442+
state_dict = module.state_dict(prefix=f"{prefix}.")
443+
444+
# sparsity first
445+
if prefix in sparse_compression_targets:
446+
# sparse_compression_targets are automatically inferred by this fn
447+
generator = self.sparsity_compressor.decompress_from_state_dict(
448+
state_dict,
449+
)
450+
# generates (param_path, param_val)
451+
# of compressed and unused params
452+
state_dict = {key: value for key, value in generator}
453+
454+
# quantization second
455+
if prefix in module_to_scheme:
456+
generator = self.quantization_compressor.decompress_from_state_dict(
457+
state_dict,
458+
names_to_scheme=module_to_scheme,
459+
)
460+
# generates (mod_path, {param_name, param_val})
461+
# of compressed params and used params, but not unused params
462+
# some used params are removed by get_unexpected_file_keys
463+
state_dict = {
464+
merge_names(module_path, param_name): param_value
465+
for module_path, compressed_data in generator
466+
for param_name, param_value in compressed_data.items()
467+
}
468+
469+
# remove any existing parameters
470+
device = get_execution_device(module)
471+
for name, _ in list(module.named_parameters()):
472+
delete_offload_parameter(module, name)
473+
474+
# replace with decompressed parameters
475+
for name, value in state_dict.items():
476+
name = name.removeprefix(f"{prefix}.")
477+
value = value.to(device)
478+
param = torch.nn.Parameter(value, requires_grad=False)
479+
register_offload_parameter(module, name, param)
480+
481+
module.quantization_status = QuantizationStatus.FROZEN
482+
483+
# ----- state dict compression pathways ----- #
484+
365485
def compress(
366-
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
486+
self,
487+
model: Module,
488+
state_dict: Optional[Dict[str, Tensor]] = None,
489+
show_progress: bool = False,
367490
) -> Dict[str, Tensor]:
368491
"""
369492
Compresses a dense state dict or model with sparsity and/or quantization
@@ -379,7 +502,9 @@ def compress(
379502
if self.quantization_compressor is not None:
380503
module_to_scheme = map_module_to_scheme(model)
381504
state_dict = self.quantization_compressor.compress(
382-
state_dict, names_to_scheme=module_to_scheme
505+
state_dict,
506+
names_to_scheme=module_to_scheme,
507+
show_progress=show_progress,
383508
)
384509

385510
# TODO: consider sparse compression to also be compression
@@ -397,6 +522,7 @@ def compress(
397522
state_dict = self.sparsity_compressor.compress(
398523
state_dict,
399524
compression_targets=sparse_compression_targets,
525+
show_progress=show_progress,
400526
)
401527

402528
# HACK: Override the dtype_byte_size function in transformers to
@@ -406,6 +532,8 @@ def compress(
406532

407533
return state_dict
408534

535+
# ----- disk decompression pathways ----- #
536+
409537
def decompress(self, model_path: str, model: Module):
410538
"""
411539
Overwrites the weights in model with weights decompressed from model_path

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
get_nested_mappings_from_state_dict,
2424
get_nested_weight_mappings,
2525
merge_names,
26-
remove_suffix,
2726
)
2827
from safetensors import safe_open
2928
from torch import Tensor
@@ -71,6 +70,7 @@ def compress(
7170
self,
7271
model_state: Dict[str, Tensor],
7372
names_to_scheme: Dict[str, QuantizationScheme],
73+
show_progress: bool = False,
7474
**kwargs,
7575
) -> Dict[str, Tensor]:
7676
"""
@@ -79,18 +79,21 @@ def compress(
7979
:param model_state: state dict of uncompressed model
8080
:param names_to_scheme: quantization args for each quantized weight, needed for
8181
quantize function to calculate bit depth
82+
:param show_progress: whether to show tqdm progress
8283
:return: compressed state dict
8384
"""
85+
uncompressed_names = list(model_state.keys())
8486
compressed_dict = {}
8587
save_device = "cpu"
8688

87-
uncompressed_names = list(model_state.keys())
88-
for name in tqdm(uncompressed_names, desc="Compressing with quantization"):
89+
# compress values
90+
desc = "Compressing with quantization"
91+
for name in tqdm(uncompressed_names, desc=desc, disable=(not show_progress)):
8992
value = model_state[name]
9093

9194
# compress weights
9295
if name.endswith("weight"):
93-
prefix = remove_suffix(name, "weight")
96+
prefix = name.removesuffix("weight")
9497

9598
# gather qparams
9699
scale = model_state.get(prefix + "weight_scale", None)
@@ -182,7 +185,7 @@ def decompress(
182185
)
183186

184187
else:
185-
yield from self._decompress_from_state_dict(
188+
yield from self.decompress_from_state_dict(
186189
path_to_model_or_tensors, names_to_scheme
187190
)
188191

@@ -209,7 +212,11 @@ def _decompress_from_path(
209212
weight_data["weight"] = decompressed
210213
yield module_path, weight_data
211214

212-
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
215+
def decompress_from_state_dict(
216+
self,
217+
state_dict: Dict[str, torch.Tensor],
218+
names_to_scheme: Dict[str, QuantizationScheme],
219+
) -> Generator[Tuple[str, Dict[str, torch.Tensor]], None, None]:
213220
weight_mappings = get_nested_mappings_from_state_dict(
214221
state_dict, self.compression_param_names
215222
)
@@ -219,7 +226,7 @@ def _decompress_from_state_dict(self, state_dict, names_to_scheme):
219226
weight_data[param_name] = param_value
220227

221228
if "weight_scale" in weight_data:
222-
quant_args = names_to_scheme[module_path]
229+
quant_args = names_to_scheme[module_path].weights
223230
decompressed = self.decompress_weight(
224231
compressed_data=weight_data, quantization_args=quant_args
225232
)

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from typing import Dict, Generator, Optional, Set, Tuple
1717

1818
from compressed_tensors.compressors.base import BaseCompressor
19-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
19+
from compressed_tensors.utils import (
20+
get_nested_mappings_from_state_dict,
21+
get_nested_weight_mappings,
22+
merge_names,
23+
)
2024
from safetensors import safe_open
2125
from torch import Tensor
2226
from tqdm import tqdm
@@ -63,6 +67,7 @@ def compress(
6367
self,
6468
model_state: Dict[str, Tensor],
6569
compression_targets: Optional[Set[str]] = None,
70+
show_progress: bool = False,
6671
) -> Dict[str, Tensor]:
6772
"""
6873
Compresses a dense state dict using bitmask compression
@@ -76,7 +81,11 @@ def compress(
7681
_LOGGER.debug(
7782
f"Compressing model with {len(model_state)} parameterized layers..."
7883
)
79-
for name, value in tqdm(model_state.items(), desc="Compressing model"):
84+
for name, value in tqdm(
85+
model_state.items(),
86+
desc="Compressing with sparsity",
87+
disable=(not show_progress),
88+
):
8089
if not self.should_compress(name, compression_targets):
8190
compressed_dict[name] = value
8291
continue
@@ -124,15 +133,15 @@ def decompress(
124133
self.compression_param_names,
125134
return_unmatched_params=True,
126135
)
127-
for weight_name in weight_mappings.keys():
136+
for module_path in weight_mappings.keys():
128137
weight_data = {}
129-
for param_name, safe_path in weight_mappings[weight_name].items():
130-
full_name = merge_names(weight_name, param_name)
138+
for param_name, safe_path in weight_mappings[module_path].items():
139+
full_name = merge_names(module_path, param_name)
131140
with safe_open(safe_path, framework="pt", device=device) as f:
132141
weight_data[param_name] = f.get_tensor(full_name)
133142

134143
decompressed = self.decompress_weight(weight_data)
135-
yield merge_names(weight_name, "weight"), decompressed
144+
yield merge_names(module_path, "weight"), decompressed
136145

137146
for ignored_param_name, safe_path in ignored_params.items():
138147
should_skip = False
@@ -146,6 +155,35 @@ def decompress(
146155
value = f.get_tensor(ignored_param_name)
147156
yield ignored_param_name, value
148157

158+
def decompress_from_state_dict(
159+
self,
160+
state_dict: Dict[str, Tensor],
161+
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
162+
"""
163+
Decompress the state dict of a module (or model)
164+
165+
Unlike `self.decompress`, this function does not need to explicitly skip params
166+
via params_to_skip_load because it is more convenient for its only caller
167+
(ModelCompressor.decompress_model) to retrieve all unused param keys
168+
169+
:param state_dict: state dict containing parameters to decompress
170+
:return: Generator of (param_path, param_val)
171+
"""
172+
weight_mappings, ignored_params = get_nested_mappings_from_state_dict(
173+
state_dict, self.compression_param_names, return_unmatched_params=True
174+
)
175+
176+
for module_path in weight_mappings.keys():
177+
weight_data = {}
178+
for param_name, param_value in weight_mappings[module_path].items():
179+
weight_data[param_name] = param_value
180+
181+
decompressed = self.decompress_weight(weight_data)
182+
yield merge_names(module_path, "weight"), decompressed
183+
184+
for ignored_param_path, ignored_param_value in ignored_params.items():
185+
yield ignored_param_path, ignored_param_value
186+
149187
@staticmethod
150188
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
151189
"""

src/compressed_tensors/compressors/sparse_compressors/dense.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,10 @@ def decompress(
4040
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
4141
) -> Generator[Tuple[str, Tensor], None, None]:
4242
return iter([])
43+
44+
def decompress_from_state_dict(
45+
self,
46+
state_dict: Dict[str, Tensor],
47+
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
48+
for key, value in state_dict.items():
49+
yield key, value

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import Dict, List, Tuple, Union
16+
from typing import Dict, Generator, List, Tuple, Union
1717

1818
import torch
1919
from compressed_tensors.compressors.base import BaseCompressor
@@ -202,11 +202,7 @@ def sparse24_bitmask_decompress(
202202
decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
203203
decompressed_tensor = decompressed_tensor.to(values.device)
204204
values = values.flatten()
205-
if decompressed_tensor.dtype == FP8_DTYPE:
206-
decompressed_tensor[bytemasks_unpacked] = values
207-
decompressed_tensor = decompressed_tensor.cuda()
208-
else:
209-
decompressed_tensor[bytemasks_unpacked] = values
205+
decompressed_tensor[bytemasks_unpacked] = values
210206
return decompressed_tensor
211207

212208

0 commit comments

Comments
 (0)