2121import torch .nn as nn
2222from compressed_tensors .compressors import ModelCompressor
2323from compressed_tensors .config import SparsityCompressionConfig
24- from compressed_tensors .config .sparse_24_bitmask import Sparse24BitMaskConfig
25- from compressed_tensors .linear .compressed_linear import CompressedLinear
26- from compressed_tensors .quantization import QuantizationConfig , QuantizationStatus
24+ from compressed_tensors .quantization import QuantizationConfig
2725from safetensors .torch import save_file
2826from tests .testing_utils import induce_sparsity , requires_hf_quantizer
2927from transformers import AutoModelForCausalLM
@@ -388,6 +386,11 @@ def _get_combined_config(s_config, q_config):
388386 "float-quantized" ,
389387 "sparse-24-bitmask" ,
390388 ),
389+ (
390+ "nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed" ,
391+ "pack-quantized" ,
392+ None ,
393+ ),
391394 ],
392395)
393396def test_compress_model (model_stub , q_format , s_config , tmpdir ):
@@ -405,6 +408,7 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
405408 # equivalent to eagerly compressing state dict
406409 assert compressed .keys () == true_compressed .keys ()
407410 for key in compressed .keys ():
411+ assert compressed [key ].dtype == true_compressed [key ].dtype
408412 assert torch .all (compressed [key ] == true_compressed [key ]), f"{ key } "
409413
410414
@@ -423,6 +427,10 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
423427 "nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed" ,
424428 "nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed" ,
425429 ),
430+ (
431+ "nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed" ,
432+ "nm-testing/llama2.c-stories15M-ultrachat-mixed-compressed" ,
433+ ),
426434 ],
427435)
428436def test_decompress_model (model_stub , comp_stub ):
@@ -451,10 +459,17 @@ def test_decompress_model(model_stub, comp_stub):
451459 compressor .decompress_model (model )
452460 decompressed = dict (model .state_dict ())
453461
462+ # remove keys not in model definition
463+ # NOTE it would be better if compressors only returned keys to keep, rather than
464+ # relying on the model structure + missing keys to catch and remove them later
465+ model_keys = true_decompressed_model .state_dict ().keys ()
466+ decompressed = {key : val for key , val in decompressed .items () if key in model_keys }
467+
454468 # equivalent to decompressing from disk
455469 assert decompressed .keys () == true_decompressed .keys ()
456470 for key in decompressed .keys ():
457- assert torch .allclose (decompressed [key ], true_decompressed [key ])
471+ assert decompressed [key ].dtype == true_decompressed [key ].dtype
472+ assert torch .all (decompressed [key ] == true_decompressed [key ]), f"{ key } "
458473
459474
460475def remove_empty_weight_zero_points (state_dict ):
0 commit comments