Skip to content

Commit 1a39c0a

Browse files
committed
just move files
1 parent 76712f7 commit 1a39c0a

File tree

5 files changed

+253
-241
lines changed

5 files changed

+253
-241
lines changed

bitsandbytes/backends/triton/__init__.py

Whitespace-only changes.

bitsandbytes/backends/triton/ops.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
from collections.abc import Sequence
2+
import warnings
3+
4+
import torch
5+
6+
from .utils import _FP4_QUANT_TABLE, _NF4_QUANT_TABLE
7+
8+
try:
9+
from . import triton_kernels
10+
11+
triton_available = True
12+
except ImportError as e:
13+
print("Import error:", e)
14+
triton_available = False
15+
16+
17+
# torch compile:
18+
# 1.53s call tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional::test_dynamic_blockwise_quantization[signed=F-256-nested=T-bf16-xpu]
19+
#
20+
# triton:
21+
# 1.07s call tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional::test_dynamic_blockwise_quantization[signed=F-256-nested=T-bf16-xpu]
22+
@torch.compile
23+
def quantize_blockwise_torch(A, code, blocksize):
24+
n = A.numel()
25+
blocks = -(n // -blocksize)
26+
27+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
28+
quantized_out = torch.empty_like(A.flatten(), dtype=torch.uint8)
29+
30+
rem = n % blocksize
31+
has_rem = rem > 0
32+
blocks = n // blocksize + has_rem
33+
A_reshaped = A.reshape(n)
34+
A_com = A_reshaped[: n - rem]
35+
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
36+
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
37+
scaled_A = torch.clamp(A_com_reshaped / absmax[: blocks - has_rem].view(-1, 1), -1, 1)
38+
scaled_A = scaled_A.reshape(-1)
39+
if has_rem:
40+
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
41+
scaled_A_rem = torch.clamp((A_reshaped[n - rem :] / absmax[-1]), -1, 1)
42+
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
43+
44+
diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
45+
quantized_out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
46+
quantized_out = quantized_out.reshape(A.shape)
47+
return quantized_out, absmax
48+
49+
50+
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
51+
torch._check_is_size(blocksize)
52+
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
53+
54+
n = A.numel()
55+
blocks = -(n // -blocksize)
56+
57+
absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
58+
out = torch.empty_like(A.flatten(), dtype=torch.uint8)
59+
60+
triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
61+
out = out.reshape(A.shape)
62+
63+
return out, absmax
64+
65+
66+
def dequantize_blockwise(
67+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
68+
) -> torch.Tensor:
69+
torch._check_is_size(blocksize)
70+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
71+
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
72+
73+
out = torch.empty_like(A, dtype=dtype, device=A.device)
74+
triton_kernels.dequant_int8_blockwise(
75+
A,
76+
code,
77+
absmax,
78+
out,
79+
blocksize,
80+
)
81+
82+
return out
83+
84+
85+
def dequantize_blockwise_inplace(
86+
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
87+
) -> None:
88+
torch._check_is_size(blocksize)
89+
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
90+
torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
91+
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
92+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
93+
94+
triton_kernels.dequant_int8_blockwise(
95+
A,
96+
code,
97+
absmax,
98+
out,
99+
blocksize,
100+
)
101+
102+
103+
# torch compile
104+
# 1.01s call tests/test_functional.py::TestQuantize4BitFunctional::test_4bit_quant[64-fp4-fp32-xpu]
105+
#
106+
# triton
107+
# 0.80s call tests/test_functional.py::TestQuantize4BitFunctional::test_4bit_quant[64-fp4-fp32-xpu]
108+
@torch.compile
109+
def quantize_4bit_torch(
110+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
111+
) -> tuple[torch.Tensor, torch.Tensor]:
112+
# Divide into blocks and normalize
113+
blocks = A.reshape(-1, blocksize)
114+
absmax = blocks.abs().max(dim=1).values.float()
115+
scaled = blocks / absmax.unsqueeze(-1)
116+
if quant_type == "fp4":
117+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _FP4_QUANT_TABLE), dim=-1, keepdim=True).to(
118+
torch.uint8
119+
)
120+
else:
121+
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to(
122+
torch.uint8
123+
)
124+
packed = quantized[::2] << 4 | quantized[1::2]
125+
if quant_storage != torch.uint8:
126+
packed = packed.squeeze().view(quant_storage).unsqueeze(1)
127+
return packed, absmax.float()
128+
129+
130+
def quantize_4bit(
131+
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
132+
) -> tuple[torch.Tensor, torch.Tensor]:
133+
torch._check_is_size(blocksize)
134+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}")
135+
torch._check(
136+
A.dtype in [torch.bfloat16, torch.float16, torch.float32],
137+
lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}",
138+
)
139+
140+
n = A.numel()
141+
142+
# TODO: Support when weight matrix is not divisible by blocksize
143+
torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}")
144+
145+
blocks = -(n // -(blocksize * 2))
146+
147+
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
148+
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
149+
150+
if quant_type == "fp4":
151+
triton_kernels.quantize_4bit_blockwise_triton(A, blocksize, _FP4_QUANT_TABLE, blocks, absmax, out)
152+
else:
153+
triton_kernels.quantize_4bit_blockwise_triton(A, blocksize, _NF4_QUANT_TABLE, blocks, absmax, out)
154+
packed = out
155+
156+
if quant_storage != torch.uint8:
157+
packed = out.squeeze().view(quant_storage).unsqueeze(1)
158+
159+
return packed, absmax.float()
160+
161+
162+
def dequantize_4bit(
163+
A: torch.Tensor,
164+
absmax: torch.Tensor,
165+
blocksize: int,
166+
quant_type: str,
167+
shape: Sequence[int],
168+
dtype: torch.dtype,
169+
) -> torch.Tensor:
170+
torch._check_is_size(blocksize)
171+
# torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}")
172+
torch._check(
173+
dtype in [torch.bfloat16, torch.float16, torch.float32],
174+
lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}",
175+
)
176+
# torch._check(
177+
# A.dtype == torch.uint8,
178+
# lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}",
179+
# )
180+
# Check if this is fine and fast
181+
if A.dtype != torch.uint8:
182+
A = A.squeeze().view(torch.uint8).unsqueeze(1)
183+
184+
out = torch.empty(shape, dtype=dtype, device=A.device)
185+
186+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
187+
return out
188+
189+
190+
def dequantize_4bit_inplace(
191+
A: torch.Tensor,
192+
absmax: torch.Tensor,
193+
blocksize: int,
194+
quant_type: str,
195+
shape: Sequence[int],
196+
dtype: torch.dtype,
197+
out: torch.Tensor,
198+
) -> None:
199+
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
200+
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
201+
triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
202+
203+
204+
def gemv_4bit(
205+
A: torch.Tensor,
206+
B: torch.Tensor,
207+
shapeB: Sequence[int],
208+
absmax: torch.Tensor,
209+
code: torch.Tensor,
210+
blocksize: int,
211+
) -> torch.Tensor:
212+
# TODO: We need to determine whether `code` is NF4, FP4, or other.
213+
# Right now we assume NF4, as this is the only one supported on CPU.
214+
quant_type = "fp4" if code[1] > 0 else "nf4"
215+
B_dq = dequantize_4bit(B, absmax, blocksize, quant_type, shapeB, A.dtype)
216+
217+
# For some reason directly passing code causes errors in some cases like:
218+
# tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=128-uint8-fp32-fc1-fp4-DQ_True-xpu]
219+
#
220+
# B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device)
221+
# code = code.to(A.device)
222+
# if B.dtype != torch.uint8:
223+
# B = B.squeeze().view(torch.uint8).unsqueeze(1)
224+
225+
# triton_kernels._dequantize_4bit_impl_passing_code(
226+
# B,
227+
# absmax,
228+
# blocksize,
229+
# code,
230+
# dtype=A.dtype,
231+
# out=B_dq,
232+
# )
233+
234+
# User called gemv with B.t(), so we need to transpose it back.
235+
# if B.shape[0] == 1:
236+
# B_dq = B_dq.t()
237+
238+
return torch.nn.functional.linear(
239+
A,
240+
B_dq,
241+
bias=None,
242+
)
243+

0 commit comments

Comments
 (0)