-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Reproduction
The problem seems to be that the SFTTrainer
class calls prepare_model_for_kbit_training()
on an instance of a PeftModel
. The preparing function freezes everything, including the adapters, and training fails.
If, however, one passes an instance of LoraConfig
to SFTTrainer
, the get_peft_model()
gets called after preparation and then the adapters are trainable, as expected.
IMO, in the SFTTrainer
class, _prepare_peft_model()
method, the line 609 should be changed from:
if is_qlora and not is_sharded_qlora:
To:
if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel):
The code below was used to reproduce both situations, using an instance of PeftModel
(Trainer1) and using a base model and an instance of LoraConfig
instead (Trainer2). One can easily see the difference by calling the print_trainable_parameters()
method on both models.
import torch
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
repo_id = 'microsoft/Phi-3-mini-4k-instruct'
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
)
model_q4 = AutoModelForCausalLM.from_pretrained(repo_id,
device_map='cuda:0',
quantization_config=nf4_config)
model_q4 = prepare_model_for_kbit_training(model_q4,
use_gradient_checkpointing=False,
gradient_checkpointing_kwargs={})
config = LoraConfig(
r=16,
target_modules=['o_proj', 'qkv_proj', 'gate_up_proj', 'down_proj'],
lora_alpha=32,
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model_q4, config)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
dataset = load_dataset("dvgodoy/yoda_sentences", split="train")
dataset = dataset.rename_column("sentence", "prompt")
dataset = dataset.rename_column("translation_extra", "completion")
dataset = dataset.remove_columns(["translation"])
def format_dataset(examples):
if isinstance(examples["prompt"], list):
output_texts = []
for i in range(len(examples["prompt"])):
converted_sample = [
{"role": "user", "content": examples["prompt"][i]},
{"role": "assistant", "content": examples["completion"][i]},
]
output_texts.append(converted_sample)
return {'messages': output_texts}
else:
converted_sample = [
{"role": "user", "content": examples["prompt"]},
{"role": "assistant", "content": examples["completion"]},
]
return {'messages': converted_sample}
dataset = dataset.map(format_dataset).remove_columns(['prompt', 'completion'])
# Using own PEFT model, LoRA already applied
trainer1 = SFTTrainer(
model=peft_model,
processing_class=tokenizer,
train_dataset=dataset,
data_collator=None,
args=SFTConfig(
output_dir="./future_name_on_the_hub",
packing=False,
report_to='none',
bf16=False,
max_length=64,
)
)
# Using base model and LoRA config
base_model = AutoModelForCausalLM.from_pretrained(repo_id,
device_map='cuda:0',
quantization_config=nf4_config)
trainer2 = SFTTrainer(
model=base_model,
peft_config=config,
processing_class=tokenizer,
train_dataset=dataset,
data_collator=None,
args=SFTConfig(
output_dir="./future_name_on_the_hub",
packing=False,
report_to='none',
bf16=False,
max_length=64,
)
)
trainer1.model.print_trainable_parameters()
trainer2.model.print_trainable_parameters()
outputs:
Trainer 1:
trainable params: 0 || all params: 3,846,245,376 || trainable%: 0.0000
Trainer 2:
trainable params: 25,165,824 || all params: 3,846,245,376 || trainable%: 0.6543
System Info
- Platform: Linux-6.1.123+-x86_64-with-glibc2.35
- Python version: 3.12.11
- TRL version: 0.21.0
- PyTorch version: 2.8.0+cu126
- accelerator(s): Tesla T4
- Transformers version: 4.55.2
- Accelerate version: 1.10.0
- Accelerate config: not found
- Datasets version: 4.0.0
- HF Hub version: 0.34.4
- bitsandbytes version: 0.47.0
- DeepSpeed version: not installed
- Diffusers version: 0.34.0
- Liger-Kernel version: not installed
- LLM-Blender version: not installed
- OpenAI version: 1.99.9
- PEFT version: 0.17.0
- vLLM version: not installed
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete