Skip to content

Commit ea15027

Browse files
committed
[xpu/triton] Add trtion dequantization kernel
This PR adds xpu backend and trtion kernel for dequantization nf4 dtype. Trtion is an optional import. Tests: tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional implemented quantize_blockwise with binary search that works faster for XPU tests/test_linear4bit.py Contains WA for gemv_4bit for XPU, for some reason directly passed code causes errors in several tests. For example: `tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=128-uint8-fp32-fc1-fp4-DQ_True-xpu]` Signed-off-by: Dmitrii Makarenko <dmitrii.makarenko@intel.com>
1 parent 42bc729 commit ea15027

File tree

4 files changed

+565
-0
lines changed

4 files changed

+565
-0
lines changed

bitsandbytes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
if torch.cuda.is_available():
3535
from .backends.cuda import ops as cuda_ops
3636

37+
if torch.xpu.is_available():
38+
from .backends.xpu import ops as xpu_ops
39+
3740

3841
def _import_backends():
3942
"""

bitsandbytes/backends/xpu/__init__.py

Whitespace-only changes.

bitsandbytes/backends/xpu/ops.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
from collections.abc import Sequence
2+
import warnings
3+
4+
import torch
5+
6+
from ..._ops import register_kernel
7+
8+
try:
9+
from . import triton_kernels
10+
11+
triton_available = True
12+
except ImportError:
13+
triton_available = False
14+
15+
16+
@torch.compile
17+
def quantize_blockwise_torch(A, blocksize, code, absmax, quantized_out):
18+
n = A.numel()
19+
rem = n % blocksize
20+
has_rem = rem > 0
21+
blocks = n // blocksize + has_rem
22+
A_reshaped = A.reshape(n)
23+
A_com = A_reshaped[: n - rem]
24+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
25+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
26+
scaled_A = torch.clamp(A_com_reshaped / absmax[: blocks - has_rem].view(-1, 1), -1, 1)
27+
scaled_A = scaled_A.reshape(-1)
28+
if has_rem:
29+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
30+
scaled_A_rem = torch.clamp((A_reshaped[n - rem :] / absmax[-1]), -1, 1)
31+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
32+
33+
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
34+
quantized_out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
35+
return quantized_out, absmax
36+
37+
38+
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
39+
torch._check_is_size(blocksize)
40+
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
41+
42+
n = A.numel()
43+
blocks = -(n // -blocksize)
44+
45+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
46+
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
47+
48+
# out, absmax = quantize_blockwise_torch(A, blocksize, code, absmax, out)
49+
50+
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
51+
out = out.reshape(A.shape)
52+
return out, absmax
53+
54+
55+
def dequantize_blockwise(
56+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
57+
) -> torch.Tensor:
58+
torch._check_is_size(blocksize)
59+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
60+
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
61+
62+
out = torch.empty_like(A, dtype=dtype, device=A.device)
63+
triton_kernels.dequant_int8_fp16(
64+
A,
65+
code,
66+
absmax,
67+
out,
68+
blocksize,
69+
)
70+
71+
return out
72+
73+
74+
_NF4_QUANT_TABLE = torch.tensor(
75+
[
76+
-1.0,
77+
-0.6961928009986877,
78+
-0.5250730514526367,
79+
-0.39491748809814453,
80+
-0.28444138169288635,
81+
-0.18477343022823334,
82+
-0.09105003625154495,
83+
0.0,
84+
0.07958029955625534,
85+
0.16093020141124725,
86+
0.24611230194568634,
87+
0.33791524171829224,
88+
0.44070982933044434,
89+
0.5626170039176941,
90+
0.7229568362236023,
91+
1.0,
92+
],
93+
dtype=torch.float32,
94+
device="xpu",
95+
)
96+
97+
_FP4_QUANT_TABLE = torch.tensor(
98+
[
99+
0.0000,
100+
0.0052,
101+
0.6667,
102+
1.0000,
103+
0.3333,
104+
0.5000,
105+
0.1667,
106+
0.2500,
107+
0.0000,
108+
-0.0052,
109+
-0.6667,
110+
-1.0000,
111+
-0.3333,
112+
-0.5000,
113+
-0.1667,
114+
-0.2500,
115+
],
116+
dtype=torch.float32,
117+
device="xpu",
118+
)
119+
120+
121+
@register_kernel("bitsandbytes::quantize_4bit", "xpu")
122+
def _(
123+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
124+
) -> tuple[torch.Tensor, torch.Tensor]:
125+
torch._check_is_size(blocksize)
126+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
127+
torch._check(
128+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
129+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
130+
)
131+
132+
n = A.numel()
133+
134+
# TODO: Support when weight matrix is not divisible by blocksize
135+
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
136+
137+
# Divide into blocks and normalize
138+
blocks = A.reshape(-1, blocksize)
139+
absmax = blocks.abs().max(dim=1).values.float()
140+
scaled = blocks / absmax.unsqueeze(-1)
141+
142+
# Quantize with the lookup table
143+
if quant_type == "fp4":
144+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _FP4_QUANT_TABLE), dim=-1, keepdim=True).to(
145+
torch.uint8
146+
)
147+
else:
148+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(
149+
torch.uint8
150+
)
151+
152+
# Pack two quantized values per byte
153+
packed = quantized[::2] << 4 | quantized[1::2]
154+
155+
if quant_storage != torch.uint8:
156+
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
157+
158+
return packed, absmax.float()
159+
160+
161+
def dequantize_4bit(
162+
A: torch.Tensor,
163+
absmax: torch.Tensor,
164+
blocksize: int,
165+
quant_type: str,
166+
shape: Sequence[int],
167+
dtype: torch.dtype,
168+
) -> torch.Tensor:
169+
torch._check_is_size(blocksize)
170+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
171+
torch._check(
172+
dtype in [torch.bfloat16, torch.float16, torch.float32],
173+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
174+
)
175+
# torch._check(
176+
# A.dtype == torch.uint8,
177+
# lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}",
178+
# )
179+
# Check if this is fine and fast
180+
if A.dtype != torch.uint8:
181+
A = A.squeeze().view(torch.uint8).unsqueeze(1)
182+
183+
out = torch.empty(shape, dtype=dtype, device=A.device)
184+
185+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
186+
return out
187+
188+
189+
def dequantize_4bit_inplace(
190+
A: torch.Tensor,
191+
absmax: torch.Tensor,
192+
blocksize: int,
193+
quant_type: str,
194+
shape: Sequence[int],
195+
dtype: torch.dtype,
196+
out: torch.Tensor,
197+
) -> None:
198+
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
199+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
200+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
201+
202+
203+
def gemv_4bit(
204+
A: torch.Tensor,
205+
B: torch.Tensor,
206+
shapeB: Sequence[int],
207+
absmax: torch.Tensor,
208+
code: torch.Tensor,
209+
blocksize: int,
210+
) -> torch.Tensor:
211+
# TODO: We need to determine whether `code` is NF4, FP4, or other.
212+
# Right now we assume NF4, as this is the only one supported on CPU.
213+
quant_type = "fp4" if code[1] > 0 else "nf4"
214+
B_dq = dequantize_4bit(B, absmax, blocksize, quant_type, shapeB, A.dtype)
215+
216+
# For some reason directly passing code causes errors in some cases like:
217+
# tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=128-uint8-fp32-fc1-fp4-DQ_True-xpu]
218+
#
219+
# B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device)
220+
# if B.dtype != torch.uint8:
221+
# B = B.squeeze().view(torch.uint8).unsqueeze(1)
222+
223+
# triton_kernels._dequantize_4bit_impl_passing_code(
224+
# B,
225+
# absmax,
226+
# blocksize,
227+
# code,
228+
# dtype=A.dtype,
229+
# out=B_dq,
230+
# )
231+
232+
# User called gemv with B.t(), so we need to transpose it back.
233+
# if B.shape[0] == 1:
234+
# B_dq = B_dq.t()
235+
236+
return torch.nn.functional.linear(
237+
A,
238+
B_dq,
239+
bias=None,
240+
)
241+
242+
243+
if triton_available:
244+
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(quantize_blockwise)
245+
register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(dequantize_blockwise)
246+
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(dequantize_4bit_inplace)
247+
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(dequantize_4bit)
248+
register_kernel("bitsandbytes::gemv_4bit", "xpu")(gemv_4bit)
249+
else:
250+
warnings.warn("XPU available, but trtion package is missing.")

0 commit comments

Comments
 (0)