Skip to content

[Feature] Reconsider prompts for GRPO #3030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,57 @@

**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.

## 🚀 What's New

### LLM API - Complete Framework for Language Model Fine-tuning

TorchRL now includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:

- 🤖 **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines
- 💬 **Conversation Management**: Advanced `History` class for multi-turn dialogue with automatic chat template detection
- 🛠️ **Tool Integration**: Built-in support for Python code execution, function calling, and custom tool transforms
- 🎯 **Specialized Objectives**: GRPO (Group Relative Policy Optimization) and SFT loss functions optimized for language models
- ⚡ **High-Performance Collectors**: Async data collection with distributed training support
- 🔄 **Flexible Environments**: Transform-based architecture for reward computation, data loading, and conversation augmentation

The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the [complete documentation](https://pytorch.org/rl/main/reference/llms.html) and [GRPO implementation example](https://github.com/pytorch/rl/tree/main/sota-implementations/grpo) to get started!

<details>
<summary>Quick LLM API Example</summary>

```python
from torchrl.envs.llm import ChatEnv
from torchrl.modules.llm import TransformersWrapper
from torchrl.objectives.llm import GRPOLoss
from torchrl.collectors.llm import LLMCollector

# Create environment with Python tool execution
env = ChatEnv(
tokenizer=tokenizer,
system_prompt="You are an assistant that can execute Python code.",
batch_size=[1]
).append_transform(PythonInterpreter())

# Wrap your language model
llm = TransformersWrapper(
model=model,
tokenizer=tokenizer,
input_mode="history"
)

# Set up GRPO training
loss_fn = GRPOLoss(llm, critic, gamma=0.99)
collector = LLMCollector(env, llm, frames_per_batch=100)

# Training loop
for data in collector:
loss = loss_fn(data)
loss.backward()
optimizer.step()
```

</details>

## Key features

- 🐍 **Python-first**: Designed with Python as the primary language for ease of use and flexibility
Expand Down Expand Up @@ -516,6 +567,39 @@ And it is `functorch` and `torch.compile` compatible!
- various [recipes](https://github.com/pytorch/rl/blob/main/torchrl/trainers/helpers/models.py) to build models that
correspond to the environment being deployed.

- **LLM API**: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends,
conversation management with automatic chat template detection, tool integration (Python execution, function calling),
specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning,
and tool-augmented training scenarios.
<details>
<summary>Code</summary>

```python
from torchrl.envs.llm import ChatEnv
from torchrl.modules.llm import TransformersWrapper
from torchrl.envs.llm.transforms import PythonInterpreter

# Create environment with tool execution
env = ChatEnv(
tokenizer=tokenizer,
system_prompt="You can execute Python code.",
batch_size=[1]
).append_transform(PythonInterpreter())

# Wrap language model for training
llm = TransformersWrapper(
model=model,
tokenizer=tokenizer,
input_mode="history"
)

# Multi-turn conversation with tool use
obs = env.reset(TensorDict({"query": "Calculate 2+2"}, batch_size=[1]))
llm_output = llm(obs) # Generates response
obs = env.step(llm_output) # Environment processes response
```
</details>

If you feel a feature is missing from the library, please submit an issue!
If you would like to contribute to new features, check our [call for contributions](https://github.com/pytorch/rl/issues/509) and our [contribution](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md) page.

Expand Down Expand Up @@ -792,6 +876,18 @@ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blo
<td> NA
</td>
</tr>
<tr>
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/grpo">LLM API (GRPO)</a>
</td>
<td> NA
</td>
<td> +
</td>
<td> +
</td>
<td> NA
</td>
</tr>
</table>

** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on
Expand All @@ -800,6 +896,7 @@ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blo
and many more to come!

[Code examples](examples/) displaying toy code snippets and training scripts are also available
- [LLM API & GRPO](sota-implementations/grpo) - Complete language model fine-tuning pipeline
- [RLHF](examples/rlhf)
- [Memory-mapped replay buffers](examples/torchrl_features)

Expand Down
5 changes: 5 additions & 0 deletions docs/source/_static/img/llm-data.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/img/llm-env.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading