21
21
import torch .nn as nn
22
22
from compressed_tensors .compressors import ModelCompressor
23
23
from compressed_tensors .config import SparsityCompressionConfig
24
- from compressed_tensors .config .sparse_24_bitmask import Sparse24BitMaskConfig
25
- from compressed_tensors .linear .compressed_linear import CompressedLinear
26
- from compressed_tensors .quantization import QuantizationConfig , QuantizationStatus
24
+ from compressed_tensors .quantization import QuantizationConfig
27
25
from safetensors .torch import save_file
28
26
from tests .testing_utils import induce_sparsity , requires_hf_quantizer
29
27
from transformers import AutoModelForCausalLM
@@ -388,6 +386,11 @@ def _get_combined_config(s_config, q_config):
388
386
"float-quantized" ,
389
387
"sparse-24-bitmask" ,
390
388
),
389
+ (
390
+ "nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed" ,
391
+ "pack-quantized" ,
392
+ None ,
393
+ ),
391
394
],
392
395
)
393
396
def test_compress_model (model_stub , q_format , s_config , tmpdir ):
@@ -405,6 +408,7 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
405
408
# equivalent to eagerly compressing state dict
406
409
assert compressed .keys () == true_compressed .keys ()
407
410
for key in compressed .keys ():
411
+ assert compressed [key ].dtype == true_compressed [key ].dtype
408
412
assert torch .all (compressed [key ] == true_compressed [key ]), f"{ key } "
409
413
410
414
@@ -423,6 +427,10 @@ def test_compress_model(model_stub, q_format, s_config, tmpdir):
423
427
"nm-testing/llama2.c-stories42M-gsm8k-stacked-uncompressed" ,
424
428
"nm-testing/llama2.c-stories42M-gsm8k-stacked-compressed" ,
425
429
),
430
+ (
431
+ "nm-testing/llama2.c-stories15M-ultrachat-mixed-uncompressed" ,
432
+ "nm-testing/llama2.c-stories15M-ultrachat-mixed-compressed" ,
433
+ ),
426
434
],
427
435
)
428
436
def test_decompress_model (model_stub , comp_stub ):
@@ -451,10 +459,17 @@ def test_decompress_model(model_stub, comp_stub):
451
459
compressor .decompress_model (model )
452
460
decompressed = dict (model .state_dict ())
453
461
462
+ # remove keys not in model definition
463
+ # NOTE it would be better if compressors only returned keys to keep, rather than
464
+ # relying on the model structure + missing keys to catch and remove them later
465
+ model_keys = true_decompressed_model .state_dict ().keys ()
466
+ decompressed = {key : val for key , val in decompressed .items () if key in model_keys }
467
+
454
468
# equivalent to decompressing from disk
455
469
assert decompressed .keys () == true_decompressed .keys ()
456
470
for key in decompressed .keys ():
457
- assert torch .allclose (decompressed [key ], true_decompressed [key ])
471
+ assert decompressed [key ].dtype == true_decompressed [key ].dtype
472
+ assert torch .all (decompressed [key ] == true_decompressed [key ]), f"{ key } "
458
473
459
474
460
475
def remove_empty_weight_zero_points (state_dict ):
0 commit comments