|
| 1 | +# FSDP-QLoRA |
| 2 | + |
| 3 | +FSDP-QLoRA combines data parallelism (FSDP enables sharding model parameters, optimizer states, and gradients across GPUs), 4-bit quantization, and LoRA to train LLMs up to 70B parameters on a dual 24GB GPU system. This technique was released by [Answer.AI](https://www.answer.ai/posts/2024-03-06-fsdp-qlora) in collaboration with bitsandbytes to make training LLMs more efficient and accessible for everyone. |
| 4 | + |
| 5 | +This guide provides a brief guide on how bitsandbytes supports storing quantized weights to enable FSDP-QLoRA, and how to run training with the Hugging Face libraries. |
| 6 | + |
| 7 | +> [!TIP] |
| 8 | +> Other changes required for bitsandbytes to support FSDP-QLoRA, such as reconstructing the weights from the quantization metadata and preventing quantizing already quantized weights when they're moved from a CPU to GPU, are documented in this [Pull Request](https://github.com/TimDettmers/bitsandbytes/pull/970) and described in the [Enabling 70B Finetuning on Consumer GPUs](https://www.answer.ai/posts/2024-03-14-fsdp-qlora-deep-dive) blog post. We highly recommend reading these resources for a better understanding of FSDP-QLoRA! |
| 9 | +
|
| 10 | +## Quantized data storage |
| 11 | + |
| 12 | +FSDP only supports sharding float data types which can be problematic because quantized weights are typically stored as integer data types (uint8). bitsandbytes doesn't have this problem because it uses `StoreChar` to read and write quantized weights regardless of the data type storage. This makes it simple to add a `quant_storage` parameter to the [`~nn.Linear4bit`] and [`~nn.Params4bit`] classes and set it to `torch.uint8` to maintain backward compatibility with the codebase. |
| 13 | + |
| 14 | +```py |
| 15 | +import torch |
| 16 | +import bitsandbytes as bnb |
| 17 | + |
| 18 | +model = bnb.nn.Linear4bit( |
| 19 | + input_features, |
| 20 | + output_features, |
| 21 | + quant_type="fp4", |
| 22 | + quant_storage=torch.uint8, |
| 23 | +) |
| 24 | +``` |
| 25 | + |
| 26 | +With the `quant_storage` parameter, you can select any of the FSDP supported data types to shard [`~nn.Linear4bit`] with such as bfloat16, float16 or float32. |
| 27 | + |
| 28 | +## Training |
| 29 | + |
| 30 | +bitsandbytes is deeply integrated with the Hugging Face ecosystem, making it easy to use with libraries like [Transformers](https://hf/co/docs/transformers), [PEFT](https://hf/co/docs/peft), and [TRL](https://hf/co/docs/trl). |
| 31 | + |
| 32 | +Before you begin, make sure you have the latest libraries installed. |
| 33 | + |
| 34 | +```bash |
| 35 | +pip install -U bitsandbytes accelerate transformers peft trl |
| 36 | +``` |
| 37 | + |
| 38 | +> [!TIP] |
| 39 | +> PEFT provides a configuration file ([fsdp_config_qlora.yaml](https://github.com/huggingface/peft/blob/main/examples/sft/configs/fsdp_config_qlora.yaml)), launch command ([run_peft_qlora_fsdp.sh](https://github.com/huggingface/peft/blob/main/examples/sft/run_peft_qlora_fsdp.sh)), and training script ([train.py](https://github.com/huggingface/peft/blob/main/examples/sft/train.py)) for FSDP-QLoRA. To learn more, check out the [Use PEFT QLoRA and FSDP for finetuning large models on multiple GPUs](https://huggingface.co/docs/peft/main/en/accelerate/fsdp#use-peft-qlora-and-fsdp-for-finetuning-large-models-on-multiple-gpus) documentation. |
| 40 | +
|
| 41 | +The important change that enables FSDP-QLoRA training is the `bnb_4bit_quant_storage` parameter in the [`~transformers.BitsAndBytesConfig`] class. This allows you to set the storage data type of the quantized weights to a float data type. |
| 42 | + |
| 43 | +```py |
| 44 | +from transformers import BitsAndBytesConfig |
| 45 | + |
| 46 | +bnb_config = BitsAndBytesConfig( |
| 47 | + load_in_4bit=True, |
| 48 | + bnb_4bit_quant_type="nf4", |
| 49 | + bnb_4bit_compute_dtype=torch.bfloat16, |
| 50 | + bnb_4bit_use_double_quant=True, |
| 51 | + bnb_4bit_quant_storage=torch.bfloat16, |
| 52 | +) |
| 53 | +``` |
| 54 | + |
| 55 | +Pass the [`~transformers.BitsAndBytesConfig`] to a model to set it up for FSDP-QLoRA. You should set the `torch_dtype` parameter to match `bnb_4bit_quant_storage` so that the [`~nn.Linear4bit`] layers are wrapped identically to the `Linear` layers. If the storage types do not match, then each [`~nn.Linear4bit`] layer is wrapped individually. |
| 56 | + |
| 57 | +```py |
| 58 | +from transformers import AutoModelForCausalLM |
| 59 | + |
| 60 | +model = AutoModelForCausalLM.from_pretrained( |
| 61 | + "meta-llama/Llama-2-70b", |
| 62 | + quantization_config=bnb_config, |
| 63 | + torch_dtype=torch.bfloat16, |
| 64 | +) |
| 65 | +``` |
| 66 | + |
| 67 | +Configure the [`~peft.LoraConfig`] class for QLoRA training by setting `target_modules="all-linear"`. |
| 68 | + |
| 69 | +```py |
| 70 | +from peft import LoraConfig |
| 71 | + |
| 72 | +peft_config = LoraConfig( |
| 73 | + lora_alpha=16, |
| 74 | + lora_dropout=0.1, |
| 75 | + r=64, |
| 76 | + bias="none", |
| 77 | + task_type="CAUSAL_LM", |
| 78 | + target_modules="all-linear", |
| 79 | +) |
| 80 | +``` |
| 81 | + |
| 82 | +Now you can pass everything to the [`~trl.SFTTrainer`] for training. |
| 83 | + |
| 84 | +```py |
| 85 | +from trl import SFTTrainer |
| 86 | + |
| 87 | +trainer = SFTTrainer( |
| 88 | + model=model, |
| 89 | + train_dataset=dataset, |
| 90 | + peft_config=peft_config, |
| 91 | + dataset_text_field="text", |
| 92 | + max_seq_length=max_seq_length, |
| 93 | + tokenizer=tokenizer, |
| 94 | + args=training_arguments, |
| 95 | +) |
| 96 | +trainer.train() |
| 97 | +``` |
| 98 | + |
| 99 | +## Resources |
| 100 | + |
| 101 | +To learn more about FSDP and QLoRA, check out the following resources: |
| 102 | + |
| 103 | +- The [AnswerDotAI/fsdp_qlora](https://github.com/AnswerDotAI/fsdp_qlora) repository. |
| 104 | +- The introductory [You can now train a 70b language model at home](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html) blog post by Answer.AI. |
| 105 | +- For an introduction to FSDP, read the [Introducing PyTorch Fully Sharded Data Parallel (FSDP) API](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api) blog post. |
| 106 | +- For more details about QLoRA, take a look at the [Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA](https://huggingface.co/blog/4bit-transformers-bitsandbytes) blog post. |
0 commit comments