Skip to content

Commit aaa71d7

Browse files
authored
Enable CPU/XPU native and ipex path (#1628)
* enable ipex Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix cpu 8bit quantization Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix int8 and nf4 cpu inference Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * add cpu fp4 and rem Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix dequantize nf4 xpu Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix ipex op Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix dequantize nf4 name Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix dequantize nf4 ipex Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix matmul8bitfp Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * enable cpu tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix format Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix quantize blockwise output shape Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix quant_storage bf16 and gemv cpu Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix cpu tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix xpu tests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix lib Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * skip xpu dequantize blockwise op check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix matmul8bit Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * skip not used function teests Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix matmul8bit fp Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * check ipex before MatMul8bitFp Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update ipex install guide Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update install guide Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix error log Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix error lof Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * update comment Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * move torch op to default Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * revert ipex check Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix code tabledevice Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix code table device Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * fix xpu ops Signed-off-by: jiqing-feng <jiqing.feng@intel.com> --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 1e54f91 commit aaa71d7

File tree

18 files changed

+621
-258
lines changed

18 files changed

+621
-258
lines changed

bitsandbytes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
if torch.cuda.is_available():
3535
from .backends.cuda import ops as cuda_ops
3636

37+
if torch.xpu.is_available():
38+
from .backends.xpu import ops as xpu_ops
39+
3740

3841
def _import_backends():
3942
"""

bitsandbytes/_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import torch
66

7+
from .cextension import ipex_cpu, ipex_xpu
8+
79
_IS_TORCH_GTE_24 = False
810

911
if hasattr(torch.library, "register_fake"):
@@ -327,3 +329,22 @@ def _(
327329
)
328330
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
329331
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
332+
333+
334+
if ipex_cpu or ipex_xpu:
335+
# Register the dequantize_nf4_ipex implementation
336+
torch.library.define(
337+
"bitsandbytes::dequantize_nf4_ipex",
338+
"(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor",
339+
)
340+
341+
@register_fake("bitsandbytes::dequantize_nf4_ipex")
342+
def _(
343+
A: torch.Tensor,
344+
absmax: torch.Tensor,
345+
blocksize: int,
346+
shape: Sequence[int],
347+
dtype: torch.dtype,
348+
) -> torch.Tensor:
349+
torch._check_is_size(blocksize)
350+
return torch.empty(shape, dtype=dtype, device=A.device)

bitsandbytes/autograd/_functions.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing_extensions import deprecated
99

1010
import bitsandbytes.functional as F
11+
from bitsandbytes.functional import ipex_cpu, ipex_xpu
1112

1213
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
1314
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
@@ -298,6 +299,63 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
298299
return grad_A, grad_B, None, grad_bias, None
299300

300301

302+
class MatMul8bitFp(torch.autograd.Function):
303+
# For Intel CPU and XPU MatMul8bitFp is much faster (~3x) than MatMul8bitLt in finetune.
304+
# Because the MatMul8bitLt has more mechanisms in computing grad.
305+
# We don't have fast kernel for quant/dequant 8bit in CPU/XPU, so it's very slow.
306+
# We'd like to use dequant + matmul to run finetune with good performance.
307+
308+
@staticmethod
309+
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
310+
if state.has_fp16_weights or state.CB is None:
311+
has_grad = getattr(B, "grad", None) is not None
312+
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
313+
if is_transposed:
314+
B = B.contiguous()
315+
316+
if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
317+
state.reset_grads()
318+
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
319+
B = state.CB
320+
321+
CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
322+
output = torch.nn.functional.linear(A, CB, bias)
323+
# to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu]
324+
state.idx = False
325+
ctx.state = state
326+
ctx.dtype_A = A.dtype
327+
ctx.grad_shape = A.shape
328+
ctx.A = A
329+
ctx.dtype_bias = None if bias is None else bias.dtype
330+
return output
331+
332+
@staticmethod
333+
def backward(ctx, grad_output):
334+
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
335+
A = ctx.A
336+
state = ctx.state
337+
grad_A = grad_B = grad_bias = None
338+
if req_gradBias:
339+
# compute grad_bias first before changing grad_output dtype
340+
grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias)
341+
342+
# Cast grad_output to fp16
343+
if len(grad_output.shape) == 3:
344+
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
345+
346+
if req_gradB:
347+
grad_B = torch.matmul(A.t(), grad_output).t()
348+
349+
if req_gradA:
350+
if state.CB is not None:
351+
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
352+
grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
353+
else:
354+
raise Exception("State must contain CB matrix for backward")
355+
356+
return grad_A, grad_B, None, grad_bias, None
357+
358+
301359
class MatMul4Bit(torch.autograd.Function):
302360
# forward is the same, but we added the fallback for pre-turing GPUs
303361
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@@ -366,6 +424,10 @@ def matmul(
366424
state = state or MatmulLtState()
367425
if threshold > 0.0:
368426
state.threshold = threshold
427+
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
428+
if state.is_training:
429+
if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu):
430+
return MatMul8bitFp.apply(A, B, out, bias, state)
369431
return MatMul8bitLt.apply(A, B, out, bias, state)
370432

371433

@@ -378,6 +440,17 @@ def matmul_4bit(
378440
):
379441
assert quant_state is not None
380442

443+
if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
444+
if getattr(quant_state, "ipex", False):
445+
# IPEX CPU will change weight to 4D so don't need transpose
446+
B = B.t() if B.dim() == 2 else B
447+
out = F.gemv_4bit(A, B, out, state=quant_state)
448+
if bias is not None:
449+
out += bias
450+
return out
451+
else:
452+
return MatMul4Bit.apply(A, B, out, bias, quant_state)
453+
381454
if A.numel() == A.shape[-1] and A.requires_grad == False:
382455
if A.shape[-1] % quant_state.blocksize != 0:
383456
warn(

bitsandbytes/backends/cpu/ops.py

Lines changed: 76 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ..._ops import register_kernel
99
from ...cextension import lib
10+
from ..utils import ipex_cpu
1011

1112
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
1213
# However, we can overflow if we use this without AVX512_VNNI support.
@@ -26,22 +27,42 @@ def _(A: torch.Tensor, B: torch.Tensor):
2627
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
2728
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
2829
torch._check_is_size(blocksize)
29-
torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on cpu, got {A.dtype}")
3030

3131
n = A.numel()
32-
blocks = -(n // -blocksize)
3332

34-
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
35-
out = torch.empty_like(A, dtype=torch.uint8)
36-
37-
lib.cquantize_blockwise_cpu_fp32(
38-
get_ptr(code),
39-
get_ptr(A),
40-
get_ptr(absmax),
41-
get_ptr(out),
42-
ct.c_longlong(blocksize),
43-
ct.c_longlong(n),
44-
)
33+
# Only FP32 has c++ kernrl
34+
if A.dtype == torch.float32:
35+
blocks = -(n // -blocksize)
36+
37+
absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
38+
out = torch.empty_like(A, dtype=torch.uint8)
39+
40+
lib.cquantize_blockwise_cpu_fp32(
41+
get_ptr(code),
42+
get_ptr(A),
43+
get_ptr(absmax),
44+
get_ptr(out),
45+
ct.c_longlong(blocksize),
46+
ct.c_longlong(n),
47+
)
48+
else:
49+
rem = n % blocksize
50+
has_rem = rem > 0
51+
blocks = n // blocksize + has_rem
52+
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
53+
A_reshaped = A.reshape(n)
54+
A_com = A_reshaped[: n - rem]
55+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
56+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
57+
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
58+
scaled_A = scaled_A.reshape(-1)
59+
if has_rem:
60+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
61+
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
62+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
63+
64+
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
65+
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
4566

4667
return out, absmax
4768

@@ -50,144 +71,50 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor
5071
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
5172
torch._check_is_size(blocksize)
5273
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
53-
torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on cpu, got {dtype}")
54-
55-
out = torch.empty_like(A, dtype=dtype)
5674

57-
lib.cdequantize_blockwise_cpu_fp32(
58-
get_ptr(code),
59-
get_ptr(A),
60-
get_ptr(absmax),
61-
get_ptr(out),
62-
ct.c_longlong(blocksize),
63-
ct.c_longlong(A.numel()),
64-
)
75+
# Only FP32 has c++ kernrl
76+
if dtype == torch.float32:
77+
out = torch.empty_like(A, dtype=dtype)
78+
79+
lib.cdequantize_blockwise_cpu_fp32(
80+
get_ptr(code),
81+
get_ptr(A),
82+
get_ptr(absmax),
83+
get_ptr(out),
84+
ct.c_longlong(blocksize),
85+
ct.c_longlong(A.numel()),
86+
)
87+
else:
88+
out = code[A.reshape(-1).int()]
89+
blocks = out.shape[-1] // blocksize
90+
res = out.shape[-1] % blocksize
91+
if res != 0:
92+
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
93+
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
94+
out = out[: blocks * blocksize + res]
95+
out = out.reshape(A.shape)
6596

6697
return out
6798

6899

69-
_NF4_QUANT_TABLE = torch.tensor(
70-
[
71-
-1.0,
72-
-0.6961928009986877,
73-
-0.5250730514526367,
74-
-0.39491748809814453,
75-
-0.28444138169288635,
76-
-0.18477343022823334,
77-
-0.09105003625154495,
78-
0.0,
79-
0.07958029955625534,
80-
0.16093020141124725,
81-
0.24611230194568634,
82-
0.33791524171829224,
83-
0.44070982933044434,
84-
0.5626170039176941,
85-
0.7229568362236023,
86-
1.0,
87-
],
88-
dtype=torch.float32,
89-
device="cpu",
90-
)
91-
92-
93-
@register_kernel("bitsandbytes::quantize_4bit", "cpu")
94-
def _(
95-
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
96-
) -> tuple[torch.Tensor, torch.Tensor]:
97-
torch._check_is_size(blocksize)
98-
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
99-
torch._check(
100-
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
101-
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
102-
)
103-
104-
n = A.numel()
105-
106-
# TODO: Support when weight matrix is not divisible by blocksize
107-
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
108-
109-
# Divide into blocks and normalize
110-
blocks = A.reshape(-1, blocksize)
111-
absmax = blocks.abs().max(dim=1).values.float()
112-
scaled = blocks / absmax.unsqueeze(-1)
113-
114-
# Quantize with the lookup table
115-
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(torch.uint8)
116-
117-
# Pack two quantized values per byte
118-
packed = quantized[::2] << 4 | quantized[1::2]
119-
120-
if quant_storage != torch.uint8:
121-
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
122-
123-
return packed, absmax.float()
124-
125-
126-
@register_kernel("bitsandbytes::dequantize_4bit", "cpu")
127-
def _(
128-
A: torch.Tensor,
129-
absmax: torch.Tensor,
130-
blocksize: int,
131-
quant_type: str,
132-
shape: Sequence[int],
133-
dtype: torch.dtype,
134-
) -> torch.Tensor:
135-
torch._check_is_size(blocksize)
136-
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
137-
torch._check(
138-
dtype in [torch.bfloat16, torch.float16, torch.float32],
139-
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
140-
)
141-
torch._check(
142-
A.dtype == torch.uint8,
143-
lambda: f"Blockwise 4bit dequantization on CPU only supports uint8 storage, got {A.dtype}",
144-
)
145-
146-
A = A.view(-1, 1)
147-
148-
# Grab upper and lower nibbles. Using int64 for indexing in the LUT.
149-
upper = (A >> 4).to(torch.int64)
150-
lower = (A & 0x0F).to(torch.int64)
151-
152-
# Expand to blocks
153-
blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize)
154-
155-
# Dequantize
156-
blocks = _NF4_QUANT_TABLE[blocks] * absmax[:, None]
157-
158-
# Reshape to original shape
159-
blocks = blocks.reshape(-1, *shape[1:])
160-
161-
return blocks.to(dtype)
162-
163-
164-
@register_kernel("bitsandbytes::gemv_4bit", "cpu")
165-
def _(
166-
A: torch.Tensor,
167-
B: torch.Tensor,
168-
shapeB: Sequence[int],
169-
absmax: torch.Tensor,
170-
code: torch.Tensor,
171-
blocksize: int,
172-
) -> torch.Tensor:
173-
# TODO: We need to determine whether `code` is NF4, FP4, or other.
174-
# Right now we assume NF4, as this is the only one supported on CPU.
175-
176-
B_dq = torch.ops.bitsandbytes.dequantize_4bit.default(
177-
B,
178-
absmax,
179-
blocksize,
180-
"nf4",
181-
shape=shapeB,
182-
dtype=A.dtype,
183-
)
184-
185-
# User called gemv with B.t(), so we need to transpose it back.
186-
# if B.shape[0] == 1:
187-
# B_dq = B_dq.t()
188-
189-
return torch.nn.functional.linear(
190-
A,
191-
B_dq,
192-
bias=None,
193-
)
100+
if ipex_cpu:
101+
from bitsandbytes.utils import _reverse_4bit_compress_format
102+
103+
@register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu")
104+
def _(
105+
A: torch.Tensor,
106+
absmax: torch.Tensor,
107+
blocksize: int,
108+
shape: Sequence[int],
109+
dtype: torch.dtype,
110+
) -> torch.Tensor:
111+
ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2)
112+
A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)
113+
return torch.ops.bitsandbytes.dequantize_4bit.default(
114+
A,
115+
absmax,
116+
blocksize,
117+
"nf4",
118+
shape,
119+
dtype,
120+
)

0 commit comments

Comments
 (0)