Skip to content

li-plus/flash-preference

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Flash Preference

PyPI License: MIT

Accelerate LLM preference tuning via prefix sharing with a single line of code. Applicable to Direct Preference Optimization (DPO), Reward Modeling (RM), Group Relative Policy Optimization (GRPO), etc.

Getting Started

Install the stable version from PyPI:

pip install flash-pref

Or install the latest version from GitHub:

pip install git+https://github.com/li-plus/flash-preference.git@main

All you have to do is to add a shared_prefix context wrapping the model forward and backward passes. The common prefixes of the input sequences will be automatically detected and shared, reducing computation and memory footprint without loss of accuracy.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from flash_pref import shared_prefix

model_id = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
model = AutoModelForCausalLM.from_pretrained(
    model_id, attn_implementation="flash_attention_2", use_cache=False, torch_dtype=torch.bfloat16, device_map="cuda"
)

prompt = "What are the next 10 numbers of this sequence: " + ", ".join(str(x) for x in range(500))
chosen_response = ", ".join(str(x) for x in range(500, 500 + 10))
rejected_response = ", ".join(str(x) for x in range(500, 500 + 10, 2))

conversations = [
    [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen_response}],
    [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected_response}],
]
inputs = tokenizer.apply_chat_template(
    conversations, tokenize=True, padding=True, return_tensors="pt", return_dict=True
).to("cuda")

# ===== MAGIC HERE =====
with shared_prefix(model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask):
    output = model(**inputs)
    output.logits.backward(torch.randn_like(output.logits))

For huggingface/trl users, a drop-in replacement for trl trainer is also available. Check out the end-to-end training examples below.

Algorithm Original Trainer Accelerated Trainer with Prefix Sharing Example
Direct Preference Optimization from trl import DPOTrainer from flash_pref import FlashDPOTrainer examples/dpo_trl.py
Reward Modeling from trl import RewardTrainer from flash_pref import FlashRewardTrainer examples/reward_modeling_trl.py

Benchmark

The performance speedup and memory saved relative to the baseline:

Benchmark settings are as below. Please refer to tests/benchmark.py for more details.

  • Model: Qwen/Qwen2.5-7B-Instruct with gradient checkpointing, Liger-Kernel and FlashAttention-2 enabled.
  • Data: mocked pairwise preference data where prompt and response lengths vary from 64 to 16k.
  • Computation: 1 forward pass and then 1 backward pass.
  • Hardware: 1x NVIDIA A800-SXM4-80GB GPU.

Developing

Unit Tests

Currently tested for LLaMA, Gemma, Gemma2, Qwen2, Qwen2VL and Qwen2.5VL architectures. At least 2 GPUs are required for unit tests. To run the unit tests, type:

make test

Code Format

To format the code, type:

make lint

License

This project is under MIT License.

About

Accelerate LLM preference tuning via prefix sharing with a single line of code

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published