Skip to content

YanCotta/reinforcement-fine-tuning-llms-with-grpo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

14 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

RL4LMS: Reinforcement Learning for Language Model Supervision

License: MIT Python Version Code style: black GitHub stars

RL4LMS is a powerful and flexible library designed for fine-tuning large language models (LLMs) using reinforcement learning, with a primary focus on the GRPO (Generalized Reinforcement Policy Optimization) algorithm. This library provides researchers and practitioners with a robust framework for implementing custom reward functions, environments, and training loops to optimize language models for specific tasks.

Table of Contents

Features

RL4LMS comes packed with powerful features designed to streamline the process of fine-tuning language models:

  • ๐Ÿ”„ Flexible Reward Function API: Intuitive interface for defining custom reward functions tailored to your specific task
  • ๐Ÿค— HuggingFace Integration: Seamless compatibility with all HuggingFace Transformers models
  • โšก Efficient Training: Optimized for both single and multi-GPU training with minimal setup
  • ๐Ÿงฉ Extensible Architecture: Modular design that makes it easy to add new components and environments
  • ๐Ÿ“Š Built-in Evaluation: Comprehensive tools for monitoring and evaluating model performance
  • ๐ŸŽฎ Wordle Environment: Built-in Wordle game environment for RL training and experimentation

Installation

RL4LMS can be installed with just a few simple steps:

  1. Clone the repository

    git clone https://github.com/YanCotta/reinforcement-fine-tuning-llms-with-grpo.git
    cd reinforcement-fine-tuning-llms-with-grpo
  2. Set up a virtual environment (recommended):

    # Create and activate virtual environment
    python -m venv venv
    # On Windows:
    .\venv\Scripts\activate
    # On macOS/Linux:
    source venv/bin/activate
  3. Install the package in development mode

    pip install -e .
  4. Install additional dependencies

    pip install -r requirements.txt

Optional: Install with development dependencies

For contributing to the project or running tests:

pip install -e ".[dev]"

Quick Start

Fine-tuning on Wordle

RL4LMS includes a ready-to-use implementation for fine-tuning language models on the Wordle game. Here's how to get started:

  1. Prepare your environment as described in the Installation section

  2. Run the example script:

    python examples/wordle_finetuning.py

Basic Usage Example

Here's a minimal example showing how to use RL4LMS to fine-tune a model:

from rl4lms.trainer import GRPOTrainer
from rl4lms.reward_functions.wordle import WordleRewardFunction
from rl4lms.envs.wordle_env import WordleEnv

# Initialize components
model = AutoModelForCausalLM.from_pretrained("gpt2")
ref_model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
reward_fn = WordleRewardFunction()

# Create trainer and start training
trainer = GRPOTrainer(
    model=model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    reward_fn=reward_fn,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    batch_size=8,
    num_epochs=3,
    learning_rate=1e-5,
    output_dir="./wordle_grpo_output"
)

trainer.train()

Project Structure

rl4lms/
โ”œโ”€โ”€ envs/                  # Environment implementations
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ””โ”€โ”€ wordle_env.py      # Wordle game environment
โ”œโ”€โ”€ losses/
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ””โ”€โ”€ grpo_loss.py       # GRPO loss implementation
โ”œโ”€โ”€ models/                # Model architectures
โ”‚   โ””โ”€โ”€ __init__.py
โ”œโ”€โ”€ reward_functions/      # Reward function implementations
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ”œโ”€โ”€ base.py           # Base reward function class
โ”‚   โ””โ”€โ”€ wordle.py         # Wordle-specific reward functions
โ”œโ”€โ”€ trainer/
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ””โ”€โ”€ grpo_trainer.py   # Training loop implementation
โ””โ”€โ”€ utils/                 # Utility functions
    โ””โ”€โ”€ __init__.py

examples/                # Example scripts
โ”œโ”€โ”€ wordle_finetuning.py  # Wordle fine-tuning example

tests/                   # Unit tests
โ””โ”€โ”€ test_reward_functions.py

Custom Reward Functions

To create a custom reward function, inherit from the RewardFunction base class and implement the __call__ method:

from rl4lms.reward_functions import RewardFunction
import torch

class MyRewardFunction(RewardFunction):
    def __init__(self, **kwargs):
        super().__init__()
        # Initialize any parameters
        
    def __call__(self, prompt_texts, generated_texts, **kwargs):
        """
        Calculate rewards for generated text.
        
        Args:
            prompt_texts: List of input prompts
            generated_texts: List of generated texts to score
            **kwargs: Additional metadata
            
        Returns:
            torch.Tensor: Tensor of rewards for each generated text
        """
        # Calculate rewards here
        rewards = torch.ones(len(generated_texts))  # Example: return 1 for each text
        return rewards

Documentation

For detailed documentation, including API references, advanced usage examples, and tutorials, please visit our documentation site.

Contributing

We welcome contributions from the community! Whether you're fixing bugs, adding new features, or improving documentation, your help is greatly appreciated.

How to Contribute

  1. Fork the repository on GitHub
  2. Clone your fork locally
  3. Create a new branch for your changes
  4. Commit your changes with clear, descriptive messages
  5. Push your changes to your fork
  6. Open a Pull Request with a clear description of your changes

Development Setup

  1. Install development dependencies:

    pip install -e ".[dev]"
  2. Run tests:

    pytest tests/
  3. Format your code:

    black .
    isort .
  4. Check for code style issues:

    flake8 src tests
    mypy src

Contact

For questions, suggestions, or support, please reach out:

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

  • This project was inspired by the course "Reinforcement Fine-Tuning LLMs With GRPO".
  • Built with โค๏ธ using PyTorch and HuggingFace Transformers.

About

Reinforcement Fine-Tuning LLMs With GRPO

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages