Skip to content

Enable CPU/XPU native and ipex path #1628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ba79025
enable ipex
jiqing-feng May 7, 2025
958d75b
fix cpu 8bit quantization
jiqing-feng May 7, 2025
f5c0b01
fix int8 and nf4 cpu inference
jiqing-feng May 7, 2025
7f2d8a8
add cpu fp4 and rem
jiqing-feng May 8, 2025
97d5bd1
fix dequantize nf4 xpu
jiqing-feng May 8, 2025
5563c35
Merge branch 'main' into ipex
jiqing-feng May 8, 2025
7b72673
fix ipex op
jiqing-feng May 9, 2025
52e32af
fix dequantize nf4 name
jiqing-feng May 9, 2025
fda3d70
fix dequantize nf4 ipex
jiqing-feng May 9, 2025
5ce3296
Merge branch 'main' into ipex
jiqing-feng May 9, 2025
f51678e
fix matmul8bitfp
jiqing-feng May 9, 2025
7c9281c
enable cpu tests
jiqing-feng May 9, 2025
83cea6b
fix format
jiqing-feng May 9, 2025
bc8723e
fix quantize blockwise output shape
jiqing-feng May 9, 2025
3c07023
fix quant_storage bf16 and gemv cpu
jiqing-feng May 9, 2025
9fbed05
fix cpu tests
jiqing-feng May 12, 2025
59e682d
Merge branch 'main' into ipex
jiqing-feng May 12, 2025
c17e2ff
fix xpu tests
jiqing-feng May 12, 2025
974c60a
fix lib
jiqing-feng May 12, 2025
a21c290
skip xpu dequantize blockwise op check
jiqing-feng May 12, 2025
a5d4a27
fix matmul8bit
jiqing-feng May 12, 2025
959a0d4
skip not used function teests
jiqing-feng May 12, 2025
f44d4a2
fix matmul8bit fp
jiqing-feng May 12, 2025
b9f3c40
check ipex before MatMul8bitFp
jiqing-feng May 12, 2025
21cf8c1
update ipex install guide
jiqing-feng May 12, 2025
a9e5c4a
update install guide
jiqing-feng May 13, 2025
1a77949
Merge branch 'main' into ipex
jiqing-feng May 14, 2025
b0cd993
Merge branch 'main' into ipex
jiqing-feng May 15, 2025
539f5d4
fix error log
jiqing-feng May 15, 2025
005afe0
fix error lof
jiqing-feng May 15, 2025
4471ada
Merge branch 'main' into ipex
jiqing-feng May 16, 2025
cddeec6
update comment
jiqing-feng May 16, 2025
8492010
Merge branch 'bitsandbytes-foundation:main' into ipex
jiqing-feng May 21, 2025
25d01a4
move torch op to default
jiqing-feng May 21, 2025
8ff8947
revert ipex check
jiqing-feng May 21, 2025
82651f9
fix code tabledevice
jiqing-feng May 21, 2025
413bba9
fix code table device
jiqing-feng May 21, 2025
cf8bc14
fix xpu ops
jiqing-feng May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,29 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
return grad_A, grad_B, None, grad_bias, None


class MatMul8bitFp(torch.autograd.Function):
# For Intel CPU and XPU, the double quant has many unsafe operations which will breaks the finetune.
# We'd like to use dequant + matmul to run finetune currently.

@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
CB = B.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)).t()
output = torch.matmul(A, CB).to(A.dtype)
ctx.state = state
ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape
return output

@staticmethod
def backward(ctx, grad_output):
state = ctx.state
B = state.CxB if state.CxB is not None else state.CB
CB = B.to(ctx.dtype_A).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)

return grad_A, None, None, None, None


class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
Expand Down Expand Up @@ -366,6 +389,8 @@ def matmul(
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
if A.device.type in ("cpu", "xpu") and state.is_training:
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)


Expand All @@ -378,6 +403,16 @@ def matmul_4bit(
):
assert quant_state is not None

if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
if getattr(quant_state, "ipex", False):
B = B.t() if B.dim() == 2 else B
out = F.gemv_4bit(A, B, out, state=quant_state)
if bias is not None:
out += bias
return out
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(
Expand Down
126 changes: 84 additions & 42 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,31 +88,62 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
dtype=torch.float32,
device="cpu",
)
_FP4_QUANT_TABLE = torch.tensor(
[
0.0000,
0.0052,
0.6667,
1.0000,
0.3333,
0.5000,
0.1667,
0.2500,
0.0000,
-0.0052,
-0.6667,
-1.0000,
-0.3333,
-0.5000,
-0.1667,
-0.2500,
],
dtype=torch.float32,
device="cpu",
)
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}


@register_kernel("bitsandbytes::quantize_4bit", "cpu")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
)

n = A.numel()

# TODO: Support when weight matrix is not divisible by blocksize
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")

# Divide into blocks and normalize
blocks = A.reshape(-1, blocksize)
absmax = blocks.abs().max(dim=1).values.float()
scaled = blocks / absmax.unsqueeze(-1)

blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0

# Scale tensor to [-1, 1]
absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype)
A_reshaped = A.reshape(n)
A_com_reshaped = A_reshaped[: n - rem].reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled = scaled.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled = torch.cat([scaled, scaled_rem], dim=0)
# Quantize with the lookup table
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8)
quant_table = CODE[quant_type]
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - quant_table), dim=-1, keepdim=True).to(torch.uint8)

# Pack two quantized values per byte
packed = quantized[::2] << 4 | quantized[1::2]
Expand All @@ -133,32 +164,47 @@ def _(
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}")
torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32],
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
)
torch._check(
A.dtype == torch.uint8,
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
)

A = A.view(-1, 1)

# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
upper = (A >> 4).to(torch.int64)
lower = (A & 0x0F).to(torch.int64)

# Expand to blocks
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)

# Dequantize
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
# Enable non uint8 dtype
device = A.device
if A.dtype != torch.uint8:
bytes_value = A.cpu().numpy().tobytes()
A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device)

A = A.reshape(-1)
# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
n = out_dq.numel()
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# code is fp32, cast to dtype to avoid the mismatch issue
code = CODE[quant_type].to(dtype)
out_dq = code[out_dq]

# Apply scales
if out_dq.numel() != n:
assert out_dq.numel() == n + 1
out_dq = torch.narrow(out_dq, 0, 0, n)
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
rem = n % blocksize
has_rem = rem > 0

out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1)
if has_rem:
out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1)
out[n - rem :] = out_dq[n - rem :] * absmax[-1]
else:
out = out_dq.view(-1, blocksize) * absmax.view(-1, 1)

out = out.reshape(-1, *shape[1:]).to(dtype)

# Reshape to original shape
blocks = blocks.reshape(-1, *shape[1:])

return blocks.to(dtype)
return out


@register_kernel("bitsandbytes::gemv_4bit", "cpu")
Expand All @@ -170,17 +216,13 @@ def _(
code: torch.Tensor,
blocksize: int,
) -> torch.Tensor:
# TODO: We need to determine whether `code` is NF4, FP4, or other.
# Right now we assume NF4, as this is the only one supported on CPU.

B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
B,
absmax,
blocksize,
"nf4",
shape=shapeB,
dtype=A.dtype,
)
# Applied from dequantize_4bit
B = B.view(-1, 1)
upper = (B >> 4).to(torch.int64)
lower = (B & 0x0F).to(torch.int64)
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
B_dq = code[blocks] * absmax[:, None]
B_dq = B_dq.reshape(-1, *shapeB[1:]).to(A.dtype)

# User called gemv with B.t(), so we need to transpose it back.
# if B.shape[0] == 1:
Expand Down
44 changes: 32 additions & 12 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,40 @@ def get_native_library() -> BNBNativeLibrary:
return BNBNativeLibrary(dll)


try:
# to support Intel CPU/GPU (XPU) backend
import intel_extension_for_pytorch as ipex

ipex_cpu = ipex if ipex._C._has_cpu() else None
ipex_xpu = ipex if ipex._C._has_xpu() else None
except BaseException:
ipex_cpu = None
ipex_xpu = None


try:
lib = get_native_library()
if not ipex_cpu:
logger.warning(
"The installed version of bitsandbytes was compiled without IPEX support. "
"You can install ipex by running `pip install intel_extension_for_pytorch`to get better performance if you use the Intel CPU.",
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like extra noise that we'd want to avoid.

Something to point out is that we still plan to ship libbitsandbytes_cpu in our wheels, so for most users, it's going to load a CPU, CUDA, or eventually ROCm or Metal library and we'll hit this logging line. At most we should really only raise this warning when:

  1. We're on a platform with IPEX CPU support. My understanding is this is limited to Linux x86-64.
  2. We expect the user to be using CPU, i.e. no CUDA, XPU, or MPS accelerators available.
    On torch >= 2.6 we could just use torch.accelerator.is_available() and on older versions I think we can overlook privateuse1 backends like HPU or Ascend NPU.
  3. There's some expectation of IPEX being beneficial. We don't want to prompt users to install it if e.g. it needs AVX512 or AMX support to be effective. This is something I can't speak to directly but defer to Intel folks to determine.

Any other thoughts @Titus-von-Koeller ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You were right. I also agree that the log should only exist if no devices like cuda/xpu are available and the CPU is an Intel product.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed it, please review again. Thanks!

except Exception as e:
lib = None
logger.error(f"Could not load bitsandbytes native library: {e}", exc_info=True)
if torch.cuda.is_available():
logger.warning(
"""
CUDA Setup failed despite CUDA being available. Please run the following command to get more information:

python -m bitsandbytes

Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues
""",
if not ipex_xpu:
logger.error(
f"Could not load bitsandbytes native library: {e}. If you use Intel CPU or XPU, please pip install intel_extension_for_pytorch by following the instruction in https://pytorch-extension.intel.com/installation.\n",
exc_info=True,
)
if torch.cuda.is_available():
logger.warning(
"""
CUDA Setup failed despite CUDA being available. Please run the following command to get more information:

python -m bitsandbytes

Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues
""",
)
78 changes: 76 additions & 2 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from torch import Tensor
from typing_extensions import deprecated

from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
from bitsandbytes.utils import pack_dict_to_tensor, reverse_4bit_compress_format, unpack_tensor_to_dict

from .cextension import lib
from .cextension import ipex_cpu, ipex_xpu, lib

name2qmap = {}

Expand Down Expand Up @@ -1123,6 +1123,15 @@ def dequantize_4bit(
if absmax.dtype != torch.float32:
absmax = absmax.float()

# IPEX format is different, we need extra process.
if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4":
if A.device.type == "xpu":
out = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t()
return out
elif A.device.type == "cpu":
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2)
A = reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)

if out is not None:
torch.ops.bitsandbytes.dequantize_4bit.out(
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
Expand Down Expand Up @@ -1710,6 +1719,25 @@ def gemv_4bit(
if state.nested:
absmax = dequantize_blockwise(absmax, state.state2) + state.offset

if getattr(state, "ipex", False) and state.quant_type == "nf4":
# compute_dtype: 1 indicates fp16, 2 indicates bf16
compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
out = torch.ops.torch_ipex.woq_linear(
A,
B,
"nf4",
state.shape,
state.new_scales,
state.new_zeros,
None,
None,
state.blocksize,
compute_dtype,
1,
state.compensation,
)
return out

if out is not None:
torch.ops.bitsandbytes.gemv_4bit.out(
A,
Expand Down Expand Up @@ -2508,3 +2536,49 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
return x.to(dtype)
else:
return None


def enable_ipex_fusion(linear, x):
quant_state = linear.weight.quant_state

if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()

quant_state.absmax = absmax
quant_state.nested = False
delattr(quant_state, "state2")

if x.device.type == "cpu" and ipex_cpu:
converted_weight = reverse_4bit_compress_format(linear.weight.data)
new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
"nf4",
quant_state.shape, # weight shape
quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
None, # zero_points
None, # bias
None, # batch_size
quant_state.blocksize,
2,
)
elif x.device.type == "xpu" and ipex_xpu:
new_weight = reverse_4bit_compress_format(linear.weight.data)
new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
new_zeros = None
compensation = None
new_scales = list(new_scales)
if not linear.training and not x.requires_grad:
new_weight = new_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
else:
raise ValueError(
"Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.7"
)

linear.weight.data = new_weight.data
linear.weight.quant_state.ipex = True
linear.weight.quant_state.new_scales = new_scales
linear.weight.quant_state.new_zeros = new_zeros
linear.weight.quant_state.compensation = compensation
Loading