Skip to content

Commit fda4889

Browse files
committed
wip: correct, but dtype mismatch
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent dbc104d commit fda4889

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/compressed_tensors/compressors/quantized_compressors/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def decompress_from_state_dict(
203203
weight_data[param_name] = param_value
204204

205205
if "weight_scale" in weight_data:
206-
quant_args = names_to_scheme[module_name]
206+
quant_args = names_to_scheme[module_name].weights
207207
decompressed = self.decompress_weight(
208208
compressed_data=weight_data, quantization_args=quant_args
209209
)

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,21 +418,28 @@ def test_compress_decompress_model(model_stub, comp_stub, q_format, s_format):
418418
compressor = ModelCompressor.from_pretrained_model(model, s_format, q_format)
419419

420420
# decompress model from disk # TODO try also using a model saved from prev step
421+
from transformers.utils.quantization_config import CompressedTensorsConfig
422+
421423
true_decompressed_model = AutoModelForCausalLM.from_pretrained(
422-
model_stub, device_map="meta"
424+
comp_stub,
425+
quantization_config=CompressedTensorsConfig(run_compressed=False),
426+
torch_dtype=torch.half,
423427
)
424-
compressor.decompress(comp_stub, true_decompressed_model)
425428
true_decompressed = dict(true_decompressed_model.state_dict())
429+
true_decompressed = {
430+
name: value
431+
for name, value in true_decompressed.items()
432+
if not name.endswith("zero_point")
433+
} # ignore zero points
426434

427435
# decompress model
428436
compressor.decompress_model(model)
429437
decompressed = dict(model.state_dict())
430438

431439
# equivalent to decompressing from disk
432-
breakpoint()
433440
assert decompressed.keys() == true_decompressed.keys()
434441
for key in decompressed.keys():
435-
assert torch.all(decompressed[key] == true_decompressed[key]), f"{key}"
442+
assert torch.allclose(decompressed[key], true_decompressed[key]), f"{key}"
436443
del true_decompressed
437444

438445
exit(0)

0 commit comments

Comments
 (0)