Skip to content

Commit 8151cf4

Browse files
committed
[Test] Fix failing test (#3033)
[Feature] Add thinking prompts to GRPO amend aned aned amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend amend
1 parent c1322a2 commit 8151cf4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+11037
-2668
lines changed

README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,57 @@
2323

2424
**TorchRL** is an open-source Reinforcement Learning (RL) library for PyTorch.
2525

26+
## 🚀 What's New
27+
28+
### LLM API - Complete Framework for Language Model Fine-tuning
29+
30+
TorchRL now includes a comprehensive **LLM API** for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
31+
32+
- 🤖 **Unified LLM Wrappers**: Seamless integration with Hugging Face models and vLLM inference engines
33+
- 💬 **Conversation Management**: Advanced `History` class for multi-turn dialogue with automatic chat template detection
34+
- 🛠️ **Tool Integration**: Built-in support for Python code execution, function calling, and custom tool transforms
35+
- 🎯 **Specialized Objectives**: GRPO (Group Relative Policy Optimization) and SFT loss functions optimized for language models
36+
-**High-Performance Collectors**: Async data collection with distributed training support
37+
- 🔄 **Flexible Environments**: Transform-based architecture for reward computation, data loading, and conversation augmentation
38+
39+
The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the [complete documentation](https://pytorch.org/rl/main/reference/llms.html) and [GRPO implementation example](https://github.com/pytorch/rl/tree/main/sota-implementations/grpo) to get started!
40+
41+
<details>
42+
<summary>Quick LLM API Example</summary>
43+
44+
```python
45+
from torchrl.envs.llm import ChatEnv
46+
from torchrl.modules.llm import TransformersWrapper
47+
from torchrl.objectives.llm import GRPOLoss
48+
from torchrl.collectors.llm import LLMCollector
49+
50+
# Create environment with Python tool execution
51+
env = ChatEnv(
52+
tokenizer=tokenizer,
53+
system_prompt="You are an assistant that can execute Python code.",
54+
batch_size=[1]
55+
).append_transform(PythonInterpreter())
56+
57+
# Wrap your language model
58+
llm = TransformersWrapper(
59+
model=model,
60+
tokenizer=tokenizer,
61+
input_mode="history"
62+
)
63+
64+
# Set up GRPO training
65+
loss_fn = GRPOLoss(llm, critic, gamma=0.99)
66+
collector = LLMCollector(env, llm, frames_per_batch=100)
67+
68+
# Training loop
69+
for data in collector:
70+
loss = loss_fn(data)
71+
loss.backward()
72+
optimizer.step()
73+
```
74+
75+
</details>
76+
2677
## Key features
2778

2879
- 🐍 **Python-first**: Designed with Python as the primary language for ease of use and flexibility
@@ -516,6 +567,39 @@ And it is `functorch` and `torch.compile` compatible!
516567
- various [recipes](https://github.com/pytorch/rl/blob/main/torchrl/trainers/helpers/models.py) to build models that
517568
correspond to the environment being deployed.
518569

570+
- **LLM API**: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends,
571+
conversation management with automatic chat template detection, tool integration (Python execution, function calling),
572+
specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning,
573+
and tool-augmented training scenarios.
574+
<details>
575+
<summary>Code</summary>
576+
577+
```python
578+
from torchrl.envs.llm import ChatEnv
579+
from torchrl.modules.llm import TransformersWrapper
580+
from torchrl.envs.llm.transforms import PythonInterpreter
581+
582+
# Create environment with tool execution
583+
env = ChatEnv(
584+
tokenizer=tokenizer,
585+
system_prompt="You can execute Python code.",
586+
batch_size=[1]
587+
).append_transform(PythonInterpreter())
588+
589+
# Wrap language model for training
590+
llm = TransformersWrapper(
591+
model=model,
592+
tokenizer=tokenizer,
593+
input_mode="history"
594+
)
595+
596+
# Multi-turn conversation with tool use
597+
obs = env.reset(TensorDict({"query": "Calculate 2+2"}, batch_size=[1]))
598+
llm_output = llm(obs) # Generates response
599+
obs = env.step(llm_output) # Environment processes response
600+
```
601+
</details>
602+
519603
If you feel a feature is missing from the library, please submit an issue!
520604
If you would like to contribute to new features, check our [call for contributions](https://github.com/pytorch/rl/issues/509) and our [contribution](https://github.com/pytorch/rl/blob/main/CONTRIBUTING.md) page.
521605

@@ -792,6 +876,18 @@ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blo
792876
<td> NA
793877
</td>
794878
</tr>
879+
<tr>
880+
<td><a href="https://github.com/pytorch/rl/blob/main/sota-implementations/grpo">LLM API (GRPO)</a>
881+
</td>
882+
<td> NA
883+
</td>
884+
<td> +
885+
</td>
886+
<td> +
887+
</td>
888+
<td> NA
889+
</td>
890+
</tr>
795891
</table>
796892

797893
** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending on
@@ -800,6 +896,7 @@ A series of [State-of-the-Art implementations](https://github.com/pytorch/rl/blo
800896
and many more to come!
801897

802898
[Code examples](examples/) displaying toy code snippets and training scripts are also available
899+
- [LLM API & GRPO](sota-implementations/grpo) - Complete language model fine-tuning pipeline
803900
- [RLHF](examples/rlhf)
804901
- [Memory-mapped replay buffers](examples/torchrl_features)
805902

docs/source/_static/img/llm-data.svg

Lines changed: 5 additions & 0 deletions
Loading

docs/source/_static/img/llm-env.png

577 KB
Loading

0 commit comments

Comments
 (0)