Skip to content

Commit 76885a4

Browse files
Merge pull request #1160 from matthewdouglas/quant4bit-blocksize4096
Fix 4bit quantization with blocksize = 4096
2 parents 2965c76 + a471456 commit 76885a4

File tree

4 files changed

+30
-15
lines changed

4 files changed

+30
-15
lines changed

bitsandbytes/functional.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,11 +1087,12 @@ def get_4bit_type(typename, device=None, blocksize=64):
10871087
if data is None:
10881088
raise NotImplementedError(f"Typename {typename} not supported")
10891089

1090-
data = Tensor(data)
1091-
data /= data.abs().max()
1090+
data = torch.tensor(data, device=device)
1091+
data.div_(data.abs().max())
1092+
10921093
assert data.numel() == 16
10931094

1094-
return data.to(device)
1095+
return data
10951096

10961097

10971098
def quantize_fp4(

csrc/ops.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
5858
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
5959

6060
if(blocksize == 4096)
61-
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
61+
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
6262
else if(blocksize == 2048)
6363
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
6464
else if(blocksize == 1024)

install_cuda.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ def main():
7777
download_path = "/tmp" # default download path
7878

7979
if len(sys.argv) < 2:
80-
print(
81-
"Usage: python install_cuda.py <version/all> [user/system] [download_path]"
82-
)
80+
print("Usage: python install_cuda.py <version/all> [user/system] [download_path]")
8381
sys.exit(1)
8482

8583
version = sys.argv[1]
@@ -100,9 +98,7 @@ def main():
10098
elif version in cuda_versions:
10199
install_cuda(version, base_path, download_path)
102100
else:
103-
print(
104-
f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}"
105-
)
101+
print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}")
106102
sys.exit(1)
107103

108104

tests/test_functional.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,7 +1928,9 @@ def test_bench_dequantization():
19281928

19291929

19301930
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
1931-
def test_fp4_quant(dtype):
1931+
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1932+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
1933+
def test_4bit_quant(dtype, quant_type, blocksize):
19321934
vals = list(product([0, 1], repeat=4))
19331935

19341936
code = {}
@@ -1953,17 +1955,33 @@ def test_fp4_quant(dtype):
19531955
code[idx] = result
19541956

19551957
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
1956-
qa, SA = F.quantize_fp4(A1, blocksize=64)
1957-
A2 = F.dequantize_fp4(qa, SA)
1958+
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
1959+
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
19581960

19591961
err = (A1 - A2).abs().float()
19601962
relerr = (err / (A1.abs().float() + 1e-8)).mean()
19611963
idx = err > 1.0
19621964
err = err.mean()
19631965

19641966
assert A2.dtype == dtype
1965-
assert err.item() < 0.1
1966-
assert relerr.item() < 0.28
1967+
1968+
# With larger block sizes, we can expect this to blow up.
1969+
# At blocksize>=1024, don't even bother looking at relerr.
1970+
if blocksize <= 64:
1971+
assert err.item() < 0.1
1972+
assert relerr.item() < 0.28
1973+
elif blocksize <= 256:
1974+
assert err.item() < 0.11
1975+
assert relerr.item() < 0.30
1976+
elif blocksize <= 512:
1977+
assert err.item() < 0.12
1978+
assert relerr.item() < 0.31
1979+
elif quant_type == "fp4":
1980+
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
1981+
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
1982+
else:
1983+
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
1984+
assert err.item() < math.log2(blocksize) * 8e-2
19671985

19681986

19691987
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])

0 commit comments

Comments
 (0)