Skip to content

Commit 3524710

Browse files
committed
add mixed tests
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent ff3323b commit 3524710

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,6 @@ def get_missing_module_keys(self, model: Module) -> List[str]:
278278
This function determines which weight keys are missing based on the
279279
applied compression techniques.
280280
281-
282281
:param model: The PyTorch model to check for missing keys.
283282
:return: A list of missing keys expected in the compressed state_dict.
284283
"""
@@ -459,7 +458,8 @@ def decompress_model(self, model: Module):
459458
names_to_scheme=module_to_scheme,
460459
)
461460
# generates (mod_path, {param_name, param_val})
462-
# of compressed params only (ignores unused params)
461+
# of compressed params and used params, but not unused params
462+
# some used params are removed by get_unexpected_file_keys
463463
state_dict = {
464464
merge_names(module_path, param_name): param_value
465465
for module_path, compressed_data in generator

src/compressed_tensors/compressors/sparse_compressors/dense.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,10 @@ def decompress(
4040
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
4141
) -> Generator[Tuple[str, Tensor], None, None]:
4242
return iter([])
43+
44+
def decompress_from_state_dict(
45+
self,
46+
state_dict: Dict[str, Tensor],
47+
) -> Generator[Tuple[str, Dict[str, Tensor]], None, None]:
48+
for key, value in state_dict.items():
49+
yield key, value

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import torch.nn as nn
2222
from compressed_tensors.compressors import ModelCompressor
2323
from compressed_tensors.config import SparsityCompressionConfig
24-
from compressed_tensors.config.sparse_24_bitmask import Sparse24BitMaskConfig
25-
from compressed_tensors.linear.compressed_linear import CompressedLinear
26-
from compressed_tensors.quantization import QuantizationConfig, QuantizationStatus
24+
from compressed_tensors.quantization import QuantizationConfig
2725
from safetensors.torch import save_file
2826
from tests.testing_utils import induce_sparsity, requires_hf_quantizer
2927
from transformers import AutoModelForCausalLM
@@ -388,6 +386,11 @@ def _get_combined_config(s_config, q_config):
388386
"float-quantized",
389387
"sparse-24-bitmask",
390388
),
389+
(
390+
"nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed",
391+
"pack-quantized",
392+
None,
393+
),
391394
],
392395
)
393396
def test_compress_model(model_stub, q_format, s_config, tmpdir):
@@ -405,6 +408,7 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
405408
# equivalent to eagerly compressing state dict
406409
assert compressed.keys() == true_compressed.keys()
407410
for key in compressed.keys():
411+
assert compressed[key].dtype == true_compressed[key].dtype
408412
assert torch.all(compressed[key] == true_compressed[key]), f"{key}"
409413

410414

@@ -423,6 +427,10 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
423427
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed",
424428
"nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed",
425429
),
430+
(
431+
"nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed",
432+
"nm-testing/llama2.c-stories15M-ultrachat-mixed-compressed",
433+
),
426434
],
427435
)
428436
def test_decompress_model(model_stub, comp_stub):
@@ -451,10 +459,17 @@ def test_decompress_model(model_stub, comp_stub):
451459
compressor.decompress_model(model)
452460
decompressed = dict(model.state_dict())
453461

462+
# remove keys not in model definition
463+
# NOTE it would be better if compressors only returned keys to keep, rather than
464+
# relying on the model structure + missing keys to catch and remove them later
465+
model_keys = true_decompressed_model.state_dict().keys()
466+
decompressed = {key: val for key, val in decompressed.items() if key in model_keys}
467+
454468
# equivalent to decompressing from disk
455469
assert decompressed.keys() == true_decompressed.keys()
456470
for key in decompressed.keys():
457-
assert torch.allclose(decompressed[key], true_decompressed[key])
471+
assert decompressed[key].dtype == true_decompressed[key].dtype
472+
assert torch.all(decompressed[key] == true_decompressed[key]), f"{key}"
458473

459474

460475
def remove_empty_weight_zero_points(state_dict):

0 commit comments

Comments
 (0)