Skip to content

Commit 16cc220

Browse files
Merge branch 'TimDettmers:main' into galore
2 parents eceed12 + 0c64a0d commit 16cc220

14 files changed

+170
-24
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,10 @@ Bug fixes:
357357
- Addressed a race condition in kEstimateQuantiles, enhancing the reliability of quantile estimation in concurrent environments (@pnunna93, #1061).
358358
- Fixed various minor issues, including typos in code comments and documentation, to improve code clarity and prevent potential confusion (@Brian Vaughan, #1063).
359359

360+
#### Backwards Compatibility
361+
- After upgrading from `v0.42` to `v0.43`, when using 4bit quantization, models may generate slightly different outputs (approximately up to the 2nd decimal place) due to a fix in the code. For anyone interested in the details, [see this comment](https://github.com/TimDettmers/bitsandbytes/discussions/1094#discussioncomment-8984069).
362+
363+
360364
#### Internal and Build System Enhancements:
361365
- Implemented several enhancements to the internal and build systems, including adjustments to the CI workflows, portability improvements, and build artifact management. These changes contribute to a more robust and flexible development process, ensuring the library's ongoing quality and maintainability (@rickardp, @akx, @wkpark, @matthewdouglas; #949, #1053, #1045, #1037).
362366

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# `bitsandbytes`
22

3+
[![Downloads](https://static.pepy.tech/badge/bitsandbytes)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/month)](https://pepy.tech/project/bitsandbytes) [![Downloads](https://static.pepy.tech/badge/bitsandbytes/week)](https://pepy.tech/project/bitsandbytes)
4+
35
The `bitsandbytes` library is a lightweight Python wrapper around CUDA custom functions, in particular 8-bit optimizers, matrix multiplication (LLM.int8()), and 8 & 4-bit quantization functions.
46

57
The library includes quantization primitives for 8-bit & 4-bit operations, through `bitsandbytes.nn.Linear8bitLt` and `bitsandbytes.nn.Linear4bit` and 8-bit optimizers through `bitsandbytes.optim` module.

bitsandbytes/diagnostics/cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path
5959
for pth in dir.glob(lib_pattern):
6060
if pth.is_file():
6161
yield pth
62-
except PermissionError:
62+
except (OSError, PermissionError):
6363
pass
6464

6565

bitsandbytes/functional.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,11 +1087,12 @@ def get_4bit_type(typename, device=None, blocksize=64):
10871087
if data is None:
10881088
raise NotImplementedError(f"Typename {typename} not supported")
10891089

1090-
data = Tensor(data)
1091-
data /= data.abs().max()
1090+
data = torch.tensor(data, device=device)
1091+
data.div_(data.abs().max())
1092+
10921093
assert data.numel() == 16
10931094

1094-
return data.to(device)
1095+
return data
10951096

10961097

10971098
def quantize_fp4(

csrc/ops.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(floa
5858
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
5959

6060
if(blocksize == 4096)
61-
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, 0><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
61+
kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE><<<num_blocks, 1024>>>(code, A, absmax, out, rand, rand_offset, n);
6262
else if(blocksize == 2048)
6363
kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, rand, rand_offset, n);
6464
else if(blocksize == 1024)

docs/source/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
title: 8-bit optimizers
1313
- local: algorithms
1414
title: Algorithms
15+
- local: fsdp_qlora
16+
title: FSDP-QLoRA
1517
- local: integrations
1618
title: Integrations
1719
- local: errors

docs/source/fsdp_qlora.md

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# FSDP-QLoRA
2+
3+
FSDP-QLoRA combines data parallelism (FSDP enables sharding model parameters, optimizer states, and gradients across GPUs), 4-bit quantization, and LoRA to train LLMs up to 70B parameters on a dual 24GB GPU system. This technique was released by [Answer.AI](https://www.answer.ai/posts/2024-03-06-fsdp-qlora) in collaboration with bitsandbytes to make training LLMs more efficient and accessible for everyone.
4+
5+
This guide provides a brief guide on how bitsandbytes supports storing quantized weights to enable FSDP-QLoRA, and how to run training with the Hugging Face libraries.
6+
7+
> [!TIP]
8+
> Other changes required for bitsandbytes to support FSDP-QLoRA, such as reconstructing the weights from the quantization metadata and preventing quantizing already quantized weights when they're moved from a CPU to GPU, are documented in this [Pull Request](https://github.com/TimDettmers/bitsandbytes/pull/970) and described in the [Enabling 70B Finetuning on Consumer GPUs](https://www.answer.ai/posts/2024-03-14-fsdp-qlora-deep-dive) blog post. We highly recommend reading these resources for a better understanding of FSDP-QLoRA!
9+
10+
## Quantized data storage
11+
12+
FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase.
13+
14+
```py
15+
import torch
16+
import bitsandbytes as bnb
17+
18+
model = bnb.nn.Linear4bit(
19+
input_features,
20+
output_features,
21+
quant_type="fp4",
22+
quant_storage=torch.uint8,
23+
)
24+
```
25+
26+
With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32.
27+
28+
## Training
29+
30+
bitsandbytes is deeply integrated with the Hugging Face ecosystem, making it easy to use with libraries like [Transformers](https://hf/co/docs/transformers), [PEFT](https://hf/co/docs/peft), and [TRL](https://hf/co/docs/trl).
31+
32+
Before you begin, make sure you have the latest libraries installed.
33+
34+
```bash
35+
pip install -U bitsandbytes accelerate transformers peft trl
36+
```
37+
38+
> [!TIP]
39+
> PEFT provides a configuration file ([fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml)), launch command ([run_peft_qlora_fsdp.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_qlora_fsdp.sh)), and training script ([train.py](https://github.com/huggingface/peft/blob/main/examples/sft/train.py)) for FSDP-QLoRA. To learn more, check out the [Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs](https://huggingface.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) documentation.
40+
41+
The important change that enables FSDP-QLoRA training is the `bnb_4bit_quant_storage` parameter in the [`~transformers.BitsAndBytesConfig`] class. This allows you to set the storage data type of the quantized weights to a float data type.
42+
43+
```py
44+
from transformers import BitsAndBytesConfig
45+
46+
bnb_config = BitsAndBytesConfig(
47+
load_in_4bit=True,
48+
bnb_4bit_quant_type="nf4",
49+
bnb_4bit_compute_dtype=torch.bfloat16,
50+
bnb_4bit_use_double_quant=True,
51+
bnb_4bit_quant_storage=torch.bfloat16,
52+
)
53+
```
54+
55+
Pass the [`~transformers.BitsAndBytesConfig`] to a model to set it up for FSDP-QLoRA. You should set the `torch_dtype` parameter to match `bnb_4bit_quant_storage` so that the [`~nn.Linear4bit`] layers are wrapped identically to the `Linear` layers. If the storage types do not match, then each [`~nn.Linear4bit`] layer is wrapped individually.
56+
57+
```py
58+
from transformers import AutoModelForCausalLM
59+
60+
model = AutoModelForCausalLM.from_pretrained(
61+
"meta-llama/Llama-2-70b",
62+
quantization_config=bnb_config,
63+
torch_dtype=torch.bfloat16,
64+
)
65+
```
66+
67+
Configure the [`~peft.LoraConfig`] class for QLoRA training by setting `target_modules="all-linear"`.
68+
69+
```py
70+
from peft import LoraConfig
71+
72+
peft_config = LoraConfig(
73+
lora_alpha=16,
74+
lora_dropout=0.1,
75+
r=64,
76+
bias="none",
77+
task_type="CAUSAL_LM",
78+
target_modules="all-linear",
79+
)
80+
```
81+
82+
Now you can pass everything to the [`~trl.SFTTrainer`] for training.
83+
84+
```py
85+
from trl import SFTTrainer
86+
87+
trainer = SFTTrainer(
88+
model=model,
89+
train_dataset=dataset,
90+
peft_config=peft_config,
91+
dataset_text_field="text",
92+
max_seq_length=max_seq_length,
93+
tokenizer=tokenizer,
94+
args=training_arguments,
95+
)
96+
trainer.train()
97+
```
98+
99+
## Resources
100+
101+
To learn more about FSDP and QLoRA, check out the following resources:
102+
103+
- The [AnswerDotAI/fsdp_qlora](https://github.com/AnswerDotAI/fsdp_qlora) repository.
104+
- The introductory [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) blog post by Answer.AI.
105+
- For an introduction to FSDP, read the [Introducing PyTorch Fully Sharded Data Parallel (FSDP) API](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api) blog post.
106+
- For more details about QLoRA, take a look at the [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post.

docs/source/installation.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Then locally install the CUDA version you need with this script from bitsandbyte
8484
```bash
8585
wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/install_cuda.sh
8686
# Syntax cuda_install CUDA_VERSION INSTALL_PREFIX EXPORT_TO_BASH
87-
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123}
87+
# CUDA_VERSION in {110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124}
8888
# EXPORT_TO_BASH in {0, 1} with 0=False and 1=True
8989

9090
# For example, the following installs CUDA 11.7 to ~/local/cuda-11.7 and exports the path to your .bashrc

install_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"121": "https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run",
1818
"122": "https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run",
1919
"123": "https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run",
20+
"124": "https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run",
2021
}
2122

2223

install_cuda.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ URL120=https://developer.download.nvidia.com/compute/cuda/12.0.1/local_installer
1111
URL121=https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run
1212
URL122=https://developer.download.nvidia.com/compute/cuda/12.2.2/local_installers/cuda_12.2.2_535.104.05_linux.run
1313
URL123=https://developer.download.nvidia.com/compute/cuda/12.3.2/local_installers/cuda_12.3.2_545.23.08_linux.run
14-
14+
URL124=https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run
1515

1616
CUDA_VERSION=$1
1717
BASE_PATH=$2
@@ -57,8 +57,11 @@ if [[ -n "$CUDA_VERSION" ]]; then
5757
elif [[ "$CUDA_VERSION" -eq "123" ]]; then
5858
URL=$URL123
5959
FOLDER=cuda-12.3
60+
elif [[ "$CUDA_VERSION" -eq "124" ]]; then
61+
URL=$URL124
62+
FOLDER=cuda-12.4
6063
else
61-
echo "argument error: No cuda version passed as input. Choose among versions 92 to 123"
64+
echo "argument error: No cuda version passed as input. Choose among versions 110 to 124"
6265
fi
6366
else
6467
echo "argument error: No cuda version passed as input. Choose among versions 92 to 123"

requirements-ci.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Requirements used for GitHub actions
2-
pytest==7.2.2
3-
einops==0.6.0
4-
lion-pytorch==0.0.6
2+
pytest==8.1.1
3+
einops==0.7.0
4+
lion-pytorch==0.1.4
55
scipy==1.10.1; python_version < "3.9"
6-
scipy==1.11.4; python_version >= "3.9"
6+
scipy==1.12.0; python_version >= "3.9"

requirements-dev.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Requirements used for local development
22
setuptools>=63
3-
pytest~=7.2.2
4-
einops~=0.6.0
5-
wheel~=0.40.0
6-
lion-pytorch~=0.0.6
7-
scipy~=1.11.4
8-
pandas~=2.2.0
9-
matplotlib~=3.8.2
3+
pytest~=8.1.1
4+
einops~=0.7.0
5+
wheel~=0.43.0
6+
lion-pytorch~=0.1.4
7+
scipy~=1.12.0
8+
pandas~=2.2.1
9+
matplotlib~=3.8.3

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import gc
2+
13
import pytest
24
import torch
35

@@ -20,6 +22,13 @@ def pytest_runtest_call(item):
2022
raise
2123

2224

25+
@pytest.hookimpl(trylast=True)
26+
def pytest_runtest_teardown(item, nextitem):
27+
gc.collect()
28+
if torch.cuda.is_available():
29+
torch.cuda.empty_cache()
30+
31+
2332
@pytest.fixture(scope="session")
2433
def requires_cuda() -> bool:
2534
cuda_available = torch.cuda.is_available()

tests/test_functional.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,7 +1928,9 @@ def test_bench_dequantization():
19281928

19291929

19301930
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
1931-
def test_fp4_quant(dtype):
1931+
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
1932+
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096])
1933+
def test_4bit_quant(dtype, quant_type, blocksize):
19321934
vals = list(product([0, 1], repeat=4))
19331935

19341936
code = {}
@@ -1953,17 +1955,33 @@ def test_fp4_quant(dtype):
19531955
code[idx] = result
19541956

19551957
A1 = torch.randn(1024, 1024, device="cuda", dtype=dtype)
1956-
qa, SA = F.quantize_fp4(A1, blocksize=64)
1957-
A2 = F.dequantize_fp4(qa, SA)
1958+
qa, SA = F.quantize_4bit(A1, blocksize=blocksize, quant_type=quant_type)
1959+
A2 = F.dequantize_4bit(qa, SA, blocksize=blocksize, quant_type=quant_type)
19581960

19591961
err = (A1 - A2).abs().float()
19601962
relerr = (err / (A1.abs().float() + 1e-8)).mean()
19611963
idx = err > 1.0
19621964
err = err.mean()
19631965

19641966
assert A2.dtype == dtype
1965-
assert err.item() < 0.1
1966-
assert relerr.item() < 0.28
1967+
1968+
# With larger block sizes, we can expect this to blow up.
1969+
# At blocksize>=1024, don't even bother looking at relerr.
1970+
if blocksize <= 64:
1971+
assert err.item() < 0.1
1972+
assert relerr.item() < 0.28
1973+
elif blocksize <= 256:
1974+
assert err.item() < 0.11
1975+
assert relerr.item() < 0.30
1976+
elif blocksize <= 512:
1977+
assert err.item() < 0.12
1978+
assert relerr.item() < 0.31
1979+
elif quant_type == "fp4":
1980+
# 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
1981+
assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
1982+
else:
1983+
# 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
1984+
assert err.item() < math.log2(blocksize) * 8e-2
19671985

19681986

19691987
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])

0 commit comments

Comments
 (0)