Skip to content

Commit 849d944

Browse files
Deprecation cleanup (#1669)
* Deprecation cleanup: remove histogram_scatter_add_2d * Deprecation cleanup: vectorwise_mm_dequant * Deprecation cleanup: vectorwise_quant * Remove unused test * Optimizer test cleanup * Deprecations: remove estimate_quantiles, create_quantile_map * Move deprecated test
1 parent 76d3e2b commit 849d944

File tree

10 files changed

+109
-543
lines changed

10 files changed

+109
-543
lines changed

bitsandbytes/functional.py

Lines changed: 0 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -401,23 +401,6 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
401401
return torch.tensor(data, dtype=torch.float32)
402402

403403

404-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
405-
def create_quantile_map(A, total_bits=8):
406-
q = estimate_quantiles(A, num_quantiles=2**total_bits - 1)
407-
q = q.tolist()
408-
q.append(0)
409-
410-
gap = 256 - len(q)
411-
for i in range(gap):
412-
q.append(0)
413-
414-
q.sort()
415-
416-
q = Tensor(q)
417-
q = q / q.abs().max()
418-
return q
419-
420-
421404
def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
422405
"""Verifies that the input tensors are all on the same device.
423406
@@ -474,74 +457,6 @@ def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
474457
return ct.c_void_p(A.data_ptr())
475458

476459

477-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
478-
def estimate_quantiles(
479-
A: Tensor,
480-
out: Optional[torch.Tensor] = None,
481-
offset: float = 1 / 512,
482-
num_quantiles=256,
483-
) -> Tensor:
484-
"""
485-
Estimates 256 equidistant quantiles on the input tensor eCDF.
486-
487-
Uses SRAM-Quantiles algorithm to quickly estimate 256 equidistant quantiles
488-
via the eCDF of the input tensor `A`. This is a fast but approximate algorithm
489-
and the extreme quantiles close to 0 and 1 have high variance / large estimation
490-
errors. These large errors can be avoided by using the offset variable which trims
491-
the distribution. The default offset value of 1/512 ensures minimum entropy encoding -- it
492-
trims 1/512 = 0.2% from each side of the distrivution. An offset value of 0.01 to 0.02
493-
usually has a much lower error but is not a minimum entropy encoding. Given an offset
494-
of 0.02 equidistance points in the range [0.02, 0.98] are used for the quantiles.
495-
496-
Parameters
497-
----------
498-
A : torch.Tensor
499-
The input tensor. Any shape.
500-
out : torch.Tensor
501-
Tensor with the 256 estimated quantiles.
502-
offset : float
503-
The offset for the first and last quantile from 0 and 1. Default: 1/(2*num_quantiles)
504-
num_quantiles : int
505-
The number of equally spaced quantiles.
506-
507-
Returns
508-
-------
509-
torch.Tensor:
510-
The 256 quantiles in float32 datatype.
511-
"""
512-
if A.numel() < 256:
513-
raise NotImplementedError(
514-
f"Quantile estimation needs at least 256 values in the Tensor, but Tensor had only {A.numel()} values.",
515-
)
516-
if num_quantiles > 256:
517-
raise NotImplementedError(
518-
f"Currently only a maximum of 256 equally spaced quantiles are supported, but the argument num_quantiles={num_quantiles}",
519-
)
520-
if num_quantiles < 256 and offset == 1 / (512):
521-
# override default arguments
522-
offset = 1 / (2 * num_quantiles)
523-
524-
if out is None:
525-
out = torch.zeros((256,), dtype=torch.float32, device=A.device)
526-
527-
with _cuda_device_of(A):
528-
is_on_gpu([A, out])
529-
530-
if A.dtype == torch.float32:
531-
lib.cestimate_quantiles_fp32(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
532-
elif A.dtype == torch.float16:
533-
lib.cestimate_quantiles_fp16(get_ptr(A), get_ptr(out), ct.c_float(offset), ct.c_int(A.numel()))
534-
else:
535-
raise NotImplementedError(f"Not supported data type {A.dtype}")
536-
537-
if num_quantiles < 256:
538-
step = round(256 / num_quantiles)
539-
idx = torch.linspace(0, 255, num_quantiles).long().to(A.device)
540-
out = out[idx]
541-
542-
return out
543-
544-
545460
class QuantState:
546461
"""container for quantization state components to work with Params4bit and similar classes"""
547462

@@ -1601,25 +1516,6 @@ def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile:
16011516
return current_gnorm, clip_value, gnorm_scale
16021517

16031518

1604-
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
1605-
def histogram_scatter_add_2d(histogram: Tensor, index1: Tensor, index2: Tensor, source: Tensor):
1606-
assert len(histogram.shape) == 2
1607-
assert histogram.dtype == torch.float32
1608-
assert source.dtype == torch.float32
1609-
assert index1.dtype == torch.int32
1610-
assert index2.dtype == torch.int32
1611-
1612-
assert histogram.device.type == "cuda"
1613-
assert index1.device.type == "cuda"
1614-
assert index2.device.type == "cuda"
1615-
assert source.device.type == "cuda"
1616-
1617-
maxdim1 = ct.c_int32(histogram.shape[0])
1618-
n = ct.c_int32(index1.numel())
1619-
is_on_gpu([histogram, index1, index2, source])
1620-
lib.chistogram_scatter_add_2d(get_ptr(histogram), get_ptr(index1), get_ptr(index2), get_ptr(source), maxdim1, n)
1621-
1622-
16231519
def check_matmul(A, B, out, transposed_A, transposed_B, expected_type=torch.int8):
16241520
if not torch.cuda.is_initialized():
16251521
torch.cuda.init()
@@ -2426,118 +2322,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
24262322
C = 127.0
24272323

24282324

2429-
@deprecated(
2430-
"This function is deprecated and will be removed in a future release. "
2431-
"Consider using `int8_vectorwise_quant` instead.",
2432-
category=FutureWarning,
2433-
)
2434-
def vectorwise_quant(x, dim=1, quant_type="vector"):
2435-
if quant_type == "linear":
2436-
max1 = torch.abs(x).max().float()
2437-
xq = torch.round(x / max1 * 127).to(torch.int8)
2438-
return xq, max1
2439-
elif quant_type in ["vector", "row"]:
2440-
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
2441-
xq = torch.round(x * (C / max1)).to(torch.int8)
2442-
return xq, max1
2443-
elif quant_type == "zeropoint":
2444-
dtype = x.dtype
2445-
x = x.float()
2446-
dyna = x.max() - x.min()
2447-
if dyna == 0:
2448-
dyna = 1
2449-
qx = 255.0 / dyna
2450-
minx = x.min()
2451-
zpx = torch.round(minx * qx)
2452-
x = torch.round(qx * x - zpx) + zpx
2453-
return x, qx
2454-
elif quant_type in ["vector-zeropoint", "row-zeropoint"]:
2455-
dtype = x.dtype
2456-
x = x.float()
2457-
dyna = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)
2458-
dyna[dyna == 0] = 1
2459-
qx = 255.0 / dyna
2460-
minx = torch.amin(x, dim=dim, keepdim=True)
2461-
zpx = torch.round(minx * qx)
2462-
x = torch.round(qx * x - zpx) + zpx
2463-
return x, qx
2464-
elif quant_type == "truncated-vector":
2465-
with torch.no_grad():
2466-
absx = torch.abs(x)
2467-
max1 = torch.amax(absx, dim=dim, keepdim=True)
2468-
max1 = max1 * 0.7
2469-
idx = absx > max1.expand_as(absx)
2470-
sign = torch.sign(x[idx])
2471-
x[idx] = max1.expand_as(absx)[idx] * sign
2472-
xq = torch.round(x / max1 * C).to(torch.int8)
2473-
return xq, max1
2474-
else:
2475-
return None
2476-
2477-
2478-
@deprecated(
2479-
"This function is deprecated and will be removed in a future release.",
2480-
category=FutureWarning,
2481-
)
2482-
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
2483-
if quant_type == "linear":
2484-
norm = S1 * S2 / (C * C)
2485-
# double cast needed to prevent overflows
2486-
return (xq.float() * norm).to(dtype)
2487-
elif quant_type == "zeropoint":
2488-
norm = 1.0 / (S1 * S2)
2489-
return (xq.float() * norm).to(dtype)
2490-
elif quant_type == "row-zeropoint":
2491-
norm = 1.0 / (S1 * S2)
2492-
x = xq.float()
2493-
if len(S1.shape) == 3 and len(x.shape) == 2:
2494-
S1 = S1.squeeze(0)
2495-
if len(S2.shape) == 3 and len(x.shape) == 2:
2496-
S2 = S2.squeeze(0)
2497-
if len(S1.shape) == 2:
2498-
x *= norm
2499-
else:
2500-
x *= norm
2501-
return x.to(dtype)
2502-
elif quant_type == "vector-zeropoint":
2503-
x = xq.float()
2504-
if len(S1.shape) == 3 and len(x.shape) == 2:
2505-
S1 = S1.squeeze(0)
2506-
if len(S2.shape) == 3 and len(x.shape) == 2:
2507-
S2 = S2.squeeze(0)
2508-
if len(S1.shape) == 2:
2509-
x *= 1.0 / S1
2510-
else:
2511-
x *= 1.0 / S1
2512-
x *= 1.0 / S2.t()
2513-
return x.to(dtype)
2514-
elif quant_type == "row":
2515-
x = xq.float()
2516-
if len(S1.shape) == 3 and len(x.shape) == 2:
2517-
S1 = S1.squeeze(0)
2518-
if len(S2.shape) == 3 and len(x.shape) == 2:
2519-
S2 = S2.squeeze(0)
2520-
if len(S1.shape) == 2:
2521-
x *= S1 * S2 / (C * C)
2522-
else:
2523-
x *= S1 * S2 / (C * C)
2524-
return x.to(dtype)
2525-
elif quant_type in ["truncated-vector", "vector"]:
2526-
x = xq.float()
2527-
if len(S1.shape) == 3 and len(x.shape) == 2:
2528-
S1 = S1.squeeze(0)
2529-
if len(S2.shape) == 3 and len(x.shape) == 2:
2530-
S2 = S2.squeeze(0)
2531-
if len(S1.shape) == 2:
2532-
x *= S1 / C
2533-
else:
2534-
x *= S1 / C
2535-
x *= S2 / C
2536-
return x.to(dtype)
2537-
else:
2538-
return None
2539-
2540-
25412325
def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
25422326
quant_state = linear.weight.quant_state
25432327

csrc/kernels.cu

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -357,92 +357,6 @@ __device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadran
357357
}
358358
}
359359

360-
361-
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n)
362-
{
363-
const int tid = threadIdx.x + (blockDim.x*blockIdx.x);
364-
const int numThreads = blockDim.x*gridDim.x;
365-
366-
for(int i = tid; i < n; i+=numThreads)
367-
{
368-
int idx = (index1[i]*maxidx1) + index2[i];
369-
atomicAdd(&histogram[idx], src[i]);
370-
}
371-
}
372-
373-
#define THREADS_ESTIMATE 512
374-
#define NUM_ESTIMATE 8
375-
#define BLOCK_ESTIMATE 4096
376-
377-
template<typename T>
378-
__launch_bounds__(THREADS_ESTIMATE, 1)
379-
__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n)
380-
{
381-
const int n_full = (BLOCK_ESTIMATE*(n/BLOCK_ESTIMATE)) + (n % BLOCK_ESTIMATE == 0 ? 0 : BLOCK_ESTIMATE);
382-
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*BLOCK_ESTIMATE) : BLOCK_ESTIMATE;
383-
const int base_idx = (blockIdx.x * BLOCK_ESTIMATE);
384-
const float reciprocal_num_blocks = 1.0f/(n < 4096 ? 1.0f : (n/BLOCK_ESTIMATE));
385-
386-
T vals[NUM_ESTIMATE];
387-
388-
typedef cub::BlockRadixSort<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::NullType, 4, true, cub::BLOCK_SCAN_RAKING> BlockRadixSort;
389-
typedef cub::BlockLoad<T, THREADS_ESTIMATE, NUM_ESTIMATE, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
390-
391-
__shared__ union {
392-
typename LoadFloat::TempStorage loadf;
393-
typename BlockRadixSort::TempStorage sort;
394-
int smem_qidx[BLOCK_ESTIMATE];
395-
} temp_storage;
396-
397-
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_ESTIMATE)
398-
{
399-
valid_items = n - i > BLOCK_ESTIMATE ? BLOCK_ESTIMATE : n - i;
400-
401-
// do not process half-blocks
402-
if(valid_items < BLOCK_ESTIMATE && n > BLOCK_ESTIMATE){ continue; }
403-
404-
#pragma unroll 4
405-
for(int j = 0; j < NUM_ESTIMATE; j++)
406-
vals[j] = max_val;
407-
408-
__syncthreads();
409-
LoadFloat(temp_storage.loadf).Load(&(A[i]), vals, valid_items);
410-
411-
#pragma unroll 4
412-
for(int j = 0; j < NUM_ESTIMATE; j++)
413-
vals[j] = ((float)vals[j]) * reciprocal_num_blocks;
414-
415-
416-
__syncthreads();
417-
// sort into striped pattern to mitigate bank conflicts
418-
// striped pattern index for thread 0 [0, 1024, 2048, 3096]
419-
// striped pattern index for thread 1 [1, 1025, 2049, 3097]
420-
BlockRadixSort(temp_storage.sort).SortBlockedToStriped(vals);
421-
422-
__syncthreads();
423-
for(int j = threadIdx.x; j < BLOCK_ESTIMATE; j+=blockDim.x)
424-
temp_storage.smem_qidx[j] = -1;
425-
426-
__syncthreads();
427-
428-
if(threadIdx.x < 256)
429-
{
430-
float q_interval = (1.0f-(2.0f*offset))/255.0f;
431-
int local_idx = round(((offset+(threadIdx.x*q_interval))*(valid_items-1)));
432-
temp_storage.smem_qidx[local_idx] = threadIdx.x;
433-
}
434-
435-
__syncthreads();
436-
437-
for(int i = threadIdx.x; i < BLOCK_ESTIMATE; i+=blockDim.x)
438-
{
439-
if(temp_storage.smem_qidx[i] != -1)
440-
atomicAdd(&code[temp_storage.smem_qidx[i]], vals[i/THREADS_ESTIMATE]);
441-
}
442-
}
443-
}
444-
445-
446360
__launch_bounds__(TH, 4)
447361
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n)
448362
{
@@ -2998,9 +2912,6 @@ template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const
29982912
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
29992913
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
30002914

3001-
template __global__ void kEstimateQuantiles(float *__restrict__ const A, float *code, const float offset, const float max_val, const int n);
3002-
template __global__ void kEstimateQuantiles(half *__restrict__ const A, float *code, const float offset, const half max_val, const int n);
3003-
30042915
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
30052916
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
30062917
float* state1, float *unorm, \

csrc/kernels.cuh

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
#define kernels
1111

1212

13-
template<typename T>__global__ void kEstimateQuantiles(T *__restrict__ const A, float *code, const float offset, const T max_val, const int n);
14-
1513
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n);
1614
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n);
1715

@@ -106,10 +104,6 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
106104

107105
template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n);
108106

109-
110-
__global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n);
111-
112-
113107
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
114108

115109
template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16(

0 commit comments

Comments
 (0)