Skip to content

Commit cfb698c

Browse files
committed
fix dtype issue
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent fda4889 commit cfb698c

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

tests/test_compressors/model_compressors/test_model_compressor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def _get_combined_config(s_config, q_config):
393393
],
394394
)
395395
def test_compress_decompress_model(model_stub, comp_stub, q_format, s_format):
396-
model = AutoModelForCausalLM.from_pretrained(model_stub)
396+
model = AutoModelForCausalLM.from_pretrained(model_stub, torch_dtype=torch.float32)
397397
compressor = ModelCompressor.from_pretrained_model(model, s_format, q_format)
398398

399399
# compress model by eagerly compressing state dict
@@ -423,7 +423,7 @@ def test_compress_decompress_model(model_stub, comp_stub, q_format, s_format):
423423
true_decompressed_model = AutoModelForCausalLM.from_pretrained(
424424
comp_stub,
425425
quantization_config=CompressedTensorsConfig(run_compressed=False),
426-
torch_dtype=torch.half,
426+
torch_dtype=torch.float32,
427427
)
428428
true_decompressed = dict(true_decompressed_model.state_dict())
429429
true_decompressed = {
@@ -441,5 +441,3 @@ def test_compress_decompress_model(model_stub, comp_stub, q_format, s_format):
441441
for key in decompressed.keys():
442442
assert torch.allclose(decompressed[key], true_decompressed[key]), f"{key}"
443443
del true_decompressed
444-
445-
exit(0)

0 commit comments

Comments
 (0)