Skip to content

Commit fbb2d00

Browse files
committed
[xpu/triton] Add trtion dequantization kernel
This PR adds xpu backend and trtion kernel for dequantization nf4 dtype. Trtion is an optional import. Tests: tests/test_functional.py::TestQuantize4BitFunctional supported nf4/fp4 cases tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional implemented quantize_blockwise with binary search that works faster for XPU tests/test_linear4bit.py Contains WA for gemv_4bit for XPU, for some reason directly passed code causes errors in several tests. For example: `tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=128-uint8-fp32-fc1-fp4-DQ_True-xpu]` Signed-off-by: Dmitrii Makarenko <dmitrii.makarenko@intel.com>
1 parent 4011273 commit fbb2d00

File tree

4 files changed

+635
-0
lines changed

4 files changed

+635
-0
lines changed

bitsandbytes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
if torch.cuda.is_available():
3535
from .backends.cuda import ops as cuda_ops
3636

37+
if torch.xpu.is_available():
38+
from .backends.xpu import ops as xpu_ops
39+
3740

3841
def _import_backends():
3942
"""

bitsandbytes/backends/xpu/__init__.py

Whitespace-only changes.

bitsandbytes/backends/xpu/ops.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
from collections.abc import Sequence
2+
import warnings
3+
4+
import torch
5+
6+
from ..._ops import register_kernel
7+
from .utils import _FP4_QUANT_TABLE, _NF4_QUANT_TABLE
8+
9+
try:
10+
from . import triton_kernels
11+
12+
triton_available = True
13+
except ImportError as e:
14+
print("Import error:", e)
15+
triton_available = False
16+
17+
18+
# torch compile:
19+
# 1.53s call tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional::test_dynamic_blockwise_quantization[signed=F-256-nested=T-bf16-xpu]
20+
#
21+
# triton:
22+
# 1.07s call tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional::test_dynamic_blockwise_quantization[signed=F-256-nested=T-bf16-xpu]
23+
@torch.compile
24+
def quantize_blockwise_torch(A, code, blocksize):
25+
n = A.numel()
26+
blocks = -(n // -blocksize)
27+
28+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
29+
quantized_out = torch.empty_like(A.flatten(), dtype=torch.uint8)
30+
31+
rem = n % blocksize
32+
has_rem = rem > 0
33+
blocks = n // blocksize + has_rem
34+
A_reshaped = A.reshape(n)
35+
A_com = A_reshaped[: n - rem]
36+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
37+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
38+
scaled_A = torch.clamp(A_com_reshaped / absmax[: blocks - has_rem].view(-1, 1), -1, 1)
39+
scaled_A = scaled_A.reshape(-1)
40+
if has_rem:
41+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
42+
scaled_A_rem = torch.clamp((A_reshaped[n - rem :] / absmax[-1]), -1, 1)
43+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
44+
45+
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
46+
quantized_out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
47+
quantized_out = quantized_out.reshape(A.shape)
48+
return quantized_out, absmax
49+
50+
51+
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
52+
torch._check_is_size(blocksize)
53+
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
54+
55+
n = A.numel()
56+
blocks = -(n // -blocksize)
57+
58+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
59+
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
60+
61+
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
62+
out = out.reshape(A.shape)
63+
64+
return out, absmax
65+
66+
67+
def dequantize_blockwise(
68+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
69+
) -> torch.Tensor:
70+
torch._check_is_size(blocksize)
71+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
72+
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
73+
74+
out = torch.empty_like(A, dtype=dtype, device=A.device)
75+
triton_kernels.dequant_int8_blockwise(
76+
A,
77+
code,
78+
absmax,
79+
out,
80+
blocksize,
81+
)
82+
83+
return out
84+
85+
86+
def dequantize_blockwise_inplace(
87+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
88+
) -> None:
89+
torch._check_is_size(blocksize)
90+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
91+
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
92+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
93+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
94+
95+
triton_kernels.dequant_int8_blockwise(
96+
A,
97+
code,
98+
absmax,
99+
out,
100+
blocksize,
101+
)
102+
103+
104+
# torch compile
105+
# 1.01s call tests/test_functional.py::TestQuantize4BitFunctional::test_4bit_quant[64-fp4-fp32-xpu]
106+
#
107+
# triton
108+
# 0.80s call tests/test_functional.py::TestQuantize4BitFunctional::test_4bit_quant[64-fp4-fp32-xpu]
109+
@torch.compile
110+
def quantize_4bit_torch(
111+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
112+
) -> tuple[torch.Tensor, torch.Tensor]:
113+
# Divide into blocks and normalize
114+
blocks = A.reshape(-1, blocksize)
115+
absmax = blocks.abs().max(dim=1).values.float()
116+
scaled = blocks / absmax.unsqueeze(-1)
117+
if quant_type == "fp4":
118+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _FP4_QUANT_TABLE), dim=-1, keepdim=True).to(
119+
torch.uint8
120+
)
121+
else:
122+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(
123+
torch.uint8
124+
)
125+
packed = quantized[::2] << 4 | quantized[1::2]
126+
if quant_storage != torch.uint8:
127+
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
128+
return packed, absmax.float()
129+
130+
131+
def quantize_4bit(
132+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
133+
) -> tuple[torch.Tensor, torch.Tensor]:
134+
torch._check_is_size(blocksize)
135+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
136+
torch._check(
137+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
138+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
139+
)
140+
141+
n = A.numel()
142+
143+
# TODO: Support when weight matrix is not divisible by blocksize
144+
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
145+
146+
blocks = -(n // -(blocksize * 2))
147+
148+
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
149+
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
150+
151+
if quant_type == "fp4":
152+
triton_kernels.quantize_4bit_blockwise_triton(A, blocksize, _FP4_QUANT_TABLE, blocks, absmax, out)
153+
else:
154+
triton_kernels.quantize_4bit_blockwise_triton(A, blocksize, _NF4_QUANT_TABLE, blocks, absmax, out)
155+
packed = out
156+
157+
if quant_storage != torch.uint8:
158+
packed = out.squeeze().view(quant_storage).unsqueeze(1)
159+
160+
return packed, absmax.float()
161+
162+
163+
def dequantize_4bit(
164+
A: torch.Tensor,
165+
absmax: torch.Tensor,
166+
blocksize: int,
167+
quant_type: str,
168+
shape: Sequence[int],
169+
dtype: torch.dtype,
170+
) -> torch.Tensor:
171+
torch._check_is_size(blocksize)
172+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
173+
torch._check(
174+
dtype in [torch.bfloat16, torch.float16, torch.float32],
175+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
176+
)
177+
# torch._check(
178+
# A.dtype == torch.uint8,
179+
# lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}",
180+
# )
181+
# Check if this is fine and fast
182+
if A.dtype != torch.uint8:
183+
A = A.squeeze().view(torch.uint8).unsqueeze(1)
184+
185+
out = torch.empty(shape, dtype=dtype, device=A.device)
186+
187+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
188+
return out
189+
190+
191+
def dequantize_4bit_inplace(
192+
A: torch.Tensor,
193+
absmax: torch.Tensor,
194+
blocksize: int,
195+
quant_type: str,
196+
shape: Sequence[int],
197+
dtype: torch.dtype,
198+
out: torch.Tensor,
199+
) -> None:
200+
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
201+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
202+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
203+
204+
205+
def gemv_4bit(
206+
A: torch.Tensor,
207+
B: torch.Tensor,
208+
shapeB: Sequence[int],
209+
absmax: torch.Tensor,
210+
code: torch.Tensor,
211+
blocksize: int,
212+
) -> torch.Tensor:
213+
# TODO: We need to determine whether `code` is NF4, FP4, or other.
214+
# Right now we assume NF4, as this is the only one supported on CPU.
215+
quant_type = "fp4" if code[1] > 0 else "nf4"
216+
B_dq = dequantize_4bit(B, absmax, blocksize, quant_type, shapeB, A.dtype)
217+
218+
# For some reason directly passing code causes errors in some cases like:
219+
# tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=128-uint8-fp32-fc1-fp4-DQ_True-xpu]
220+
#
221+
# B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device)
222+
# code = code.to(A.device)
223+
# if B.dtype != torch.uint8:
224+
# B = B.squeeze().view(torch.uint8).unsqueeze(1)
225+
226+
# triton_kernels._dequantize_4bit_impl_passing_code(
227+
# B,
228+
# absmax,
229+
# blocksize,
230+
# code,
231+
# dtype=A.dtype,
232+
# out=B_dq,
233+
# )
234+
235+
# User called gemv with B.t(), so we need to transpose it back.
236+
# if B.shape[0] == 1:
237+
# B_dq = B_dq.t()
238+
239+
return torch.nn.functional.linear(
240+
A,
241+
B_dq,
242+
bias=None,
243+
)
244+
245+
246+
if triton_available:
247+
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(quantize_blockwise)
248+
register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(dequantize_blockwise_inplace)
249+
register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(dequantize_blockwise)
250+
register_kernel("bitsandbytes::quantize_4bit", "xpu")(quantize_4bit)
251+
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(dequantize_4bit_inplace)
252+
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(dequantize_4bit)
253+
register_kernel("bitsandbytes::gemv_4bit", "xpu")(gemv_4bit)
254+
else:
255+
warnings.warn("XPU available, but trtion package is missing.")

0 commit comments

Comments
 (0)