Skip to content

Commit 945f7c1

Browse files
Fix CI regression (#1666)
* Tests: xfail opcheck for 4bit quantization with floating storage dtypes * Tests: xfail opcheck for 4bit quantization with floating storage dtypes * Tests: skip test_gemv_eye_4bit on CPU with bf16 when not supported by torch * Tests: skip test_gemv_eye_4bit on CPU with bf16 when not supported by torch
1 parent a2a74ed commit 945f7c1

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

bitsandbytes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
if torch.cuda.is_available():
3535
from .backends.cuda import ops as cuda_ops
3636

37-
if torch.xpu.is_available():
37+
if hasattr(torch, "xpu") and torch.xpu.is_available():
3838
from .backends.xpu import ops as xpu_ops
3939

4040

bitsandbytes/backends/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
1.0,
3131
],
3232
dtype=torch.float32,
33-
device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now.
33+
device="xpu"
34+
if hasattr(torch, "xpu") and torch.xpu.is_available()
35+
else "cpu", # Only cpu/xpu use this table for now.
3436
)
3537
_FP4_QUANT_TABLE = torch.tensor(
3638
[
@@ -52,6 +54,8 @@
5254
-0.2500,
5355
],
5456
dtype=torch.float32,
55-
device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now.
57+
device="xpu"
58+
if hasattr(torch, "xpu") and torch.xpu.is_available()
59+
else "cpu", # Only cpu/xpu use this table for now.
5660
)
5761
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}

tests/test_functional.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
13301330
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
13311331
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
13321332
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
1333+
if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
1334+
pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")
1335+
13331336
dims = 10
13341337
torch.random.manual_seed(np.random.randint(0, 412424242))
13351338
dims = get_test_dims(0, 8192, n=dims)

tests/test_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,8 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize
167167
assert absmax.device == A.device
168168
assert absmax.dtype == torch.float32
169169

170-
# TODO: Enable it
171-
if device in ("cpu", "xpu") and storage_dtype == torch.bfloat16:
172-
pytest.skip("CPU bf16 storage_dtype will fail on torch op check")
170+
if storage_dtype != torch.uint8:
171+
pytest.xfail("opcheck fails for storage_dtype != torch.uint8")
173172

174173
opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype))
175174

0 commit comments

Comments
 (0)