From d53b72f8b7dcd4bf777be4f7d13aa3b3288fc83b Mon Sep 17 00:00:00 2001 From: Asad Aali Date: Thu, 24 Apr 2025 17:26:35 -0700 Subject: [PATCH] Included bnb quantization feature in lm_local.py for quantized fine-tuning --- dspy/clients/lm_local.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dspy/clients/lm_local.py b/dspy/clients/lm_local.py index 8329db9786..9e93642f49 100644 --- a/dspy/clients/lm_local.py +++ b/dspy/clients/lm_local.py @@ -211,8 +211,20 @@ def train_sft_locally(model_name, train_data, train_kwargs): ) logger.info(f"Using device: {device}") + USE_QUANTIZATION = train_kwargs.get("use_quantization", False) + quantization_config = None + + if USE_QUANTIZATION: + from transformers import BitsAndBytesConfig + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4" + ) + model = AutoModelForCausalLM.from_pretrained( - pretrained_model_name_or_path=model_name + pretrained_model_name_or_path=model_name, quantization_config=quantization_config ).to(device) tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_name)