|
| 1 | +from collections.abc import Sequence |
| 2 | +import warnings |
| 3 | + |
| 4 | +import torch |
| 5 | + |
| 6 | +from ..._ops import register_kernel |
| 7 | +from .utils import _FP4_QUANT_TABLE, _NF4_QUANT_TABLE |
| 8 | + |
| 9 | +try: |
| 10 | + from . import triton_kernels |
| 11 | + |
| 12 | + triton_available = True |
| 13 | +except ImportError as e: |
| 14 | + print("Import error:", e) |
| 15 | + triton_available = False |
| 16 | + |
| 17 | + |
| 18 | +# torch compile: |
| 19 | +# 1.53s call tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional::test_dynamic_blockwise_quantization[signed=F-256-nested=T-bf16-xpu] |
| 20 | +# |
| 21 | +# triton: |
| 22 | +# 1.07s call tests/test_functional.py::Test8BitBlockwiseQuantizeFunctional::test_dynamic_blockwise_quantization[signed=F-256-nested=T-bf16-xpu] |
| 23 | +@torch.compile |
| 24 | +def quantize_blockwise_torch(A, code, blocksize): |
| 25 | + n = A.numel() |
| 26 | + blocks = -(n // -blocksize) |
| 27 | + |
| 28 | + absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) |
| 29 | + quantized_out = torch.empty_like(A.flatten(), dtype=torch.uint8) |
| 30 | + |
| 31 | + rem = n % blocksize |
| 32 | + has_rem = rem > 0 |
| 33 | + blocks = n // blocksize + has_rem |
| 34 | + A_reshaped = A.reshape(n) |
| 35 | + A_com = A_reshaped[: n - rem] |
| 36 | + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) |
| 37 | + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] |
| 38 | + scaled_A = torch.clamp(A_com_reshaped / absmax[: blocks - has_rem].view(-1, 1), -1, 1) |
| 39 | + scaled_A = scaled_A.reshape(-1) |
| 40 | + if has_rem: |
| 41 | + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() |
| 42 | + scaled_A_rem = torch.clamp((A_reshaped[n - rem :] / absmax[-1]), -1, 1) |
| 43 | + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) |
| 44 | + |
| 45 | + diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) |
| 46 | + quantized_out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) |
| 47 | + quantized_out = quantized_out.reshape(A.shape) |
| 48 | + return quantized_out, absmax |
| 49 | + |
| 50 | + |
| 51 | +def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: |
| 52 | + torch._check_is_size(blocksize) |
| 53 | + # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") |
| 54 | + |
| 55 | + n = A.numel() |
| 56 | + blocks = -(n // -blocksize) |
| 57 | + |
| 58 | + absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) |
| 59 | + out = torch.empty_like(A.flatten(), dtype=torch.uint8) |
| 60 | + |
| 61 | + triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) |
| 62 | + out = out.reshape(A.shape) |
| 63 | + |
| 64 | + return out, absmax |
| 65 | + |
| 66 | + |
| 67 | +def dequantize_blockwise( |
| 68 | + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype |
| 69 | +) -> torch.Tensor: |
| 70 | + torch._check_is_size(blocksize) |
| 71 | + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") |
| 72 | + # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") |
| 73 | + |
| 74 | + out = torch.empty_like(A, dtype=dtype, device=A.device) |
| 75 | + triton_kernels.dequant_int8_blockwise( |
| 76 | + A, |
| 77 | + code, |
| 78 | + absmax, |
| 79 | + out, |
| 80 | + blocksize, |
| 81 | + ) |
| 82 | + |
| 83 | + return out |
| 84 | + |
| 85 | + |
| 86 | +def dequantize_blockwise_inplace( |
| 87 | + A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor |
| 88 | +) -> None: |
| 89 | + torch._check_is_size(blocksize) |
| 90 | + torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") |
| 91 | + torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}") |
| 92 | + torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}") |
| 93 | + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") |
| 94 | + |
| 95 | + triton_kernels.dequant_int8_blockwise( |
| 96 | + A, |
| 97 | + code, |
| 98 | + absmax, |
| 99 | + out, |
| 100 | + blocksize, |
| 101 | + ) |
| 102 | + |
| 103 | + |
| 104 | +# torch compile |
| 105 | +# 1.01s call tests/test_functional.py::TestQuantize4BitFunctional::test_4bit_quant[64-fp4-fp32-xpu] |
| 106 | +# |
| 107 | +# triton |
| 108 | +# 0.80s call tests/test_functional.py::TestQuantize4BitFunctional::test_4bit_quant[64-fp4-fp32-xpu] |
| 109 | +@torch.compile |
| 110 | +def quantize_4bit_torch( |
| 111 | + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype |
| 112 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 113 | + # Divide into blocks and normalize |
| 114 | + blocks = A.reshape(-1, blocksize) |
| 115 | + absmax = blocks.abs().max(dim=1).values.float() |
| 116 | + scaled = blocks / absmax.unsqueeze(-1) |
| 117 | + if quant_type == "fp4": |
| 118 | + quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _FP4_QUANT_TABLE), dim=-1, keepdim=True).to( |
| 119 | + torch.uint8 |
| 120 | + ) |
| 121 | + else: |
| 122 | + quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - _NF4_QUANT_TABLE), dim=-1, keepdim=True).to( |
| 123 | + torch.uint8 |
| 124 | + ) |
| 125 | + packed = quantized[::2] << 4 | quantized[1::2] |
| 126 | + if quant_storage != torch.uint8: |
| 127 | + packed = packed.squeeze().view(quant_storage).unsqueeze(1) |
| 128 | + return packed, absmax.float() |
| 129 | + |
| 130 | + |
| 131 | +def quantize_4bit( |
| 132 | + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype |
| 133 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 134 | + torch._check_is_size(blocksize) |
| 135 | + # torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on CPU, got {quant_type}") |
| 136 | + torch._check( |
| 137 | + A.dtype in [torch.bfloat16, torch.float16, torch.float32], |
| 138 | + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", |
| 139 | + ) |
| 140 | + |
| 141 | + n = A.numel() |
| 142 | + |
| 143 | + # TODO: Support when weight matrix is not divisible by blocksize |
| 144 | + torch._check(n % blocksize == 0, lambda: f"n must be divisible by blocksize, got {n} and {blocksize}") |
| 145 | + |
| 146 | + blocks = -(n // -(blocksize * 2)) |
| 147 | + |
| 148 | + absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype) |
| 149 | + out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) |
| 150 | + |
| 151 | + if quant_type == "fp4": |
| 152 | + triton_kernels.quantize_4bit_blockwise_triton(A, blocksize, _FP4_QUANT_TABLE, blocks, absmax, out) |
| 153 | + else: |
| 154 | + triton_kernels.quantize_4bit_blockwise_triton(A, blocksize, _NF4_QUANT_TABLE, blocks, absmax, out) |
| 155 | + packed = out |
| 156 | + |
| 157 | + if quant_storage != torch.uint8: |
| 158 | + packed = out.squeeze().view(quant_storage).unsqueeze(1) |
| 159 | + |
| 160 | + return packed, absmax.float() |
| 161 | + |
| 162 | + |
| 163 | +def dequantize_4bit( |
| 164 | + A: torch.Tensor, |
| 165 | + absmax: torch.Tensor, |
| 166 | + blocksize: int, |
| 167 | + quant_type: str, |
| 168 | + shape: Sequence[int], |
| 169 | + dtype: torch.dtype, |
| 170 | +) -> torch.Tensor: |
| 171 | + torch._check_is_size(blocksize) |
| 172 | + # torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4 on XPU, got {quant_type}") |
| 173 | + torch._check( |
| 174 | + dtype in [torch.bfloat16, torch.float16, torch.float32], |
| 175 | + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", |
| 176 | + ) |
| 177 | + # torch._check( |
| 178 | + # A.dtype == torch.uint8, |
| 179 | + # lambda: f"Blockwise 4bit dequantization on XPU only supports uint8 storage, got {A.dtype}", |
| 180 | + # ) |
| 181 | + # Check if this is fine and fast |
| 182 | + if A.dtype != torch.uint8: |
| 183 | + A = A.squeeze().view(torch.uint8).unsqueeze(1) |
| 184 | + |
| 185 | + out = torch.empty(shape, dtype=dtype, device=A.device) |
| 186 | + |
| 187 | + triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) |
| 188 | + return out |
| 189 | + |
| 190 | + |
| 191 | +def dequantize_4bit_inplace( |
| 192 | + A: torch.Tensor, |
| 193 | + absmax: torch.Tensor, |
| 194 | + blocksize: int, |
| 195 | + quant_type: str, |
| 196 | + shape: Sequence[int], |
| 197 | + dtype: torch.dtype, |
| 198 | + out: torch.Tensor, |
| 199 | +) -> None: |
| 200 | + torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") |
| 201 | + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") |
| 202 | + triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) |
| 203 | + |
| 204 | + |
| 205 | +def gemv_4bit( |
| 206 | + A: torch.Tensor, |
| 207 | + B: torch.Tensor, |
| 208 | + shapeB: Sequence[int], |
| 209 | + absmax: torch.Tensor, |
| 210 | + code: torch.Tensor, |
| 211 | + blocksize: int, |
| 212 | +) -> torch.Tensor: |
| 213 | + # TODO: We need to determine whether `code` is NF4, FP4, or other. |
| 214 | + # Right now we assume NF4, as this is the only one supported on CPU. |
| 215 | + quant_type = "fp4" if code[1] > 0 else "nf4" |
| 216 | + B_dq = dequantize_4bit(B, absmax, blocksize, quant_type, shapeB, A.dtype) |
| 217 | + |
| 218 | + # For some reason directly passing code causes errors in some cases like: |
| 219 | + # tests/test_functional.py::TestQuantize4BitFunctional::test_gemv_4bit[dim=128-uint8-fp32-fc1-fp4-DQ_True-xpu] |
| 220 | + # |
| 221 | + # B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device) |
| 222 | + # code = code.to(A.device) |
| 223 | + # if B.dtype != torch.uint8: |
| 224 | + # B = B.squeeze().view(torch.uint8).unsqueeze(1) |
| 225 | + |
| 226 | + # triton_kernels._dequantize_4bit_impl_passing_code( |
| 227 | + # B, |
| 228 | + # absmax, |
| 229 | + # blocksize, |
| 230 | + # code, |
| 231 | + # dtype=A.dtype, |
| 232 | + # out=B_dq, |
| 233 | + # ) |
| 234 | + |
| 235 | + # User called gemv with B.t(), so we need to transpose it back. |
| 236 | + # if B.shape[0] == 1: |
| 237 | + # B_dq = B_dq.t() |
| 238 | + |
| 239 | + return torch.nn.functional.linear( |
| 240 | + A, |
| 241 | + B_dq, |
| 242 | + bias=None, |
| 243 | + ) |
| 244 | + |
| 245 | + |
| 246 | +if triton_available: |
| 247 | + register_kernel("bitsandbytes::quantize_blockwise", "xpu")(quantize_blockwise) |
| 248 | + register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")(dequantize_blockwise_inplace) |
| 249 | + register_kernel("bitsandbytes::dequantize_blockwise", "xpu")(dequantize_blockwise) |
| 250 | + register_kernel("bitsandbytes::quantize_4bit", "xpu")(quantize_4bit) |
| 251 | + register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(dequantize_4bit_inplace) |
| 252 | + register_kernel("bitsandbytes::dequantize_4bit", "xpu")(dequantize_4bit) |
| 253 | + register_kernel("bitsandbytes::gemv_4bit", "xpu")(gemv_4bit) |
| 254 | +else: |
| 255 | + warnings.warn("XPU available, but trtion package is missing.") |
0 commit comments