Skip to content

Commit ba48863

Browse files
committed
implement memory decompression
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 3f1cf36 commit ba48863

File tree

6 files changed

+68
-107
lines changed

6 files changed

+68
-107
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@ def __init__(
268268
quantization_config.format, config=quantization_config
269269
)
270270

271+
# ----- used by hf quantizer ----- #
272+
271273
def get_missing_module_keys(self, model: Module) -> List[str]:
272274
"""
273275
Identifies the expected missing weight keys in the compressed state_dict.
@@ -369,7 +371,7 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]:
369371

370372
return list(unexpected_keys)
371373

372-
# ----- model compression/decompression pathways ----- #
374+
# ----- model memory compression/decompression pathways ----- #
373375

374376
def compress_model(self, model: Module):
375377
module_to_scheme = map_module_to_scheme(model)
@@ -418,13 +420,6 @@ def decompress_model(self, model: Module):
418420
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
419421
)
420422

421-
# because decompressors are implemented to only generate new values (rather than
422-
# generating new values and unused values), we must explicitly pass a list of
423-
# keys to yield which are unused (but used in subsequent decompressors)
424-
params_to_ignore = None
425-
if self.quantization_compressor is not None:
426-
params_to_ignore = self.quantization_compressor.compression_param_names
427-
428423
for prefix, module in model.named_modules():
429424
if prefix in module_to_scheme or prefix in sparse_compression_targets:
430425
state_dict = module.state_dict(prefix=f"{prefix}.")
@@ -433,37 +428,37 @@ def decompress_model(self, model: Module):
433428
# sparse_compression_targets are automatically inferred by this fn
434429
generator = self.sparsity_compressor.decompress_from_state_dict(
435430
state_dict,
436-
params_to_ignore=params_to_ignore,
437431
)
438-
decompressed = dict()
439-
for _, decompressed_data in generator:
440-
decompressed.update(decompressed_data)
441-
state_dict = decompressed
432+
# generates (param_path, param_val)
433+
# of compressed and unused params
434+
state_dict = {key: value for key, value in generator}
442435

443436
# quantization second
444437
if prefix in module_to_scheme:
445438
generator = self.quantization_compressor.decompress_from_state_dict(
446-
state_dict, # asdf
439+
state_dict,
447440
names_to_scheme=module_to_scheme,
448441
)
449-
decompressed = dict()
450-
for _, decompressed_data in generator:
451-
decompressed.update(decompressed_data)
452-
state_dict = decompressed
442+
# generates (mod_path, {param_name, param_val})
443+
# of compressed params only (ignores unused params)
444+
state_dict = {
445+
merge_names(module_path, param_name): param_value
446+
for module_path, compressed_data in generator
447+
for param_name, param_value in compressed_data.items()
448+
}
453449

454450
# remove any existing parameters
455451
for name, _ in list(module.named_parameters()):
456452
delattr(module, name)
457453

458454
# replace with decompressed parameters
459455
for name, value in state_dict.items():
456+
name = name.removeprefix(f"{prefix}.")
460457
param = torch.nn.Parameter(value, requires_grad=False)
461458
register_offload_parameter(module, name, param)
462459

463460
module.quantization_status = QuantizationStatus.FROZEN
464461

465-
# apparently we only have logic for decompressing from a file...
466-
467462
# ----- state dict compression pathways ----- #
468463

469464
def compress(

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from typing import Dict, Generator, Optional, Set, Tuple
1717

1818
from compressed_tensors.compressors.base import BaseCompressor
19-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names, get_nested_mappings_from_state_dict
19+
from compressed_tensors.utils import (
20+
get_nested_mappings_from_state_dict,
21+
get_nested_weight_mappings,
22+
merge_names,
23+
)
2024
from safetensors import safe_open
2125
from torch import Tensor
2226
from tqdm import tqdm
@@ -154,10 +158,14 @@ def decompress(
154158
def decompress_from_state_dict(
155159
self,
156160
state_dict: Dict[str, Tensor],
157-
params_to_skip_load: Optional[Tuple] = None,
158161
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
159162
"""
160-
Implemented to copy the pattern of
163+
Unlike `self.decompress`, this function does not need to explicitly skip params
164+
via params_to_skip_load because it is more convenient for its only caller
165+
(ModelCompressor.decompress_model) to retrieve all unused param keys
166+
167+
:param state_dict: state dict containing parameters to decompress
168+
:return: Generator of (param_path, param_val)
161169
"""
162170
weight_mappings, ignored_params = get_nested_mappings_from_state_dict(
163171
state_dict, self.compression_param_names, return_unmatched_params=True
@@ -168,17 +176,11 @@ def decompress_from_state_dict(
168176
for param_name, param_value in weight_mappings[module_path].items():
169177
weight_data[param_name] = param_value
170178

171-
yield module_path, self.decompress_weight(weight_data)
172-
173-
for ignored_param_name, safe_path in ignored_params.items():
174-
should_skip = False
175-
if params_to_skip_load is not None:
176-
for param_to_skip in params_to_skip_load:
177-
if param_to_skip in ignored_param_name:
178-
should_skip = True
179+
decompressed = self.decompress_weight(weight_data)
180+
yield merge_names(module_path, "weight"), decompressed
179181

180-
if not should_skip:
181-
yield ignored_param_name, state_dict[ignored_param_name]
182+
for ignored_param_path, ignored_param_value in ignored_params.items():
183+
yield ignored_param_path, ignored_param_value
182184

183185
@staticmethod
184186
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:

src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from dataclasses import dataclass
16-
from typing import Dict, List, Tuple, Union
16+
from typing import Dict, Generator, List, Tuple, Union
1717

1818
import torch
1919
from compressed_tensors.compressors.base import BaseCompressor
@@ -202,11 +202,7 @@ def sparse24_bitmask_decompress(
202202
decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
203203
decompressed_tensor = decompressed_tensor.to(values.device)
204204
values = values.flatten()
205-
if decompressed_tensor.dtype == FP8_DTYPE:
206-
decompressed_tensor[bytemasks_unpacked] = values
207-
decompressed_tensor = decompressed_tensor.cuda()
208-
else:
209-
decompressed_tensor[bytemasks_unpacked] = values
205+
decompressed_tensor[bytemasks_unpacked] = values
210206
return decompressed_tensor
211207

212208

src/compressed_tensors/utils/safetensors_load.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def get_nested_weight_mappings(
250250

251251

252252
def get_nested_mappings_from_state_dict(
253-
state_dict, params_to_nest: Iterable[str], return_unmatched_params: bool = False,
253+
state_dict,
254+
params_to_nest: Iterable[str],
255+
return_unmatched_params: bool = False,
254256
) -> Union[NestedStateDictType, Tuple[NestedStateDictType, Dict[str, Tensor]]]:
255257
"""
256258
Takes a state dict and returns a nested mapping from uncompressed
@@ -271,14 +273,15 @@ def get_nested_mappings_from_state_dict(
271273
"""
272274
nested_weight_mappings = {}
273275
unmatched_params = {}
274-
276+
275277
for key in state_dict.keys():
278+
matched = False
276279
for param_name in params_to_nest:
277-
dense_param = match_param_name(key, param_name)
278-
if dense_param:
279-
if dense_param not in nested_weight_mappings:
280-
nested_weight_mappings[dense_param] = {}
281-
nested_weight_mappings[dense_param][param_name] = state_dict[key]
280+
module_path = match_param_name(key, param_name)
281+
if module_path:
282+
if module_path not in nested_weight_mappings:
283+
nested_weight_mappings[module_path] = {}
284+
nested_weight_mappings[module_path][param_name] = state_dict[key]
282285
matched = True
283286
if return_unmatched_params and not matched:
284287
unmatched_params[key] = state_dict[key]

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 24 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
from copy import deepcopy
1717
from pathlib import Path
1818

19-
from compressed_tensors.config.sparse_24_bitmask import Sparse24BitMaskConfig
2019
import pytest
2120
import torch
2221
import torch.nn as nn
2322
from compressed_tensors.compressors import ModelCompressor
2423
from compressed_tensors.config import SparsityCompressionConfig
24+
from compressed_tensors.config.sparse_24_bitmask import Sparse24BitMaskConfig
2525
from compressed_tensors.linear.compressed_linear import CompressedLinear
2626
from compressed_tensors.quantization import QuantizationConfig, QuantizationStatus
2727
from safetensors.torch import save_file
@@ -386,11 +386,11 @@ def _get_combined_config(s_config, q_config):
386386
(
387387
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
388388
"float-quantized",
389-
Sparse24BitMaskConfig(targets=["Linear"]),
389+
"sparse-24-bitmask",
390390
),
391391
],
392392
)
393-
def test_compress_decompress_model(model_stub, q_format, s_config, tmpdir):
393+
def test_compress_model(model_stub, q_format, s_config, tmpdir):
394394
model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
395395
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
396396

@@ -407,64 +407,32 @@ def test_compress_decompress_model(model_stub, q_format, s_config, tmpdir):
407407
for key in compressed.keys():
408408
assert torch.all(compressed[key] == true_compressed[key]), f"{key}"
409409

410-
del compressed
411-
del true_compressed
412-
413-
# -- decompress -- #
414-
415-
# reinstantiate compressor to mimic LLM Compressor flows
416-
model.save_pretrained(tmpdir)
417-
model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32)
418-
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
419-
420-
true_decompressed_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
421-
compressor.decompress(tmpdir, true_decompressed_model)
422-
true_decompressed = dict(true_decompressed_model.state_dict())
423-
424-
# decompress model
425-
compressor.decompress_model(model)
426-
decompressed = dict(model.state_dict())
427-
428-
# equivalent to decompressing from disk
429-
assert decompressed.keys() == true_decompressed.keys()
430-
for key in decompressed.keys():
431-
mask = ~torch.isclose(decompressed[key], true_decompressed[key], rtol=1e-3, atol=1e-5)
432-
print("Mismatched indices:", mask.nonzero(as_tuple=True))
433-
print("a values:", decompressed[key][mask])
434-
print("b values:", true_decompressed[key][mask])
435-
assert torch.allclose(decompressed[key], true_decompressed[key], rtol=1e-3, atol=1e-5), f"{key}"
436-
del true_decompressed
437-
438410

439411
@pytest.mark.parametrize(
440-
"comp_stub,q_format,s_config",
412+
"model_stub,comp_stub",
441413
[
442-
# (
443-
# "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed",
444-
# "float-quantized",
445-
# None,
446-
# ),
447-
# (
448-
# "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
449-
# None,
450-
# "sparse-24-bitmask",
451-
# ),
452414
(
415+
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
416+
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed",
417+
),
418+
(
419+
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
420+
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
421+
),
422+
(
423+
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
453424
"nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed",
454-
"float-quantized",
455-
Sparse24BitMaskConfig(targets=["Linear"]),
456425
),
457426
],
458427
)
459-
def test_decompress_model(comp_stub, q_format, s_config):
428+
def test_decompress_model(model_stub, comp_stub):
429+
from transformers.utils.quantization_config import CompressedTensorsConfig
430+
431+
# decompress from disk
460432
# NOTE: transformers adds extra zero points if run_compressed=False or w/ sparsity
461433
# https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_compressed_tensors.py#L131-L133
462434
# however, decompression does not add zero points in non-asymmetric cases
463435
# in order to normalize for this effect in this test, we remove empty weight zps
464-
465-
from transformers.utils.quantization_config import CompressedTensorsConfig
466-
467-
# decompress from disk
468436
true_decompressed_model = AutoModelForCausalLM.from_pretrained(
469437
comp_stub,
470438
quantization_config=CompressedTensorsConfig(run_compressed=False),
@@ -474,24 +442,19 @@ def test_decompress_model(comp_stub, q_format, s_config):
474442
true_decompressed = remove_empty_weight_zero_points(true_decompressed) # see above
475443

476444
# decompress from memory
477-
model = AutoModelForCausalLM.from_pretrained(
478-
comp_stub,
479-
quantization_config=CompressedTensorsConfig(run_compressed=True),
480-
torch_dtype=torch.float32,
481-
)
482-
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
445+
# NOTE there is no other way to load a compressed model into memory, since
446+
# there is no way to turn off decompression for sparse models
447+
# https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_compressed_tensors.py#L133
448+
model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
449+
compressor = ModelCompressor.from_pretrained(comp_stub)
450+
compressor.compress_model(model)
483451
compressor.decompress_model(model)
484452
decompressed = dict(model.state_dict())
485-
if "sparse" in str(s_config):
486-
decompressed = remove_empty_weight_zero_points(decompressed) # see above
487453

488454
# equivalent to decompressing from disk
489-
breakpoint()
490455
assert decompressed.keys() == true_decompressed.keys()
491456
for key in decompressed.keys():
492-
if not torch.allclose(decompressed[key], true_decompressed[key]):
493-
breakpoint()
494-
del true_decompressed
457+
assert torch.allclose(decompressed[key], true_decompressed[key])
495458

496459

497460
def remove_empty_weight_zero_points(state_dict):

tests/test_compressors/sparse_compressors/test_sparse_24_bitmask.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def _validate_shard_shapes(sharded_values, sharded_bitmask, expected_shapes):
4747

4848
def validate_compression(dense_matrix, decompressed_tensor):
4949
"""Validate that the decompressed tensor matches the original dense matrix."""
50+
if decompressed_tensor.device == FP8_DTYPE:
51+
decompressed_tensor = decompressed_tensor.to("cuda")
5052
dense_matrix = dense_matrix.to(decompressed_tensor.device)
5153
assert dense_matrix.dtype == decompressed_tensor.dtype, "Dtype mismatch"
5254
assert dense_matrix.shape == decompressed_tensor.shape, "Shape mismatch"

0 commit comments

Comments
 (0)