Skip to content

Commit db8d603

Browse files
committed
align with ipex code
1 parent 47bee0d commit db8d603

File tree

5 files changed

+13
-8
lines changed

5 files changed

+13
-8
lines changed

bitsandbytes/backends/triton/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def quantize_4bit(
8383
n = A.numel()
8484

8585
# TODO: Support when weight matrix is not divisible by blocksize
86-
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
86+
# torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
8787

8888
blocks = -(n // -(blocksize * 2))
8989

bitsandbytes/backends/xpu/ops.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from ..._ops import register_kernel
77
from ..utils import ipex_xpu
88

9-
if torch.__version__ >= (2, 7):
9+
# With default torch, error:
10+
# NotImplementedError: The operator 'aten::_int_mm' for XPU
11+
if ipex_xpu and torch.__version__ >= (2, 7):
12+
1013
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
1114
def _(A: torch.Tensor, B: torch.Tensor):
1215
return torch._int_mm(
@@ -16,6 +19,7 @@ def _(A: torch.Tensor, B: torch.Tensor):
1619

1720

1821
if ipex_xpu:
22+
1923
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu")
2024
def _(
2125
A: torch.Tensor,

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def to(self, *args, **kwargs):
677677
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
678678
if device.type != "cpu" or self.data.dtype != torch.int8:
679679
return self._quantize(device)
680-
elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu"):
680+
elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu):
681681
self.CB = self.data
682682

683683
new_param = Int8Params(

tests/test_functional.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
137137
abserr = sum(diffs) / len(diffs)
138138
relerr = sum(reldiffs) / len(reldiffs)
139139
if signed:
140-
threshold_abserr = 0.0036 if device in ("cpu", "xpu") else 0.0035
140+
threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035
141141
assert abserr < 0.0036
142142
assert relerr < 0.015
143143
else:
144-
assert abserr < 0.00175 if device in ("cpu", "xpu") else 0.0023
144+
assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023
145145
assert relerr < 0.012
146146
assert A2.dtype == dtype
147147

@@ -172,7 +172,7 @@ def test_blockwise_cpu_large(self, hidden, blocksize):
172172
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
173173
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic", "quantile"])
174174
def test_few_bit_quant(self, device, bits, method):
175-
if device in ("cpu", "xpu") and bits != 8:
175+
if device in ("cpu", "xpu") and bits != 8 and (F.ipex_cpu or F.ipex_xpu):
176176
pytest.skip("CPU/XPU implementation only supports 8 bits")
177177

178178
abserrs = []

tests/test_ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
import bitsandbytes
7+
from bitsandbytes.functional import ipex_xpu
78
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter
89

910
# torch.library.opcheck is only available in torch 2.4 and later.
@@ -144,7 +145,7 @@ def test_dequantize_blockwise(self, device, dtype, blocksize):
144145
assert out.device == A.device
145146

146147
# TODO: Enable it
147-
if device == "xpu":
148+
if device == "xpu" and ipex_xpu:
148149
pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check")
149150

150151
opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
@@ -170,7 +171,7 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
170171
if storage_dtype != torch.uint8:
171172
pytest.xfail("opcheck fails for storage_dtype != torch.uint8")
172173

173-
opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
174+
opcheck(torch.ops.bitsandbytes.quantize_4bit.default, (A, blocksize, quant_type, storage_dtype))
174175

175176
@pytest.mark.parametrize("device", get_available_devices())
176177
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=id_formatter("dtype"))

0 commit comments

Comments
 (0)