Skip to content

Enable CPU/XPU native and ipex path #1628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ba79025
enable ipex
jiqing-feng May 7, 2025
958d75b
fix cpu 8bit quantization
jiqing-feng May 7, 2025
f5c0b01
fix int8 and nf4 cpu inference
jiqing-feng May 7, 2025
7f2d8a8
add cpu fp4 and rem
jiqing-feng May 8, 2025
97d5bd1
fix dequantize nf4 xpu
jiqing-feng May 8, 2025
5563c35
Merge branch 'main' into ipex
jiqing-feng May 8, 2025
7b72673
fix ipex op
jiqing-feng May 9, 2025
52e32af
fix dequantize nf4 name
jiqing-feng May 9, 2025
fda3d70
fix dequantize nf4 ipex
jiqing-feng May 9, 2025
5ce3296
Merge branch 'main' into ipex
jiqing-feng May 9, 2025
f51678e
fix matmul8bitfp
jiqing-feng May 9, 2025
7c9281c
enable cpu tests
jiqing-feng May 9, 2025
83cea6b
fix format
jiqing-feng May 9, 2025
bc8723e
fix quantize blockwise output shape
jiqing-feng May 9, 2025
3c07023
fix quant_storage bf16 and gemv cpu
jiqing-feng May 9, 2025
9fbed05
fix cpu tests
jiqing-feng May 12, 2025
59e682d
Merge branch 'main' into ipex
jiqing-feng May 12, 2025
c17e2ff
fix xpu tests
jiqing-feng May 12, 2025
974c60a
fix lib
jiqing-feng May 12, 2025
a21c290
skip xpu dequantize blockwise op check
jiqing-feng May 12, 2025
a5d4a27
fix matmul8bit
jiqing-feng May 12, 2025
959a0d4
skip not used function teests
jiqing-feng May 12, 2025
f44d4a2
fix matmul8bit fp
jiqing-feng May 12, 2025
b9f3c40
check ipex before MatMul8bitFp
jiqing-feng May 12, 2025
21cf8c1
update ipex install guide
jiqing-feng May 12, 2025
a9e5c4a
update install guide
jiqing-feng May 13, 2025
1a77949
Merge branch 'main' into ipex
jiqing-feng May 14, 2025
b0cd993
Merge branch 'main' into ipex
jiqing-feng May 15, 2025
539f5d4
fix error log
jiqing-feng May 15, 2025
005afe0
fix error lof
jiqing-feng May 15, 2025
4471ada
Merge branch 'main' into ipex
jiqing-feng May 16, 2025
cddeec6
update comment
jiqing-feng May 16, 2025
8492010
Merge branch 'bitsandbytes-foundation:main' into ipex
jiqing-feng May 21, 2025
25d01a4
move torch op to default
jiqing-feng May 21, 2025
8ff8947
revert ipex check
jiqing-feng May 21, 2025
82651f9
fix code tabledevice
jiqing-feng May 21, 2025
413bba9
fix code table device
jiqing-feng May 21, 2025
cf8bc14
fix xpu ops
jiqing-feng May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing_extensions import deprecated

import bitsandbytes.functional as F
from bitsandbytes.functional import ipex_cpu, ipex_xpu

# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
Expand Down Expand Up @@ -298,6 +299,64 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
return grad_A, grad_B, None, grad_bias, None


class MatMul8bitFp(torch.autograd.Function):
# For Intel CPU and XPU, the double quant has many unsafe operations which will breaks the finetune.
# Moreover, the MatMul8bitLt is much slower than MatMul8bitFp in finetune.
# The MatMul8bitLt has more mechanisms in computing grad.
# We don't have fast kernel for quant/dequant 8bit in CPU/XPU, so it's very slow.
# We'd like to use dequant + matmul to run finetune currently.

@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
if state.has_fp16_weights or state.CB is None:
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()

if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
state.reset_grads()
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
B = state.CB

CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A, CB, bias)
# to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu]
state.idx = False
ctx.state = state
ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape
ctx.A = A
ctx.dtype_bias = None if bias is None else bias.dtype
return output

@staticmethod
def backward(ctx, grad_output):
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
A = ctx.A
state = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
# compute grad_bias first before changing grad_output dtype
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)

# Cast grad_output to fp16
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

if req_gradB:
grad_B = torch.matmul(A.t(), grad_output).t()

if req_gradA:
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
else:
raise Exception("State must contain CB matrix for backward")

return grad_A, grad_B, None, grad_bias, None


class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
Expand Down Expand Up @@ -366,6 +425,10 @@ def matmul(
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu):
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)


Expand All @@ -378,6 +441,17 @@ def matmul_4bit(
):
assert quant_state is not None

if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
# IPEX CPU will change weight to 4D so don't need transpose
B = B.t() if B.dim() == 2 else B
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
199 changes: 133 additions & 66 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,42 @@ def _(A: torch.Tensor, B: torch.Tensor):
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")

n = A.numel()
blocks = -(n // -blocksize)

absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)

lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(n),
)

# Only FP32 has c++ kernrl
if A.dtype == torch.float32:
blocks = -(n // -blocksize)

absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty_like(A, dtype=torch.uint8)

lib.cquantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(n),
)
else:
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)

diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)

return out, absmax

Expand All @@ -50,18 +70,28 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}")

out = torch.empty_like(A, dtype=dtype)

lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
# Only FP32 has c++ kernrl
if dtype == torch.float32:
out = torch.empty_like(A, dtype=dtype)

lib.cdequantize_blockwise_cpu_fp32(
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_longlong(blocksize),
ct.c_longlong(A.numel()),
)
else:
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
res = out.shape[-1] % blocksize
if res != 0:
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
out = out[: blocks * blocksize + res]
out = out.reshape(A.shape)

return out

Expand All @@ -88,31 +118,63 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
dtype=torch.float32,
device="cpu",
)
_FP4_QUANT_TABLE = torch.tensor(
[
0.0000,
0.0052,
0.6667,
1.0000,
0.3333,
0.5000,
0.1667,
0.2500,
0.0000,
-0.0052,
-0.6667,
-1.0000,
-0.3333,
-0.5000,
-0.1667,
-0.2500,
],
dtype=torch.float32,
device="cpu",
)
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}


@register_kernel("bitsandbytes::quantize_4bit", "cpu")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)

n = A.numel()

# TODO: Support when weight matrix is not divisible by blocksize
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")

# Divide into blocks and normalize
blocks = A.reshape(-1, blocksize)
absmax = blocks.abs().max(dim=1).values.float()
scaled = blocks / absmax.unsqueeze(-1)
full_blocks = n // blocksize
rem = n % blocksize
blocks = full_blocks + 1 if rem else full_blocks
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_flattened = A.reshape(n)

# Scale full blocks of the tensor to [-1, 1]
A_full_blocks = A_flattened[: n - rem].reshape(n // blocksize, blocksize)
absmax[:full_blocks] = torch.abs(A_full_blocks).max(dim=-1)[0]
scaled = torch.clamp(A_full_blocks * (1 / absmax[:full_blocks].view(-1, 1)), -1, 1).reshape(-1)

# Scale any partial block
if rem:
A_rem = A_flattened[-rem:]
absmax[-1] = torch.abs(A_rem).max()
scaled_rem = torch.clamp(A_rem * (1 / absmax[-1]), -1, 1)
scaled = torch.cat([scaled, scaled_rem], dim=0)

# Quantize with the lookup table
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8)
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - CODE[quant_type]), dim=-1, keepdim=True).to(torch.uint8)

# Pack two quantized values per byte
packed = quantized[::2] << 4 | quantized[1::2]
Expand All @@ -133,32 +195,45 @@ def _(
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
torch._check(
A.dtype == torch.uint8,
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
)

A = A.view(-1, 1)

# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
upper = (A >> 4).to(torch.int64)
lower = (A & 0x0F).to(torch.int64)

# Expand to blocks
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)

A = A.reshape(-1)
# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# code is fp32, cast to dtype to avoid the mismatch issue
code = CODE[quant_type].to(dtype)
out_dq = code[out_dq]

# Apply scales
if out_dq.numel() != n:
assert out_dq.numel() == n + 1
out_dq = torch.narrow(out_dq, 0, 0, n)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0

out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1)
if has_rem:
out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1)
out[n - rem :] = out_dq[n - rem :] * absmax[-1]
else:
out = out_dq.view(-1, blocksize) * absmax.view(-1, 1)

out = out.reshape(-1, *shape[1:]).to(dtype)

# Dequantize
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]

# Reshape to original shape
blocks = blocks.reshape(-1, *shape[1:])

return blocks.to(dtype)
return out


@register_kernel("bitsandbytes::gemv_4bit", "cpu")
Expand All @@ -170,17 +245,9 @@ def _(
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
# TODO: We need to determine whether `code` is NF4, FP4, or other.
# Right now we assume NF4, as this is the only one supported on CPU.

B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
B,
absmax,
blocksize,
"nf4",
shape=shapeB,
dtype=A.dtype,
)
# Applied from dequantize_4bit
quant_type = "fp4" if code[1] > 0 else "nf4"
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(B, absmax, blocksize, quant_type, shapeB, A.dtype)

# User called gemv with B.t(), so we need to transpose it back.
# if B.shape[0] == 1:
Expand Down
19 changes: 18 additions & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,28 @@ def get_native_library() -> BNBNativeLibrary:
return BNBNativeLibrary(dll)


try:
# to support Intel CPU/GPU (XPU) backend
import intel_extension_for_pytorch as ipex

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_cpu = None
ipex_xpu = None


try:
lib = get_native_library()
if not ipex_cpu:
logger.warning(
"The installed version of bitsandbytes was compiled without IPEX support. "
"You can install ipex by running `pip install intel_extension_for_pytorch`to get better performance if you use the Intel CPU.",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like extra noise that we'd want to avoid.

Something to point out is that we still plan to ship libbitsandbytes_cpu in our wheels, so for most users, it's going to load a CPU, CUDA, or eventually ROCm or Metal library and we'll hit this logging line. At most we should really only raise this warning when:

  1. We're on a platform with IPEX CPU support. My understanding is this is limited to Linux x86-64.
  2. We expect the user to be using CPU, i.e. no CUDA, XPU, or MPS accelerators available.
    On torch >= 2.6 we could just use torch.accelerator.is_available() and on older versions I think we can overlook privateuse1 backends like HPU or Ascend NPU.
  3. There's some expectation of IPEX being beneficial. We don't want to prompt users to install it if e.g. it needs AVX512 or AMX support to be effective. This is something I can't speak to directly but defer to Intel folks to determine.

Any other thoughts @Titus-von-Koeller ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You were right. I also agree that the log should only exist if no devices like cuda/xpu are available and the CPU is an Intel product.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed it, please review again. Thanks!

except Exception as e:
error_msg = str(e)
logger.error(f"bitsandbytes library load error: {error_msg}\n", exc_info=True)
if not ipex_xpu:
logger.error(f"bitsandbytes library load error: {error_msg}\n", exc_info=True)

# create a mock with error messaging as fallback
lib = ErrorHandlerMockBNBNativeLibrary(error_msg)
Loading