From fa9facf4e7310890f204e0d6842d6150596ed0ff Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 2 Jun 2025 13:02:06 -0400 Subject: [PATCH 1/4] Tests: xfail opcheck for 4bit quantization with floating storage dtypes --- tests/test_ops.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index 9a0ae3338..7da19c012 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -167,9 +167,8 @@ def test_quantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksize assert absmax.device == A.device assert absmax.dtype == torch.float32 - # TODO: Enable it - if device in ("cpu", "xpu") and storage_dtype == torch.bfloat16: - pytest.skip("CPU bf16 storage_dtype will fail on torch op check") + if storage_dtype != torch.uint8: + pytest.xfail("opcheck fails for storage_dtype != torch.uint8") opcheck(torch.ops.bitsandbytes.quantize_4bit, (A, blocksize, quant_type, storage_dtype)) From 308afcfcb6f90e0ca32e1cca246e6ab8ed97c3d4 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 2 Jun 2025 13:05:14 -0400 Subject: [PATCH 2/4] Tests: xfail opcheck for 4bit quantization with floating storage dtypes --- bitsandbytes/__init__.py | 2 +- bitsandbytes/backends/utils.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 5014e8240..c747398ce 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -34,7 +34,7 @@ if torch.cuda.is_available(): from .backends.cuda import ops as cuda_ops -if torch.xpu.is_available(): +if hasattr(torch, "xpu") and torch.xpu.is_available(): from .backends.xpu import ops as xpu_ops diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py index cc88ffae1..bf277e7ea 100755 --- a/bitsandbytes/backends/utils.py +++ b/bitsandbytes/backends/utils.py @@ -30,7 +30,9 @@ 1.0, ], dtype=torch.float32, - device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. + device="xpu" + if hasattr(torch, "xpu") and torch.xpu.is_available() + else "cpu", # Only cpu/xpu use this table for now. ) _FP4_QUANT_TABLE = torch.tensor( [ @@ -52,6 +54,8 @@ -0.2500, ], dtype=torch.float32, - device="xpu" if torch.xpu.is_available() else "cpu", # Only cpu/xpu use this table for now. + device="xpu" + if hasattr(torch, "xpu") and torch.xpu.is_available() + else "cpu", # Only cpu/xpu use this table for now. ) CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} From 4ee6c601fbe392456e159e2dea3b69f26ed67725 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 2 Jun 2025 13:38:50 -0400 Subject: [PATCH 3/4] Tests: skip test_gemv_eye_4bit on CPU with bf16 when not supported by torch --- tests/test_functional.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_functional.py b/tests/test_functional.py index fa4a14ae9..34c7d4704 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1330,6 +1330,9 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): + if dtype == torch.bfloat16 and torch.__version__ < (2, 3): + pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3") + dims = 10 torch.random.manual_seed(np.random.randint(0, 412424242)) dims = get_test_dims(0, 8192, n=dims) From ca78ddd083dab859ebeeaeb5338d4873846009e6 Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Mon, 2 Jun 2025 14:15:39 -0400 Subject: [PATCH 4/4] Tests: skip test_gemv_eye_4bit on CPU with bf16 when not supported by torch --- tests/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index 34c7d4704..6a94205e8 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1330,7 +1330,7 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): - if dtype == torch.bfloat16 and torch.__version__ < (2, 3): + if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3): pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3") dims = 10