diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 806a98f0..d816f855 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -189,7 +189,7 @@ def _initialize_scale_zero_point( else: # TODO: consider erroring out in the future as if the dtype if not one of these, # there is likely bug - if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]: + if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32, torch.float64]: scale_dtype = torch.float16 zp_dtype = quantization_args.pytorch_dtype()