This repository contains the code for training a document-summarizing agent using reinforcement learning. The agent is based on Qwen2.5-3B-Instruct and learns from its own experience through the ART reinforcement learning framework. While training runs are non-deterministic, the agent usually achieves SOTA performance within a few hours of training time on an H100 GPU.
The goal of your agent is to summarize a document into 350 words or less while maximizing the number of questions that the summary can answer. We'll be using documents from the Repliqa dataset and feeding summaries generated by the agent into Gemini 2.5 Flash along with 5 questions that could have been answered by the original document. Importantly, when the agent is summarizing the document, it is not aware of the questions that Gemini 2.5 Flash will be asked.
There are 3591 documents in the Repliqa dataset. By default, your agent will train on 3500 documents and validate on 91 documents. You can change these numbers by setting the TRAIN_SIZE
and VAL_SIZE
environment variables.
We've evaluated several SOTA models on this task. The scores below represent the average percentage of questions that each model's summary allowed Gemini 2.5 Flash to answer correctly.
- GPT-4o: 35%
- GPT-4.1: 46%
- o4-mini 60%
- gemini-2.5-pro-preview: 40%
Your goal is to train a model that outperforms all of them.
If you haven't already, install uv
by following the instructions here.
Then install the project dependencies by running uv sync
.
We'll be using SkyPilotBackend
to manage the GPU that your model will be trained on. In order for the backend to work, you'll need to have SkyPilot installed on your machine and provide it with the credentials to spin up machines on at least one infra provider.
We recommend using RunPod because of their ease of use, but any infra provider that SkyPilot supports will work.
Follow RunPod's Getting Started guide here. You'll have to provide a credit card to use RunPod, but you'll only pay for the time your GPUs are running.
In a new .env
file at the root of the repository, set the following optional environment variables:
WANDB_API_KEY
- Enables metric logging to Weights & Biases.OPENPIPE_API_KEY
- Enables chat completion logging to OpenPipe.OPENAI_API_KEY
- Will be necessary for later comparison benchmarks, but not used for training.
To enable model and logging backup to S3, you'll also need to provide AWS credentials. These are necessary for generating the benchmarks found in the benchmarks
directory, but not for training itself. If you don't already have AWS credentials with create/read/write permissions for s3 buckets, follow the instructions here.
AWS_ACCESS_KEY_ID
- Your AWS access key ID, which should have create/read/write permissions for s3 buckets.AWS_SECRET_ACCESS_KEY
- Your matching secret access key.AWS_REGION
- The region of the S3 bucket.BACKUP_BUCKET
- The name of the S3 bucket in which to store model checkpoints and logging data. Can be a new bucket or an existing one.
uv run python src/summarizer/train.py
The following steps execute when a training run on a new cluster begins:
- Spin up a cluster with 1 H100 GPU.
- This usually takes about 10 minutes, but RunPod occasionally has network throughput issues that can cause the cluster to take up to 30 minutes to spin up. Once the cluster is provisioned, it can be used for subsequent training runs without going through this process again.
- Register the model with ART.
- This usually takes less than 5 minutes, though it can require up to 30 minutes if RunPod experiences network issues.
- Download the model checkpoint from S3.
- Usually takes a few seconds.
- Train the model for a specified number of steps.
- Training itself should be pretty quick (each step takes less than a minute), but the total training time will depend on how many steps you run for. During training, the model checkpoint is saved to S3 after each step.
- Upload the final model checkpoint to S3.
- This usually takes a few seconds.
When you're done training and running benchmarks, you can shut down the cluster in two ways:
Through the CLI:
uv run sky down <cluster-name>
or through code:
DESTROY_AFTER_RUN = True
if DESTROY_AFTER_RUN:
await backend.down()
However, since spinning up clusters is a time-intensive process, we recommend keeping clusters alive until you're sure you won't be using them in the near future.
The benchmark_models.py
script will compare the performance of the trained model to gpt-4o
, gpt-4.1
, o4-mini
, and gemini-2.5-pro-preview
.
Before running the benchmark script, make sure you've provided a valid OPENROUTER_API_KEY
and the AWS credentials detailed in step 3. These credentials are necessary for the script to upload the benchmark results to S3.
uv run python benchmarks/benchmark_models.py
This script will:
- Run each benchmarked model through each document in the validation set.
- Record the percentage of questions that each model's summary allowed Gemini 2.5 Flash to answer correctly.
- Upload the results to S3.
Once the benchmark generation script has finished running, you can view the results and generate visual charts by navigating to benchmarks/display_benchmarks.ipynb
and running the cells. After running all the cells, you should see something like the following:
The percentage of questions that each model's summary allowed Gemini 2.5 Flash to answer correctly at each training step. By step 5 of this training run, the trained model outperforms every other model.
A side-by-side comparison of the percentage of questions that each model's summary allowed Gemini 2.5 Flash to answer correctly. The trained model began with a percentage of 12%, but by the final step, it was able to generate summaries that allowed Gemini 2.5 Flash to answer 24% of the questions correctly.