Skip to content

Commit 3f1cf36

Browse files
committed
wip: writing sparse decompress_from_state_dict
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent cfb698c commit 3f1cf36

File tree

5 files changed

+147
-42
lines changed

5 files changed

+147
-42
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def compress_model(self, model: Module):
398398
show_progress=False,
399399
)
400400

401-
# remove any exist parameters
401+
# remove any existing parameters
402402
for name, _ in list(module.named_parameters()):
403403
delattr(module, name)
404404

@@ -418,34 +418,45 @@ def decompress_model(self, model: Module):
418418
ignore=self.sparsity_config.ignore if self.sparsity_config else [],
419419
)
420420

421+
# because decompressors are implemented to only generate new values (rather than
422+
# generating new values and unused values), we must explicitly pass a list of
423+
# keys to yield which are unused (but used in subsequent decompressors)
424+
params_to_ignore = None
425+
if self.quantization_compressor is not None:
426+
params_to_ignore = self.quantization_compressor.compression_param_names
427+
421428
for prefix, module in model.named_modules():
422429
if prefix in module_to_scheme or prefix in sparse_compression_targets:
423430
state_dict = module.state_dict(prefix=f"{prefix}.")
424-
decompressed = dict()
425431
# sparsity first
426432
if prefix in sparse_compression_targets:
433+
# sparse_compression_targets are automatically inferred by this fn
427434
generator = self.sparsity_compressor.decompress_from_state_dict(
428435
state_dict,
429-
names_to_scheme=module_to_scheme,
436+
params_to_ignore=params_to_ignore,
430437
)
431-
for _module_name, decompressed_data in generator:
438+
decompressed = dict()
439+
for _, decompressed_data in generator:
432440
decompressed.update(decompressed_data)
441+
state_dict = decompressed
433442

434443
# quantization second
435444
if prefix in module_to_scheme:
436445
generator = self.quantization_compressor.decompress_from_state_dict(
437-
state_dict,
446+
state_dict, # asdf
438447
names_to_scheme=module_to_scheme,
439448
)
440-
for _module_name, decompressed_data in generator:
449+
decompressed = dict()
450+
for _, decompressed_data in generator:
441451
decompressed.update(decompressed_data)
452+
state_dict = decompressed
442453

443-
# remove any exist parameters
454+
# remove any existing parameters
444455
for name, _ in list(module.named_parameters()):
445456
delattr(module, name)
446457

447458
# replace with decompressed parameters
448-
for name, value in decompressed.items():
459+
for name, value in state_dict.items():
449460
param = torch.nn.Parameter(value, requires_grad=False)
450461
register_offload_parameter(module, name, param)
451462

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,18 +197,18 @@ def decompress_from_state_dict(
197197
weight_mappings = get_nested_mappings_from_state_dict(
198198
state_dict, self.compression_param_names
199199
)
200-
for module_name in weight_mappings.keys():
200+
for module_path in weight_mappings.keys():
201201
weight_data = {}
202-
for param_name, param_value in weight_mappings[module_name].items():
202+
for param_name, param_value in weight_mappings[module_path].items():
203203
weight_data[param_name] = param_value
204204

205205
if "weight_scale" in weight_data:
206-
quant_args = names_to_scheme[module_name].weights
206+
quant_args = names_to_scheme[module_path].weights
207207
decompressed = self.decompress_weight(
208208
compressed_data=weight_data, quantization_args=quant_args
209209
)
210210
weight_data["weight"] = decompressed
211-
yield module_name, weight_data
211+
yield module_path, weight_data
212212

213213
def _decompress_from_path(
214214
self,
@@ -219,16 +219,16 @@ def _decompress_from_path(
219219
weight_mappings = get_nested_weight_mappings(
220220
path_to_model, self.compression_param_names
221221
)
222-
for module_name in weight_mappings.keys():
222+
for module_path in weight_mappings.keys():
223223
weight_data = {}
224-
for param_name, safe_path in weight_mappings[module_name].items():
225-
full_name = merge_names(module_name, param_name)
224+
for param_name, safe_path in weight_mappings[module_path].items():
225+
full_name = merge_names(module_path, param_name)
226226
with safe_open(safe_path, framework="pt", device=device) as f:
227227
weight_data[param_name] = f.get_tensor(full_name)
228228
if "weight_scale" in weight_data:
229-
quant_args = names_to_scheme[module_name].weights
229+
quant_args = names_to_scheme[module_path].weights
230230
decompressed = self.decompress_weight(
231231
compressed_data=weight_data, quantization_args=quant_args
232232
)
233233
weight_data["weight"] = decompressed
234-
yield module_name, weight_data
234+
yield module_path, weight_data

src/compressed_tensors/compressors/sparse_compressors/base.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Dict, Generator, Optional, Set, Tuple
1717

1818
from compressed_tensors.compressors.base import BaseCompressor
19-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
19+
from compressed_tensors.utils import get_nested_weight_mappings, merge_names, get_nested_mappings_from_state_dict
2020
from safetensors import safe_open
2121
from torch import Tensor
2222
from tqdm import tqdm
@@ -129,15 +129,15 @@ def decompress(
129129
self.compression_param_names,
130130
return_unmatched_params=True,
131131
)
132-
for weight_name in weight_mappings.keys():
132+
for module_path in weight_mappings.keys():
133133
weight_data = {}
134-
for param_name, safe_path in weight_mappings[weight_name].items():
135-
full_name = merge_names(weight_name, param_name)
134+
for param_name, safe_path in weight_mappings[module_path].items():
135+
full_name = merge_names(module_path, param_name)
136136
with safe_open(safe_path, framework="pt", device=device) as f:
137137
weight_data[param_name] = f.get_tensor(full_name)
138138

139139
decompressed = self.decompress_weight(weight_data)
140-
yield merge_names(weight_name, "weight"), decompressed
140+
yield merge_names(module_path, "weight"), decompressed
141141

142142
for ignored_param_name, safe_path in ignored_params.items():
143143
should_skip = False
@@ -151,8 +151,34 @@ def decompress(
151151
value = f.get_tensor(ignored_param_name)
152152
yield ignored_param_name, value
153153

154-
def decompress_from_state_dict(self, state_dict, names_to_scheme):
155-
exit(0)
154+
def decompress_from_state_dict(
155+
self,
156+
state_dict: Dict[str, Tensor],
157+
params_to_skip_load: Optional[Tuple] = None,
158+
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
159+
"""
160+
Implemented to copy the pattern of
161+
"""
162+
weight_mappings, ignored_params = get_nested_mappings_from_state_dict(
163+
state_dict, self.compression_param_names, return_unmatched_params=True
164+
)
165+
166+
for module_path in weight_mappings.keys():
167+
weight_data = {}
168+
for param_name, param_value in weight_mappings[module_path].items():
169+
weight_data[param_name] = param_value
170+
171+
yield module_path, self.decompress_weight(weight_data)
172+
173+
for ignored_param_name, safe_path in ignored_params.items():
174+
should_skip = False
175+
if params_to_skip_load is not None:
176+
for param_to_skip in params_to_skip_load:
177+
if param_to_skip in ignored_param_name:
178+
should_skip = True
179+
180+
if not should_skip:
181+
yield ignored_param_name, state_dict[ignored_param_name]
156182

157183
@staticmethod
158184
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:

src/compressed_tensors/utils/safetensors_load.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"is_quantization_param",
3636
]
3737

38+
NestedStateDictType = Dict[str, Dict[str, Tensor]]
3839
WeightMappingType = Dict[str, str]
3940
NestedWeightMappingType = Dict[str, WeightMappingType]
4041

@@ -249,8 +250,8 @@ def get_nested_weight_mappings(
249250

250251

251252
def get_nested_mappings_from_state_dict(
252-
state_dict, params_to_nest: Iterable[str]
253-
) -> NestedWeightMappingType:
253+
state_dict, params_to_nest: Iterable[str], return_unmatched_params: bool = False,
254+
) -> Union[NestedStateDictType, Tuple[NestedStateDictType, Dict[str, Tensor]]]:
254255
"""
255256
Takes a state dict and returns a nested mapping from uncompressed
256257
parameterized layer names to the value of
@@ -269,13 +270,21 @@ def get_nested_mappings_from_state_dict(
269270
each layer's compression parameters.
270271
"""
271272
nested_weight_mappings = {}
273+
unmatched_params = {}
274+
272275
for key in state_dict.keys():
273276
for param_name in params_to_nest:
274277
dense_param = match_param_name(key, param_name)
275278
if dense_param:
276279
if dense_param not in nested_weight_mappings:
277280
nested_weight_mappings[dense_param] = {}
278281
nested_weight_mappings[dense_param][param_name] = state_dict[key]
282+
matched = True
283+
if return_unmatched_params and not matched:
284+
unmatched_params[key] = state_dict[key]
285+
286+
if return_unmatched_params:
287+
return nested_weight_mappings, unmatched_params
279288
return nested_weight_mappings
280289

281290

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 75 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from copy import deepcopy
1717
from pathlib import Path
1818

19+
from compressed_tensors.config.sparse_24_bitmask import Sparse24BitMaskConfig
1920
import pytest
2021
import torch
2122
import torch.nn as nn
@@ -370,31 +371,28 @@ def _get_combined_config(s_config, q_config):
370371

371372

372373
@pytest.mark.parametrize(
373-
"model_stub,comp_stub,q_format,s_format",
374+
"model_stub,q_format,s_config",
374375
[
375376
(
376377
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-uncompressed",
377-
"nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed",
378378
"float-quantized",
379379
None,
380380
),
381381
(
382382
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-uncompressed",
383-
"nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
384383
None,
385384
"sparse-24-bitmask",
386385
),
387386
(
388387
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
389-
"nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed",
390388
"float-quantized",
391-
"sparse-24-bitmask",
389+
Sparse24BitMaskConfig(targets=["Linear"]),
392390
),
393391
],
394392
)
395-
def test_compress_decompress_model(model_stub, comp_stub, q_format, s_format):
393+
def test_compress_decompress_model(model_stub, q_format, s_config, tmpdir):
396394
model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
397-
compressor = ModelCompressor.from_pretrained_model(model, s_format, q_format)
395+
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
398396

399397
# compress model by eagerly compressing state dict
400398
true_compressed = dict(compressor.compress(model))
@@ -415,29 +413,90 @@ def test_compress_decompress_model(model_stub, comp_stub, q_format, s_format):
415413
# -- decompress -- #
416414

417415
# reinstantiate compressor to mimic LLM Compressor flows
418-
compressor = ModelCompressor.from_pretrained_model(model, s_format, q_format)
416+
model.save_pretrained(tmpdir)
417+
model = AutoModelForCausalLM.from_pretrained(tmpdir, torch_dtype=torch.float32)
418+
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
419+
420+
true_decompressed_model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
421+
compressor.decompress(tmpdir, true_decompressed_model)
422+
true_decompressed = dict(true_decompressed_model.state_dict())
423+
424+
# decompress model
425+
compressor.decompress_model(model)
426+
decompressed = dict(model.state_dict())
427+
428+
# equivalent to decompressing from disk
429+
assert decompressed.keys() == true_decompressed.keys()
430+
for key in decompressed.keys():
431+
mask = ~torch.isclose(decompressed[key], true_decompressed[key], rtol=1e-3, atol=1e-5)
432+
print("Mismatched indices:", mask.nonzero(as_tuple=True))
433+
print("a values:", decompressed[key][mask])
434+
print("b values:", true_decompressed[key][mask])
435+
assert torch.allclose(decompressed[key], true_decompressed[key], rtol=1e-3, atol=1e-5), f"{key}"
436+
del true_decompressed
437+
438+
439+
@pytest.mark.parametrize(
440+
"comp_stub,q_format,s_config",
441+
[
442+
# (
443+
# "nm-testing/llama2.c-stories42M-gsm8k-quantized-only-compressed",
444+
# "float-quantized",
445+
# None,
446+
# ),
447+
# (
448+
# "nm-testing/llama2.c-stories42M-gsm8k-sparse-only-compressed",
449+
# None,
450+
# "sparse-24-bitmask",
451+
# ),
452+
(
453+
"nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed",
454+
"float-quantized",
455+
Sparse24BitMaskConfig(targets=["Linear"]),
456+
),
457+
],
458+
)
459+
def test_decompress_model(comp_stub, q_format, s_config):
460+
# NOTE: transformers adds extra zero points if run_compressed=False or w/ sparsity
461+
# https://github.com/huggingface/transformers/blob/main/src/transformers/quantizers/quantizer_compressed_tensors.py#L131-L133
462+
# however, decompression does not add zero points in non-asymmetric cases
463+
# in order to normalize for this effect in this test, we remove empty weight zps
419464

420-
# decompress model from disk # TODO try also using a model saved from prev step
421465
from transformers.utils.quantization_config import CompressedTensorsConfig
422466

467+
# decompress from disk
423468
true_decompressed_model = AutoModelForCausalLM.from_pretrained(
424469
comp_stub,
425470
quantization_config=CompressedTensorsConfig(run_compressed=False),
426471
torch_dtype=torch.float32,
427472
)
428473
true_decompressed = dict(true_decompressed_model.state_dict())
429-
true_decompressed = {
430-
name: value
431-
for name, value in true_decompressed.items()
432-
if not name.endswith("zero_point")
433-
} # ignore zero points
474+
true_decompressed = remove_empty_weight_zero_points(true_decompressed) # see above
434475

435-
# decompress model
476+
# decompress from memory
477+
model = AutoModelForCausalLM.from_pretrained(
478+
comp_stub,
479+
quantization_config=CompressedTensorsConfig(run_compressed=True),
480+
torch_dtype=torch.float32,
481+
)
482+
compressor = ModelCompressor.from_pretrained_model(model, s_config, q_format)
436483
compressor.decompress_model(model)
437484
decompressed = dict(model.state_dict())
485+
if "sparse" in str(s_config):
486+
decompressed = remove_empty_weight_zero_points(decompressed) # see above
438487

439488
# equivalent to decompressing from disk
489+
breakpoint()
440490
assert decompressed.keys() == true_decompressed.keys()
441491
for key in decompressed.keys():
442-
assert torch.allclose(decompressed[key], true_decompressed[key]), f"{key}"
492+
if not torch.allclose(decompressed[key], true_decompressed[key]):
493+
breakpoint()
443494
del true_decompressed
495+
496+
497+
def remove_empty_weight_zero_points(state_dict):
498+
return {
499+
name: value
500+
for name, value in state_dict.items()
501+
if not (name.endswith("weight_zero_point") and torch.all(value == 0))
502+
}

0 commit comments

Comments
 (0)