diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 1aed09219..892e5f94a 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -671,8 +671,9 @@ def to(self, *args, **kwargs): if device is not None and device.type != "meta" and self.data.device.type == "cpu": if device.type != "cpu" or self.data.dtype != torch.int8: return self._quantize(device) - elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu): - self.CB = self.data + elif self.data.dtype == torch.int8: + if device.type == "cpu" or (device.type == "xpu" and ipex_xpu): + self.CB = self.data new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking), diff --git a/tests/test_functional.py b/tests/test_functional.py index 2e2e898cc..f66139a71 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -137,11 +137,10 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, abserr = sum(diffs) / len(diffs) relerr = sum(reldiffs) / len(reldiffs) if signed: - threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035 assert abserr < 0.0036 assert relerr < 0.015 else: - assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023 + assert abserr < 0.00175 if (device in "cpu") or (device in "xpu" and F.ipex_xpu) else 0.0023 assert relerr < 0.012 assert A2.dtype == dtype @@ -172,8 +171,10 @@ def test_blockwise_cpu_large(self, hidden, blocksize): @pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits")) @pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"]) def test_few_bit_quant(self, device, bits, method): - if bits != 8 and (device == "cpu" or (device == "xpu" and F.ipex_xpu)): - pytest.skip("CPU/XPU implementation only supports 8 bits") + if device in "cpu" and bits != 8: + pytest.skip("CPU implementation only supports 8 bits") + if device in "xpu" and bits != 8 and F.ipex_xpu: + pytest.skip("XPU ipex implementation only supports 8 bits") abserrs = [] relerrs = []