Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr

from ..._ops import register_kernel
from ...cextension import HIP_ENVIRONMENT, lib
from ...cextension import ROCM_WARP_SIZE_64, lib


@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
Expand Down Expand Up @@ -211,7 +211,7 @@ def _get_col_absmax(
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)

if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down Expand Up @@ -269,7 +269,7 @@ def _(
def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None:
if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down Expand Up @@ -303,7 +303,7 @@ def _dequantize_blockwise_impl(
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down Expand Up @@ -385,7 +385,7 @@ def _dequantize_4bit_impl(
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
if HIP_ENVIRONMENT:
if ROCM_WARP_SIZE_64:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
Expand Down
9 changes: 8 additions & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import torch

from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch
from bitsandbytes.cuda_specs import (
CUDASpecs,
get_cuda_specs,
get_cuda_version_tuple,
get_rocm_gpu_arch,
get_rocm_warpsize,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -298,6 +304,7 @@ def get_native_library() -> BNBNativeLibrary:


ROCM_GPU_ARCH = get_rocm_gpu_arch()
ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False

HIP_ENVIRONMENT = False
BNB_BACKEND = "CPU"
Expand Down
26 changes: 26 additions & 0 deletions bitsandbytes/cuda_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,29 @@ def get_rocm_gpu_arch() -> str:
""",
)
return "unknown"


def get_rocm_warpsize() -> int:
"""Get ROCm warp size."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
if match:
return int(match.group(1))
else:
# default to 64 to be safe
return 64
else:
# nvidia cards always use 32 warp size
return 32
except Exception as e:
logger.error(f"Could not detect ROCm warp size: {e}. Defaulting to 64. (some 4-bit functions may not work!)")
if torch.cuda.is_available():
logger.warning(
"""
ROCm warp size detection failed despite ROCm being available.
""",
)
return 64
14 changes: 7 additions & 7 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict

from .cextension import HIP_ENVIRONMENT, lib
from .cextension import ROCM_WARP_SIZE_64, lib

name2qmap = {}

Expand Down Expand Up @@ -806,7 +806,7 @@ def quantize_fp4(
quant_storage=torch.uint8,
):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)


Expand All @@ -819,7 +819,7 @@ def quantize_nf4(
quant_storage=torch.uint8,
):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)


Expand Down Expand Up @@ -857,7 +857,7 @@ def quantize_4bit(
"""

if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128

input_shape = A.shape

Expand Down Expand Up @@ -912,7 +912,7 @@ def dequantize_fp4(
blocksize: Optional[int] = None,
) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")


Expand All @@ -924,7 +924,7 @@ def dequantize_nf4(
blocksize: Optional[int] = None,
) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")


Expand Down Expand Up @@ -964,7 +964,7 @@ def dequantize_4bit(
"""

if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128

if quant_state is None:
assert absmax is not None and out is not None
Expand Down
4 changes: 2 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F

import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
Expand Down Expand Up @@ -221,7 +221,7 @@ def __new__(
data = torch.empty(0)

if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128

self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.blocksize = blocksize
Expand Down
8 changes: 6 additions & 2 deletions csrc/common_hip.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#pragma once

#define BNB_WARP_SIZE warpSize
#ifdef __GFX9__
#define BNB_WARP_SIZE 64
#else
#define BNB_WARP_SIZE 32
#endif

// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs
#define BNB_MAX_THREADS_PER_SM 2048
#define BNB_MAX_THREADS_PER_CU 2048
#define BNB_BF16_AVAILABLE true
69 changes: 41 additions & 28 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -1881,7 +1881,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
// rowStats [rows]
// out [rows, cols]
template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024)
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {

// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
Expand Down Expand Up @@ -1945,7 +1945,7 @@ __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStat
}

template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024)
__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
using BlockReduceT = hipcub::BlockReduce<float, THREADS>;

Expand Down Expand Up @@ -2057,11 +2057,6 @@ __global__ void kdequant_mm_int32_fp16(
#define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
#if defined(__GFX9__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
Expand All @@ -2082,9 +2077,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
const int local_row_idx = rowidx[offset];

const int warp_id = threadIdx.x / WARP_SIZE;
const int warp_idx = threadIdx.x % WARP_SIZE;
const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS;
const int warp_id = threadIdx.x / BNB_WARP_SIZE;
const int warp_idx = threadIdx.x % BNB_WARP_SIZE;
const int warp_offset = (warp_id*BNB_WARP_SIZE)*SPMM_ITEMS;
const int num_items = BITS == 8 ? 8 : 8;
int idx_col_B = warp_offset;
int local_idx_col_B_offset = 0;
Expand All @@ -2104,7 +2099,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
}

// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
// we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
// we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
// added 3 bytes = 6 values between warps should reduce bank conflicts
__shared__ half smem_dequant_stats[SMEM_SIZE];
Expand Down Expand Up @@ -2657,15 +2652,15 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
{

// per threadblock:
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, WARP_SIZE> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE];
// 1xBNB_WARP_SIZE * BNB_WARP_SIZEx4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, BNB_WARP_SIZE> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE];

const int warp_idx = threadIdx.x / WARP_SIZE;
const int warp_lane = threadIdx.x % WARP_SIZE;
const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx;
const int warp_idx = threadIdx.x / BNB_WARP_SIZE;
const int warp_lane = threadIdx.x % BNB_WARP_SIZE;
const int row_B = (THREADS/BNB_WARP_SIZE)*blockIdx.x + warp_idx;
const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit/2;
float local_C = 0.0f;
Expand All @@ -2684,7 +2679,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc

// A: [1, K]
// B: [M, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit)
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE*num_values_4bit)
{
const int inner_idx_halved = inner_idx/2;

Expand Down Expand Up @@ -2996,23 +2991,29 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
#endif

MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
#endif

MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
#endif

MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
Expand All @@ -3021,23 +3022,29 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
#endif

MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
#endif

MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
#endif

MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit)
Expand All @@ -3046,23 +3053,29 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
#endif

MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
#endif

MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
#if BNB_WARP_SIZE == 32
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
#endif

template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
Expand Down
Loading
Loading