diff --git a/bitsandbytes/backends/triton/triton_kernels.py b/bitsandbytes/backends/triton/kernels_4bit.py similarity index 78% rename from bitsandbytes/backends/triton/triton_kernels.py rename to bitsandbytes/backends/triton/kernels_4bit.py index 03ffa187d..0e94f49e8 100644 --- a/bitsandbytes/backends/triton/triton_kernels.py +++ b/bitsandbytes/backends/triton/kernels_4bit.py @@ -4,167 +4,6 @@ import triton.language as tl -# @triton.autotune( -# configs=[ -# # triton.Config({'SPLIT_SIZE': 64}), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128}), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), -# triton.Config({"SPLIT_SIZE": 256}), -# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32), -# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), -# triton.Config({"SPLIT_SIZE": 512}), -# # triton.Config({'SPLIT_SIZE': 1024}), -# ], -# key=["num_paired_elements", "QUANT_BLOCK"], -# ) -@triton.jit -def dequant_8bit_kernel( - a_ptr, - c_ptr, - quant_ptr, - absmax_ptr, - num_paired_elements, - QUANT_BLOCK: tl.constexpr, - SPLIT_SIZE: tl.constexpr, -): - pid = tl.program_id(axis=0) - block_start = pid * SPLIT_SIZE - offsets = block_start + tl.arange(0, SPLIT_SIZE) - mask = offsets < num_paired_elements - - a = tl.load(a_ptr + offsets, mask) - a = a.to(tl.uint8) - - # apply conversion - scaled_int8 = tl.load(quant_ptr + a, mask) - - abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK - abs_offsets = offsets // QUANT_BLOCK - mask_blocked = offsets < abs_blocks_lim - - absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked) - # apply scales - out_dq = scaled_int8 * absmax - - offs = block_start + tl.arange(0, SPLIT_SIZE) - mask = offs < num_paired_elements - tl.store(c_ptr + offs, out_dq, mask) - - -def dequant_int8_blockwise( - A_nf4: torch.Tensor, - quant_state_code: torch.Tensor, - absmax: torch.Tensor, - out: torch.Tensor, - quant_blocksize: int = 64, -): - number_of_paired_elements = A_nf4.numel() - - SPLIT_SIZE = 256 - # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) - grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) - dequant_8bit_kernel[grid]( - A_nf4, - out, - quant_state_code, - absmax, - number_of_paired_elements, - quant_blocksize, - SPLIT_SIZE, - ) - return out - - -# @triton.autotune( -# configs=[ -# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), -# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), -# triton.Config({"SPLIT_NUM_BLOCKS": 1}), -# triton.Config({"SPLIT_NUM_BLOCKS": 2}), -# ], -# key=["n_elements"], -# ) -@triton.jit -def quantize_blockwise_kernel( - A_ptr, - code_ptr, - absmax_ptr, - out_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - CODE_SIZE: tl.constexpr, - SPLIT_NUM_BLOCKS: tl.constexpr, -): - block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS - thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) - - offsets = block_start_idx * BLOCK_SIZE + thread_idx - mask = offsets < n_elements - - A = tl.load(A_ptr + offsets, mask=mask, other=0.0) - - # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) - A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE)) - - # Calculating absamax for each block - absmax = tl.max(tl.abs(A_reshaped), axis=1) - tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) - - A_normalized = A_reshaped / absmax[:, None] - A_normalized = tl.clamp(A_normalized, -1.0, 1.0) - - lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32) - upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) - - for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter - pivot = (lower_pivot + upper_pivot) // 2 - val = tl.load(code_ptr + pivot) - is_higher = A_normalized > val # code[pivot] - lower_pivot = tl.where(is_higher, pivot, lower_pivot) - upper_pivot = tl.where(is_higher, upper_pivot, pivot) - - # Choose closest level - lower_val = tl.load(code_ptr + lower_pivot) - upper_val = tl.load(code_ptr + upper_pivot) - lower_dist = tl.abs(A_normalized - lower_val) - upper_dist = tl.abs(A_normalized - upper_val) - quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) - - # too slow approach - # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) - # quantized = tl.argmin(diff, axis=2).to(tl.uint8) - - quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) - tl.store(out_ptr + offsets, quantized_flat, mask=mask) - - -def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out): - n = A.numel() - - split_num_blocks = 1 - grid = (triton.cdiv(blocks, split_num_blocks),) - # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) - quantize_blockwise_kernel[grid]( - A_ptr=A, - code_ptr=code, - absmax_ptr=absmax, - out_ptr=quantized_out, - n_elements=n, - BLOCK_SIZE=blocksize, - CODE_SIZE=code.numel(), - SPLIT_NUM_BLOCKS=split_num_blocks, - ) - - return quantized_out, absmax - - # Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4 # @triton.autotune( # configs=[ @@ -587,7 +426,7 @@ def dequant_nf4_kernel( tl.store(c_ptr + offs, out_dq, mask) -def _dequantize_4bit_impl( +def dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, @@ -611,7 +450,7 @@ def _dequantize_4bit_impl( dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE) -def _dequantize_4bit_impl_passing_code( +def dequantize_4bit_impl_passing_code( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, diff --git a/bitsandbytes/backends/triton/kernels_8bit_quant.py b/bitsandbytes/backends/triton/kernels_8bit_quant.py new file mode 100644 index 000000000..42f97b83c --- /dev/null +++ b/bitsandbytes/backends/triton/kernels_8bit_quant.py @@ -0,0 +1,238 @@ +import torch + +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# # triton.Config({'SPLIT_SIZE': 64}), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128}), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_SIZE": 256}), +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32), +# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32), +# triton.Config({"SPLIT_SIZE": 512}), +# # triton.Config({'SPLIT_SIZE': 1024}), +# ], +# key=["num_paired_elements", "QUANT_BLOCK"], +# ) +@triton.jit +def dequant_8bit_kernel( + a_ptr, + c_ptr, + quant_ptr, + absmax_ptr, + num_paired_elements, + QUANT_BLOCK: tl.constexpr, + SPLIT_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * SPLIT_SIZE + offsets = block_start + tl.arange(0, SPLIT_SIZE) + mask = offsets < num_paired_elements + + a = tl.load(a_ptr + offsets, mask) + a = a.to(tl.uint8) + + # apply conversion + scaled_int8 = tl.load(quant_ptr + a, mask) + + abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK + abs_offsets = offsets // QUANT_BLOCK + mask_blocked = offsets < abs_blocks_lim + + absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked) + # apply scales + out_dq = scaled_int8 * absmax + + offs = block_start + tl.arange(0, SPLIT_SIZE) + mask = offs < num_paired_elements + tl.store(c_ptr + offs, out_dq, mask) + + +def dequant_8bit_blockwise( + a: torch.Tensor, + absmax: torch.Tensor, + quant_state_code: torch.Tensor, + quant_blocksize: int = 64, + dtype: torch.dtype = None, + out: torch.Tensor = None, +): + number_of_paired_elements = a.numel() + if out is None: + if dtype is None: + raise ValueError("If out is None, dtype must be specified") + out = torch.empty_like(a, dtype=dtype, device=a.device) + + SPLIT_SIZE = 256 + # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),) + grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),) + dequant_8bit_kernel[grid]( + a, + out, + quant_state_code, + absmax, + number_of_paired_elements, + quant_blocksize, + SPLIT_SIZE, + ) + return out + + +# @triton.autotune( +# configs=[ +# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32), +# triton.Config({"SPLIT_NUM_BLOCKS": 1}), +# triton.Config({"SPLIT_NUM_BLOCKS": 2}), +# ], +# key=["n_elements"], +# ) +@triton.jit +def quantize_8bit_blockwise_kernel( + A_ptr, + code_ptr, + absmax_ptr, + out_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + CODE_SIZE: tl.constexpr, + SPLIT_NUM_BLOCKS: tl.constexpr, +): + block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS + thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE) + + offsets = block_start_idx * BLOCK_SIZE + thread_idx + mask = offsets < n_elements + + A = tl.load(A_ptr + offsets, mask=mask, other=0.0) + + # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) + A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE)) + + # Calculating absamax for each block + absmax = tl.max(tl.abs(A_reshaped), axis=1) + tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax) + + A_normalized = A_reshaped / absmax[:, None] + A_normalized = tl.clamp(A_normalized, -1.0, 1.0) + + lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32) + upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) + + for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter + pivot = (lower_pivot + upper_pivot) // 2 + val = tl.load(code_ptr + pivot) + is_higher = A_normalized > val # code[pivot] + lower_pivot = tl.where(is_higher, pivot, lower_pivot) + upper_pivot = tl.where(is_higher, upper_pivot, pivot) + + # Choose closest level + lower_val = tl.load(code_ptr + lower_pivot) + upper_val = tl.load(code_ptr + upper_pivot) + lower_dist = tl.abs(A_normalized - lower_val) + upper_dist = tl.abs(A_normalized - upper_val) + quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) + + # too slow approach + # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :]) + # quantized = tl.argmin(diff, axis=2).to(tl.uint8) + + quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,)) + tl.store(out_ptr + offsets, quantized_flat, mask=mask) + + +def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None): + n = A.numel() + blocks = -(n // -blocksize) + + if absmax is None: + absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) + if out is None: + out = torch.empty_like(A.flatten(), dtype=torch.uint8) + + split_num_blocks = 1 + grid = (triton.cdiv(blocks, split_num_blocks),) + # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),) + quantize_8bit_blockwise_kernel[grid]( + A_ptr=A, + code_ptr=code, + absmax_ptr=absmax, + out_ptr=out, + n_elements=n, + BLOCK_SIZE=blocksize, + CODE_SIZE=code.numel(), + SPLIT_NUM_BLOCKS=split_num_blocks, + # num_warps=1, + # num_stages=2, + ) + out = out.reshape(A.shape) + + return out, absmax + + +@triton.jit +def quantize_8bit_blockwise_core( + a, + qmap_ptr, + CODE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + N_PER_TH: tl.constexpr, +): + # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS) + a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE)) + + # Calculating absamax for each block + absmax = tl.max(tl.abs(a_reshaped), axis=1) + + a_normalized = a_reshaped / absmax[:, None] + a_normalized = tl.clamp(a_normalized, -1.0, 1.0) + + lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32) + upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32) + + # ceil(log2(code_size)) = 8, actually, in general case should be input parameter + for _ in range(8): + pivot = (lower_pivot + upper_pivot) // 2 + val = tl.load(qmap_ptr + pivot) + is_higher = a_normalized > val # code[pivot] + lower_pivot = tl.where(is_higher, pivot, lower_pivot) + upper_pivot = tl.where(is_higher, upper_pivot, pivot) + + # Choose closest level + lower_val = tl.load(qmap_ptr + lower_pivot) + upper_val = tl.load(qmap_ptr + upper_pivot) + lower_dist = tl.abs(a_normalized - lower_val) + upper_dist = tl.abs(a_normalized - upper_val) + quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8) + + quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,)) + return quantized_flat, absmax + + +@triton.jit +def dequant_8bit_kernel_util( + codes_ptr, + offsets, + qmap_ptr, + absmax_ptr, + mask, + BLOCK_SIZE: tl.constexpr, +): + codes = tl.load(codes_ptr + offsets, mask, other=0).to(tl.uint8) + abs_offsets = offsets // BLOCK_SIZE + absmax = tl.load(absmax_ptr + abs_offsets, mask=mask, other=0.0, eviction_policy="evict_last") + + # apply conversion + scaled_int8 = tl.load(qmap_ptr + codes, mask) + # apply scales + out_dq = scaled_int8 * absmax + return out_dq diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py new file mode 100644 index 000000000..530ef472d --- /dev/null +++ b/bitsandbytes/backends/triton/kernels_optim.py @@ -0,0 +1,722 @@ +import math +from typing import Optional + +import torch + +import triton +import triton.language as tl +from triton.language.extra import libdevice + +from .kernels_8bit_quant import ( + dequant_8bit_blockwise, + dequant_8bit_kernel_util, + quantize_8bit_blockwise_core, + quantize_blockwise_triton, +) + +########################################### +# Pure torch implementation for reference # +########################################### + + +@torch.compile +def _dequantize_blockwise_pytorch( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, +) -> torch.Tensor: + """ + Pure PyTorch reference implementation for block-wise dequantization. + """ + if A.numel() == 0: + return torch.empty_like(A, dtype=dtype) + + A_flat = A.flatten() + num_elements = A_flat.numel() + + dequantized_flat = code.to(A.device)[A_flat.long()].to(dtype) + + num_blocks = math.ceil(num_elements / blocksize) + pad_len = num_blocks * blocksize - num_elements + if pad_len > 0: + dequantized_flat = torch.nn.functional.pad(dequantized_flat, (0, pad_len)) + + dequantized_blocks = dequantized_flat.reshape(num_blocks, blocksize) + + rescaled_blocks = dequantized_blocks * absmax.unsqueeze(1).to(dtype) + + rescaled_flat = rescaled_blocks.flatten() + if pad_len > 0: + rescaled_flat = rescaled_flat[:-pad_len] + + return rescaled_flat.reshape(A.shape) + + +@torch.compile +def _quantize_blockwise_pytorch( + A: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pure PyTorch reference implementation for block-wise quantization. + """ + if A.numel() == 0: + return torch.empty_like(A, dtype=torch.uint8), torch.empty(0, dtype=torch.float32, device=A.device) + + A_flat = A.flatten() + num_elements = A_flat.numel() + + num_blocks = math.ceil(num_elements / blocksize) + + pad_len = num_blocks * blocksize - num_elements + if pad_len > 0: + A_flat = torch.nn.functional.pad(A_flat, (0, pad_len)) + + A_blocks = A_flat.reshape(num_blocks, blocksize) + + absmax = torch.max(torch.abs(A_blocks), dim=1, keepdim=True)[0] + absmax[absmax == 0] = 1.0 + + scaled_blocks = A_blocks / absmax + + # Inefficient but straightforward quantization, takes a lot of memory + diff = torch.abs(scaled_blocks.unsqueeze(2) - code.to(A.device)) + quantized_indices = torch.argmin(diff, dim=2).to(torch.uint8) + + quantized_flat = quantized_indices.flatten() + if pad_len > 0: + quantized_flat = quantized_flat[:-pad_len] + + return quantized_flat.reshape(A.shape), absmax.flatten() + + +# Main updated function +def optimizer_update_8bit_blockwise_pytorch( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, # ADEMIX + alpha: float, # ADEMIX + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float, + gnorm_scale: float, + skip_zeros: bool, + # ADEMIX + n: int, + *, + optimizer_name: str, +) -> None: + """ + Pure PyTorch implementation of the 8-bit block-wise optimizer update step. + This version ensures high-precision updates for float16 parameters. + """ + if skip_zeros: + raise ValueError("skip_zeros is not supported on XPU yet.") + + blocksize = 256 + + with torch.no_grad(): + # Dequantize states to perform updates in 32-bit precision + if optimizer_name == "ademamix" and absmax1.ndim == 2: + # For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked. + s1_1_fp32 = _dequantize_blockwise_pytorch(state1[0], absmax1[0], qmap1, blocksize, torch.float32) + s1_2_fp32 = _dequantize_blockwise_pytorch(state1[1], absmax1[1], qmap1, blocksize, torch.float32) + state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32]) + else: + state1_fp32 = _dequantize_blockwise_pytorch(state1, absmax1, qmap1, blocksize, torch.float32) + + state2_fp32 = None + if state2 is not None: + state2_fp32 = _dequantize_blockwise_pytorch(state2, absmax2, qmap2, blocksize, torch.float32) + + grad = g.float() * gnorm_scale + + # Create a 32-bit copy of the parameter for high-precision updates + p_fp32 = p.data.float() + + if optimizer_name == "adam": + state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = 1.0 - beta2**step + + denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1) + + elif optimizer_name == "ademamix": + m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1] + nu_fp32 = state2_fp32 + + m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3) + nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = math.sqrt(1.0 - beta2**step) + + update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps) + + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + p_fp32.add_(update, alpha=-lr) + state1_fp32 = torch.stack([m1_fp32, m2_fp32]) + + elif optimizer_name == "momentum": + grad.add_(p_fp32, alpha=weight_decay) + if step == 1: + state1_fp32.copy_(grad) + else: + state1_fp32.mul_(beta1).add_(grad) + p_fp32.add_(state1_fp32, alpha=-lr) + + elif optimizer_name == "rmsprop": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + elif optimizer_name == "lion": + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1)) + p_fp32.add_(update_dir, alpha=-lr) + + state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2) + + elif optimizer_name == "adagrad": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.addcmul_(grad, grad, value=1.0) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + else: + raise NotImplementedError( + f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available." + ) + + # Copy the updated 32-bit parameter back to the original tensor + p.data.copy_(p_fp32) + + # Re-quantize states and update state tensors in-place + if optimizer_name == "ademamix": + new_m1_8bit, new_absmax_m1 = _quantize_blockwise_pytorch(state1_fp32[0], qmap1, blocksize) + new_m2_8bit, new_absmax_m2 = _quantize_blockwise_pytorch(state1_fp32[1], qmap1, blocksize) + state1[0].copy_(new_m1_8bit) + state1[1].copy_(new_m2_8bit) + absmax1[0].copy_(new_absmax_m1) + absmax1[1].copy_(new_absmax_m2) + + new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + else: + new_state1_8bit, new_absmax1 = _quantize_blockwise_pytorch(state1_fp32, qmap1, blocksize) + state1.copy_(new_state1_8bit) + absmax1.copy_(new_absmax1) + + if state2_fp32 is not None: + new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + + +####################################### +# Mixed torch + triton implementation # +####################################### + + +# Much more memory efficient due to using triton for quantization/dequantization +def optimizer_update_8bit_blockwise_triton_quant( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, # ADEMIX + alpha: float, # ADEMIX + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float, + gnorm_scale: float, + skip_zeros: bool, + # ADEMIX + n: int, + *, + optimizer_name: str, +) -> None: + """ + Pure PyTorch implementation of the 8-bit block-wise optimizer update step. + This version ensures high-precision updates for float16 parameters. + """ + if skip_zeros and not torch.any(g): + return + + blocksize = 256 + grad = g.float() * gnorm_scale + + with torch.no_grad(): + # Create a 32-bit copy of the parameter for high-precision updates + p_fp32 = p.data.float() + + # Dequantize states to perform updates in 32-bit precision + if optimizer_name == "ademamix" and absmax1.ndim == 2: + # For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked. + s1_1_fp32 = dequant_8bit_blockwise(state1[0], absmax1[0], qmap1, blocksize, dtype=torch.float32) + s1_2_fp32 = dequant_8bit_blockwise(state1[1], absmax1[1], qmap1, blocksize, dtype=torch.float32) + state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32]) + else: + state1_fp32 = dequant_8bit_blockwise(state1, absmax1, qmap1, blocksize, dtype=torch.float32) + + state2_fp32 = None + if state2 is not None: + state2_fp32 = dequant_8bit_blockwise(state2, absmax2, qmap2, blocksize, dtype=torch.float32) + + # Apply optimizer-specific update logic + if optimizer_name == "adam": + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = 1.0 - beta2**step + + denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps) + p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1) + + elif optimizer_name == "ademamix": + m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1] + nu_fp32 = state2_fp32 + + m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1) + m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3) + nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + + bias_correction1 = 1.0 - beta1**step + bias_correction2 = math.sqrt(1.0 - beta2**step) + + update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps) + + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + p_fp32.add_(update, alpha=-lr) + state1_fp32 = torch.stack([m1_fp32, m2_fp32]) + + elif optimizer_name == "momentum": + grad.add_(p_fp32, alpha=weight_decay) + if step == 1: + state1_fp32.copy_(grad) + else: + state1_fp32.mul_(beta1).add_(grad) + p_fp32.add_(state1_fp32, alpha=-lr) + + elif optimizer_name == "rmsprop": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + elif optimizer_name == "lion": + if weight_decay > 0.0: + p_fp32.mul_(1.0 - lr * weight_decay) + + update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1)) + p_fp32.add_(update_dir, alpha=-lr) + + state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2) + + elif optimizer_name == "adagrad": + grad.add_(p_fp32, alpha=weight_decay) + state1_fp32.addcmul_(grad, grad, value=1.0) + p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr) + + else: + raise NotImplementedError( + f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available." + ) + + # Copy the updated 32-bit parameter back to the original tensor + p.data.copy_(p_fp32) + + # Re-quantize states and update state tensors in-place + if optimizer_name == "ademamix": + new_m1_8bit, new_absmax_m1 = quantize_blockwise_triton(state1_fp32[0], qmap1, blocksize) + new_m2_8bit, new_absmax_m2 = quantize_blockwise_triton(state1_fp32[1], qmap1, blocksize) + state1[0].copy_(new_m1_8bit) + state1[1].copy_(new_m2_8bit) + absmax1[0].copy_(new_absmax_m1) + absmax1[1].copy_(new_absmax_m2) + + new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + else: + new_state1_8bit, new_absmax1 = quantize_blockwise_triton(state1_fp32, qmap1, blocksize) + state1.copy_(new_state1_8bit) + absmax1.copy_(new_absmax1) + + if state2_fp32 is not None: + new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize) + state2.copy_(new_state2_8bit) + absmax2.copy_(new_absmax2) + + +######################### +# Triton implementation # +######################### + +MOMENTUM = 0 +RMSPROP = 1 +ADAGRAD = 2 +ADAM = 3 +# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels +LION = 4 +ADEMAMIX = 5 + +name2optimizer_id = { + "momentum": MOMENTUM, + "rmsprop": RMSPROP, + "adagrad": ADAGRAD, + "adam": ADAM, + "lion": LION, + "ademamix": ADEMAMIX, +} + + +@triton.jit +def _optimizer_update_1state_8bit_blockwise_triton_kernel( + # Tensors + p_ptr, + g_ptr, + state1_ptr, + state2_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + beta3, + alpha, + eps: tl.constexpr, + step, + lr, + qmap1_ptr, + qmap2_ptr, + absmax1_ptr, + absmax2_ptr, + weight_decay, + gnorm_scale, + # Meta-parameters + n_elements, + BLOCK_SIZE_N: tl.constexpr, + N_PER_TH: tl.constexpr, + OPTIMIZER_ID: tl.constexpr, +): + """ + Triton kernel for 8-bit optimizers that use one momentum state. + Supports: Momentum, RMSprop, Adagrad, Lion. + """ + # 1. Boilerplate: pid, offsets, mask + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH) + mask = offsets < n_elements + + # 2. Load and dequantize tensors + g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale + p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + s1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + + # 3. Optimizer-specific updates + # LION + if weight_decay > 0.0 and OPTIMIZER_ID == 2: + p *= 1.0 - lr * weight_decay + # Apply weight decay for momentum, rmsprop, adagrad + elif weight_decay > 0.0: + g += p * weight_decay + + # Momentum update + if OPTIMIZER_ID == 0: # MOMENTUM + if step == 1: + s1 = g + else: + s1 = s1 * beta1 + g + p -= lr * s1 + + # RMSprop update + elif OPTIMIZER_ID == 1: # RMSPROP + s1 = s1 * beta1 + (1.0 - beta1) * g * g + p -= lr * (g / (tl.sqrt(s1) + eps)) + + # Adagrad update + elif OPTIMIZER_ID == 2: # ADAGRAD + s1 += g * g + p -= lr * (g / (tl.sqrt(s1) + eps)) + + # Lion update + elif OPTIMIZER_ID == 4: # LION + val = s1 * beta1 + (1.0 - beta1) * g + update = tl.where(val > 0.0, 1.0, tl.where(val < 0.0, -1.0, 0.0)) + p -= lr * update + s1 = s1 * beta2 + (1.0 - beta2) * g + + # 4. Store updated parameter and requantized state + tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) + s1_codes, new_absmax1 = quantize_8bit_blockwise_core(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + offsets, s1_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1) + + +@triton.jit +def _optimizer_update_2state_8bit_blockwise_triton_kernel( + # Tensors + p_ptr, + g_ptr, + state1_ptr, + state2_ptr, + beta1: tl.constexpr, + beta2: tl.constexpr, + # ademamix changes alpha and beta3 + beta3, + # ademamix changes alpha and beta3 + alpha, + eps: tl.constexpr, + step, + lr, + qmap1_ptr, + qmap2_ptr, + absmax1_ptr, + absmax2_ptr, + weight_decay: tl.constexpr, + gnorm_scale: tl.constexpr, + # Meta-parameters + n_elements, + BLOCK_SIZE_N: tl.constexpr, + N_PER_TH: tl.constexpr, + OPTIMIZER_ID: tl.constexpr, +): + """ + Triton kernel for 8-bit optimizers that use two momentum states. + Supports: Adam, AdEMAMix. + """ + # 1. Boilerplate: pid, offsets, mask + pid = tl.program_id(axis=0) + block_start_idx = pid * N_PER_TH + offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH) + mask = offsets < n_elements + + # 2. Load and dequantize tensors + g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale + p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + + # 3. Optimizer-specific updates + if OPTIMIZER_ID == 3: # ADAM + s1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + s2 = dequant_8bit_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) + + s1 = s1 * beta1 + (1.0 - beta1) * g + s2 = s2 * beta2 + (1.0 - beta2) * g * g + + bias_correction1 = 1.0 - libdevice.pow(beta1, step) + bias_correction2 = 1.0 - libdevice.pow(beta2, step) + + if weight_decay > 0.0: + p *= 1.0 - lr * weight_decay + + denom = tl.sqrt(s2) / tl.sqrt(bias_correction2) + eps + p -= (lr / bias_correction1) * (s1 / denom) + + # Store updated parameter + tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) + + # Requantize and store states + s1_codes, new_absmax1 = quantize_8bit_blockwise_core(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + offsets, s1_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1) + + s2_codes, new_absmax2 = quantize_8bit_blockwise_core(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state2_ptr + offsets, s2_codes, mask=mask) + tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax2) + + elif OPTIMIZER_ID == 5: # ADEMAMIX + # AdEMAMix has a stacked state1 (m1, m2) and state2 (nu) + m1 = dequant_8bit_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N) + m2 = dequant_8bit_kernel_util( + state1_ptr + n_elements, offsets, qmap1_ptr, absmax1_ptr + n_elements // BLOCK_SIZE_N, mask, BLOCK_SIZE_N + ) + nu = dequant_8bit_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N) + + m1 = m1 * beta1 + (1.0 - beta1) * g + m2 = m2 * beta3 + (1.0 - beta3) * g + nu = nu * beta2 + (1.0 - beta2) * g * g + + bias_correction1 = 1.0 - libdevice.pow(beta1, step) + bias_correction2 = tl.sqrt(1.0 - libdevice.pow(beta2, step)) + + update = (m1 / bias_correction1 + alpha * m2) / (tl.sqrt(nu) / bias_correction2 + eps) + + if weight_decay > 0.0: + p *= 1.0 - lr * weight_decay + + p -= lr * update + + # Store updated parameter + tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask) + + # Requantize and store all three states + m1_codes, new_absmax_m1 = quantize_8bit_blockwise_core(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + offsets, m1_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_m1) + + m2_codes, new_absmax_m2 = quantize_8bit_blockwise_core(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state1_ptr + n_elements + offsets, m2_codes, mask=mask) + tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N, new_absmax_m2) + + nu_codes, new_absmax_nu = quantize_8bit_blockwise_core(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH) + tl.store(state2_ptr + offsets, nu_codes, mask=mask) + tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_nu) + + +def optimizer_update_1state_8bit_blockwise( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: torch.Tensor, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: torch.Tensor, + absmax1: torch.Tensor, + absmax2: torch.Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + n: int = 0, # Deprecated: n is ignored, p.numel() is used instead. + *, + optimizer_id: int, +): + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported on XPU yet") + + BLOCK_SIZE = 256 + if n % BLOCK_SIZE != 0: + raise ValueError(f"Matrix size ({n}) must be a multiple of BLOCK_SIZE ({BLOCK_SIZE}) for block-wise updates.") + N_PER_TH = 1 # Number of blocks processed per thread. + grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),) + + _optimizer_update_1state_8bit_blockwise_triton_kernel[grid]( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + p.numel(), + BLOCK_SIZE_N=BLOCK_SIZE, + N_PER_TH=N_PER_TH, + OPTIMIZER_ID=optimizer_id, + num_warps=2, + ) + + +def optimizer_update_2state_8bit_blockwise( + p: torch.Tensor, + g: torch.Tensor, + state1: torch.Tensor, + state2: torch.Tensor, + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: torch.Tensor, + absmax1: torch.Tensor, + absmax2: torch.Tensor, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + n: int = 0, # Deprecated: n is ignored, p.numel() is used instead. + *, + optimizer_id: int, +): + if skip_zeros: + raise NotImplementedError("skip_zeros is not supported on XPU yet") + + if optimizer_id == ADEMAMIX: + # Handle AdEMAMix's stacked state tensors + if state1.dim() < 2 or state1.shape[0] != 2: + raise ValueError( + f"For ademamix, state1 must be a stacked tensor of shape (2, ...), but got {state1.shape}" + ) + if absmax1.dim() < 2 or absmax1.shape[0] != 2: + raise ValueError( + f"For ademamix, absmax1 must be a stacked tensor of shape (2, ...), but got {absmax1.shape}" + ) + + BLOCK_SIZE = 256 + N_PER_TH = 1 # Number of blocks processed per thread. + grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),) + + _optimizer_update_2state_8bit_blockwise_triton_kernel[grid]( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + p.numel(), + BLOCK_SIZE_N=BLOCK_SIZE, + N_PER_TH=N_PER_TH, + OPTIMIZER_ID=optimizer_id, + num_warps=2, + ) diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py index 058c2747d..d3cd9136a 100644 --- a/bitsandbytes/backends/triton/ops.py +++ b/bitsandbytes/backends/triton/ops.py @@ -1,8 +1,9 @@ from collections.abc import Sequence +from functools import partial import torch -from . import triton_kernels +from . import kernels_4bit, kernels_8bit_quant, kernels_optim # currently codes unused, kept for reference # Should be the same for quant/dequant @@ -16,19 +17,9 @@ def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) # torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}") - - n = A.numel() - blocks = -(n // -blocksize) - - absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype) - out = torch.empty_like(A.flatten(), dtype=torch.uint8) - with torch_accelerator_module.device(A.device): - triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out) - - out = out.reshape(A.shape) - - return out, absmax.float() + out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize) + return out, absmax.float() def dequantize_blockwise( @@ -37,22 +28,24 @@ def dequantize_blockwise( torch._check_is_size(blocksize) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") # torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}") - - out = torch.empty_like(A, dtype=dtype, device=A.device) with torch_accelerator_module.device(A.device): - triton_kernels.dequant_int8_blockwise( + out = kernels_8bit_quant.dequant_8bit_blockwise( A, - code, absmax, - out, + code, blocksize, + dtype=dtype, ) - return out def dequantize_blockwise_inplace( - A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, ) -> None: torch._check_is_size(blocksize) torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") @@ -61,12 +54,13 @@ def dequantize_blockwise_inplace( torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") with torch_accelerator_module.device(A.device): - triton_kernels.dequant_int8_blockwise( + kernels_8bit_quant.dequant_8bit_blockwise( A, - code, absmax, - out, + code, blocksize, + dtype=dtype, + out=out, ) @@ -91,7 +85,7 @@ def quantize_4bit( out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8) with torch_accelerator_module.device(A.device): - triton_kernels.quantize_4bit_blockwise_triton( + kernels_4bit.quantize_4bit_blockwise_triton( A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out ) packed = out @@ -125,9 +119,8 @@ def dequantize_4bit( A = A.squeeze().view(torch.uint8).unsqueeze(1) out = torch.empty(shape, dtype=dtype, device=A.device) - with torch_accelerator_module.device(A.device): - triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) return out @@ -144,7 +137,7 @@ def dequantize_4bit_inplace( torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}") torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") with torch_accelerator_module.device(A.device): - triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) + kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out) def gemv_4bit( @@ -161,7 +154,7 @@ def gemv_4bit( B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device) with torch_accelerator_module.device(A.device): - triton_kernels._dequantize_4bit_impl_passing_code( + kernels_4bit.dequantize_4bit_impl_passing_code( B, absmax, blocksize, @@ -170,8 +163,32 @@ def gemv_4bit( out=B_dq_triton, ) - return torch.nn.functional.linear( - A, - B_dq_triton, - bias=None, - ) + return torch.nn.functional.linear( + A, + B_dq_triton, + bias=None, + ) + + +# optimizer_update_8bit_blockwise = kernels_optim.optimizer_update_8bit_blockwise_pytorch +# optimizer_update_8bit_blockwise = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_pytorch) # 60ms +# optimizer_update_8bit_blockwise = kernels_optim.optimizer_update_8bit_blockwise_triton_quant #2.8ms +# optimizer_update_8bit_blockwise = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_triton_quant) # 2.3ms + +# adam_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="adam") +# momentum_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="momentum") +# rmsprop_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="rmsprop") +# lion_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="lion") +# adagrad_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="adagrad") +# ademamix_8bit_blockwise_grad = partial(optimizer_update_8bit_blockwise, optimizer_name="ademamix") + +# ~0.95ms for adam +update_1state = kernels_optim.optimizer_update_1state_8bit_blockwise +update_2state = kernels_optim.optimizer_update_2state_8bit_blockwise +momentum_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["momentum"]) +rmsprop_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["rmsprop"]) +lion_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["lion"]) +adagrad_8bit_blockwise_grad = partial(update_1state, optimizer_id=kernels_optim.name2optimizer_id["adagrad"]) + +ademamix_8bit_blockwise_grad = partial(update_2state, optimizer_id=kernels_optim.name2optimizer_id["ademamix"]) +adam_8bit_blockwise_grad = partial(update_2state, optimizer_id=kernels_optim.name2optimizer_id["adam"]) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9b446a2de..503d32002 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -13,6 +13,14 @@ from torch import Tensor from typing_extensions import deprecated +from bitsandbytes.backends.triton.ops import ( + adagrad_8bit_blockwise_grad, + adam_8bit_blockwise_grad, + ademamix_8bit_blockwise_grad, + lion_8bit_blockwise_grad, + momentum_8bit_blockwise_grad, + rmsprop_8bit_blockwise_grad, +) from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib @@ -84,34 +92,34 @@ str2optimizer8bit_blockwise = { "adam": ( - lib.cadam_8bit_blockwise_grad_fp32, - lib.cadam_8bit_blockwise_grad_fp16, - lib.cadam_8bit_blockwise_grad_bf16, + adam_8bit_blockwise_grad, # lib.cadam_8bit_blockwise_grad_fp32, + adam_8bit_blockwise_grad, # lib.cadam_8bit_blockwise_grad_fp16, + adam_8bit_blockwise_grad, # lib.cadam_8bit_blockwise_grad_fp16, ), "momentum": ( - lib.cmomentum_8bit_blockwise_grad_fp32, - lib.cmomentum_8bit_blockwise_grad_fp16, - lib.cmomentum_8bit_blockwise_grad_bf16, + momentum_8bit_blockwise_grad, + momentum_8bit_blockwise_grad, + momentum_8bit_blockwise_grad, ), "rmsprop": ( - lib.crmsprop_8bit_blockwise_grad_fp32, - lib.crmsprop_8bit_blockwise_grad_fp16, - lib.crmsprop_8bit_blockwise_grad_bf16, + rmsprop_8bit_blockwise_grad, + rmsprop_8bit_blockwise_grad, + rmsprop_8bit_blockwise_grad, ), "lion": ( - lib.clion_8bit_blockwise_grad_fp32, - lib.clion_8bit_blockwise_grad_fp16, - lib.clion_8bit_blockwise_grad_bf16, + lion_8bit_blockwise_grad, + lion_8bit_blockwise_grad, + lion_8bit_blockwise_grad, ), "adagrad": ( - lib.cadagrad_8bit_blockwise_grad_fp32, - lib.cadagrad_8bit_blockwise_grad_fp16, - lib.cadagrad_8bit_blockwise_grad_bf16, + adagrad_8bit_blockwise_grad, + adagrad_8bit_blockwise_grad, + adagrad_8bit_blockwise_grad, ), "ademamix": ( - lib.cademamix_8bit_blockwise_grad_fp32, - lib.cademamix_8bit_blockwise_grad_fp16, - lib.cademamix_8bit_blockwise_grad_bf16, + ademamix_8bit_blockwise_grad, + ademamix_8bit_blockwise_grad, + ademamix_8bit_blockwise_grad, ), } @@ -422,8 +430,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]): for t in tensors: # NULL pointers and paged tensors are OK. if t is not None and not getattr(t, "is_paged", False): - on_gpu &= t.is_cuda - gpu_ids.add(t.device.index) + on_gpu &= t.device.type != "cpu" + gpu_ids.add((t.device.type, t.device.index)) if not on_gpu: raise RuntimeError( @@ -1466,28 +1474,53 @@ def optimizer_update_8bit_blockwise( is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) - with _cuda_device_of(g): + # print("p device: ", p.device, " g device: ", g.device) + # print("p device type: ", p.device, " g device type: ", g.device) + if p.device.type == "xpu": optim_func( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), + p, + g, + state1, + state2, + float(beta1), + float(beta2), + float(beta3), + float(alpha), + float(eps), + int(step), + float(lr), + qmap1, + qmap2, + absmax1, + absmax2, + float(weight_decay), + float(gnorm_scale), + bool(skip_zeros), + int(g.numel()), ) + else: + with _cuda_device_of(g): + optim_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index ee1781a8b..36537be04 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -10,6 +10,7 @@ import torch import bitsandbytes.functional as F +from bitsandbytes.utils import sync_gpu class MockArgs: @@ -289,11 +290,11 @@ def step(self, closure=None): self.prefetch_state(p) self.update_step(group, p, gindex, pindex) - torch.cuda.synchronize() + sync_gpu(p) if self.is_paged: # all paged operations are asynchronous, we need # to sync to make sure all tensors are in the right state - torch.cuda.synchronize() + sync_gpu(loss) return loss diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 7920e2188..a3b043ba0 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -209,3 +209,10 @@ def unpack_tensor_to_dict(tensor_data): LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3} INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()} + + +def sync_gpu(t: torch.Tensor): + if t.device.type == "cuda": + torch.cuda.synchronize() + elif t.device.type == "xpu": + torch.xpu.synchronize() diff --git a/tests/test_optim.py b/tests/test_optim.py index 75e5a1714..0a998ba3e 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -11,7 +11,8 @@ import bitsandbytes as bnb import bitsandbytes.functional as F -from tests.helpers import describe_dtype, id_formatter +from bitsandbytes.utils import sync_gpu +from tests.helpers import describe_dtype, get_available_devices, id_formatter # import apex @@ -168,7 +169,8 @@ def rm_path(path): @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2")) -def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): +@pytest.mark.parametrize("device", get_available_devices(), ids=id_formatter("device")) +def test_optimizer32bit(dim1, dim2, gtype, optim_name, device): if optim_name.startswith("paged_") and sys.platform == "win32": pytest.skip("Paged optimizers can have issues on Windows.") @@ -176,7 +178,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): pytest.skip() if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() @@ -191,7 +193,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): atol, rtol = 1e-4, 1e-3 for i in range(k): - g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() @@ -201,7 +203,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): for name1, name2 in str2statenames[optim_name]: torch.testing.assert_close( torch_optimizer.state[p1][name1], - bnb_optimizer.state[p2][name2].cuda(), + bnb_optimizer.state[p2][name2].to(device), atol=atol, rtol=rtol, ) @@ -247,7 +249,8 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name): @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype) -def test_global_config(requires_cuda, dim1, dim2, gtype): +@pytest.mark.parametrize("device", get_available_devices()) +def test_global_config(dim1, dim2, gtype, device): if dim1 == 1 and dim2 == 1: return p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1 @@ -263,9 +266,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8) bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3]) - p1 = p1.cuda() - p2 = p2.cuda() - p3 = p3.cuda() + p1 = p1.to(device) + p2 = p2.to(device) + p3 = p3.to(device) adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps) @@ -275,9 +278,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): atol, rtol = 1e-4, 1e-3 for i in range(50): - g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 - g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 - g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001 + g1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001 + g2 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001 + g3 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001 p1.grad = g1 p2.grad = g2 p3.grad = g3 @@ -302,13 +305,14 @@ def test_global_config(requires_cuda, dim1, dim2, gtype): @pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1")) -def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): +@pytest.mark.parametrize("device", get_available_devices()) +def test_optimizer8bit(dim1, dim2, gtype, optim_name, device): torch.set_printoptions(precision=6) if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 p2 = p1.clone() p1 = p1.float() blocksize = 256 @@ -330,12 +334,12 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): relerrors = [] for i in range(50): - g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01 p1.grad = g.clone().float() p2.grad = g.clone() - bnb_optimizer.step() torch_optimizer.step() + bnb_optimizer.step() # since Lion can have pretty noisy updates where things lie at the boundary assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0) @@ -368,7 +372,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name): ) num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0 - # assert num_not_close.sum().item() < 20 + assert num_not_close.sum().item() < 20 dequant_states.append(s1.clone()) err = torch.abs(p1 - p2) @@ -549,25 +553,25 @@ def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits): @pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt")) @pytest.mark.benchmark -def test_benchmark_blockwise(dim1, dim2, gtype, optim_name): +def test_benchmark_blockwise(dim1, dim2, gtype, optim_name, device): if dim1 == 1 and dim2 == 1: return - p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 bnb_optimizer = str2optimizers[optim_name][1]([p1]) - g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01 + g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01 p1.grad = g total_steps = 500 for i in range(total_steps): if i == total_steps // 5: # 100 iterations for burn-in - torch.cuda.synchronize() + sync_gpu(p1) t0 = time.time() bnb_optimizer.step() - torch.cuda.synchronize() + sync_gpu(p1) s = time.time() - t0 print("") params = (total_steps - total_steps // 5) * dim1 * dim2