Accelerate LLM preference tuning via prefix sharing with a single line of code. Applicable to Direct Preference Optimization (DPO), Reward Modeling (RM), Group Relative Policy Optimization (GRPO), etc.
Install the stable version from PyPI:
pip install flash-pref
Or install the latest version from GitHub:
pip install git+https://github.com/li-plus/flash-preference.git@main
All you have to do is to add a shared_prefix
context wrapping the model forward and backward passes. The common prefixes of the input sequences will be automatically detected and shared, reducing computation and memory footprint without loss of accuracy.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from flash_pref import shared_prefix
model_id = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
model = AutoModelForCausalLM.from_pretrained(
model_id, attn_implementation="flash_attention_2", use_cache=False, torch_dtype=torch.bfloat16, device_map="cuda"
)
prompt = "What are the next 10 numbers of this sequence: " + ", ".join(str(x) for x in range(500))
chosen_response = ", ".join(str(x) for x in range(500, 500 + 10))
rejected_response = ", ".join(str(x) for x in range(500, 500 + 10, 2))
conversations = [
[{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen_response}],
[{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected_response}],
]
inputs = tokenizer.apply_chat_template(
conversations, tokenize=True, padding=True, return_tensors="pt", return_dict=True
).to("cuda")
# ===== MAGIC HERE =====
with shared_prefix(model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask):
output = model(**inputs)
output.logits.backward(torch.randn_like(output.logits))
For huggingface/trl users, a drop-in replacement for trl trainer is also available. Check out the end-to-end training examples below.
Algorithm | Original Trainer | Accelerated Trainer with Prefix Sharing | Example |
---|---|---|---|
Direct Preference Optimization | from trl import DPOTrainer |
from flash_pref import FlashDPOTrainer |
examples/dpo_trl.py |
Reward Modeling | from trl import RewardTrainer |
from flash_pref import FlashRewardTrainer |
examples/reward_modeling_trl.py |
The performance speedup and memory saved relative to the baseline:
Benchmark settings are as below. Please refer to tests/benchmark.py for more details.
- Model: Qwen/Qwen2.5-7B-Instruct with gradient checkpointing, Liger-Kernel and FlashAttention-2 enabled.
- Data: mocked pairwise preference data where prompt and response lengths vary from 64 to 16k.
- Computation: 1 forward pass and then 1 backward pass.
- Hardware: 1x NVIDIA A800-SXM4-80GB GPU.
Unit Tests
Currently tested for LLaMA, Gemma, Gemma2, Qwen2, Qwen2VL and Qwen2.5VL architectures. At least 2 GPUs are required for unit tests. To run the unit tests, type:
make test
Code Format
To format the code, type:
make lint
This project is under MIT License.