Skip to content

SFTTrainer freezes LoRA adapter when using PEFT model as argument #3926

@dvgodoy

Description

@dvgodoy

Reproduction

See code in Colab

The problem seems to be that the SFTTrainer class calls prepare_model_for_kbit_training() on an instance of a PeftModel. The preparing function freezes everything, including the adapters, and training fails.
If, however, one passes an instance of LoraConfig to SFTTrainer, the get_peft_model() gets called after preparation and then the adapters are trainable, as expected.

IMO, in the SFTTrainer class, _prepare_peft_model() method, the line 609 should be changed from:

if is_qlora and not is_sharded_qlora:

To:

if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel):

The code below was used to reproduce both situations, using an instance of PeftModel (Trainer1) and using a base model and an instance of LoraConfig instead (Trainer2). One can easily see the difference by calling the print_trainable_parameters() method on both models.

import torch
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

repo_id = 'microsoft/Phi-3-mini-4k-instruct'

nf4_config = BitsAndBytesConfig(
   load_in_4bit=True,
)

model_q4 = AutoModelForCausalLM.from_pretrained(repo_id,
                                                device_map='cuda:0',
                                                quantization_config=nf4_config)

model_q4 = prepare_model_for_kbit_training(model_q4,
                                           use_gradient_checkpointing=False,
                                           gradient_checkpointing_kwargs={})

config = LoraConfig(
    r=16,
    target_modules=['o_proj', 'qkv_proj', 'gate_up_proj', 'down_proj'],
    lora_alpha=32,
    task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model_q4, config)

tokenizer = AutoTokenizer.from_pretrained(repo_id)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

dataset = load_dataset("dvgodoy/yoda_sentences", split="train")
dataset = dataset.rename_column("sentence", "prompt")
dataset = dataset.rename_column("translation_extra", "completion")
dataset = dataset.remove_columns(["translation"])

def format_dataset(examples):
    if isinstance(examples["prompt"], list):
        output_texts = []
        for i in range(len(examples["prompt"])):
            converted_sample = [
                {"role": "user", "content": examples["prompt"][i]},
                {"role": "assistant", "content": examples["completion"][i]},
            ]
            output_texts.append(converted_sample)
        return {'messages': output_texts}
    else:
        converted_sample = [
            {"role": "user", "content": examples["prompt"]},
            {"role": "assistant", "content": examples["completion"]},
        ]
        return {'messages': converted_sample}

dataset = dataset.map(format_dataset).remove_columns(['prompt', 'completion'])

# Using own PEFT model, LoRA already applied
trainer1 = SFTTrainer(
    model=peft_model,
    processing_class=tokenizer,
    train_dataset=dataset,
    data_collator=None,
    args=SFTConfig(
        output_dir="./future_name_on_the_hub",
        packing=False,
        report_to='none',
        bf16=False,
        max_length=64,
    )
)

# Using base model and LoRA config
base_model = AutoModelForCausalLM.from_pretrained(repo_id,
                                                device_map='cuda:0',
                                                quantization_config=nf4_config)

trainer2 = SFTTrainer(
    model=base_model,
    peft_config=config,
    processing_class=tokenizer,
    train_dataset=dataset,
    data_collator=None,
    args=SFTConfig(
        output_dir="./future_name_on_the_hub",
        packing=False,
        report_to='none',
        bf16=False,
        max_length=64,
    )
)

trainer1.model.print_trainable_parameters()
trainer2.model.print_trainable_parameters()

outputs:

Trainer 1:

trainable params: 0 || all params: 3,846,245,376 || trainable%: 0.0000

Trainer 2:

trainable params: 25,165,824 || all params: 3,846,245,376 || trainable%: 0.6543

System Info

  • Platform: Linux-6.1.123+-x86_64-with-glibc2.35
  • Python version: 3.12.11
  • TRL version: 0.21.0
  • PyTorch version: 2.8.0+cu126
  • accelerator(s): Tesla T4
  • Transformers version: 4.55.2
  • Accelerate version: 1.10.0
  • Accelerate config: not found
  • Datasets version: 4.0.0
  • HF Hub version: 0.34.4
  • bitsandbytes version: 0.47.0
  • DeepSpeed version: not installed
  • Diffusers version: 0.34.0
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.99.9
  • PEFT version: 0.17.0
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    ⚡ PEFTRelated to PEFT❓ questionSeeking clarification or more information

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions