Skip to content

[Feature] Multi-node Ray support for GRPO sota-implementation #3040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions sota-implementations/grpo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,20 @@ The async mode offers better performance by:
- Better throughput
- More flexible buffer management

### Running GRPO on More Than One Node with SLURM

GRPO can be run across more than one node using SLURM, enabling distributed training for moderately scaled workloads.

Two scripts are provided for launching multi-node runs:

- `grpo-sync-multi-node.sbatch`: SLURM job script that launches sync GRPO across multiple nodes using Ray.
- `grpo-async-multi-node.sbatch`: SLURM job script that launches async GRPO across multiple nodes using Ray.

Example Usage:

```bash
sbatch sota-implementations/grpo/grpo-sync-multi-node.sbatch

### KL Divergences in PPO: Reference vs Inference

KL divergence is a key regularization term in policy optimization algorithms like PPO and in LLM post-training. It measures how much the updated policy diverges from a baseline or reference policy, helping to prevent the new policy from drifting too far and ensuring stable learning.
Expand Down
25 changes: 25 additions & 0 deletions sota-implementations/grpo/grpo-async-multi-node.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash
#SBATCH --job-name=grpo-async-multi-node
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=96
#SBATCH --exclusive
#SBATCH --output=logs/%x.job%j.out
#SBATCH --time=24:00:00

# Exit on any error
set -euo pipefail

# Ensure logs directory exists
mkdir -p logs

# Environment variables
export LIST_TO_STACK=1
export VLLM_USE_V1=0
export RAY_CLUSTER_MANAGED_EXTERNALLY=1

# Run command in Ray cluster
CMD="python grpo-async.py mode=async train_model.num_devices=8 ref_model.num_devices=4 inference_model.num_devices=4"
srun bash run_in_ray_cluster.sh "$CMD"

echo "Job completed"
11 changes: 9 additions & 2 deletions sota-implementations/grpo/grpo-async.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,11 @@ def train(
model_metadata = vLLMUpdater.get_model_metadata(policy_training)

# Create weight updater with remote LLM
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
weight_updater: vLLMUpdater = make_weight_updater(
master_address="localhost", # Since we're running locally
master_address="localhost"
if not ray_managed_externally
else ray.util.get_node_ip_address(),
master_port=None, # Will auto-assign an open port
model_metadata=model_metadata,
vllm_tp_size=cfg.inference_model.num_devices
Expand Down Expand Up @@ -422,7 +425,11 @@ def main(cfg):
ray_init_config["runtime_env"]["env_vars"]
)
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
ray.init(**ray_init_config)
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
if ray_managed_externally:
ray.init(address="auto")
else:
ray.init(**ray_init_config)

# Check if num_devices is set
if cfg.inference_model.num_devices is None:
Expand Down
25 changes: 25 additions & 0 deletions sota-implementations/grpo/grpo-sync-multi-node.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash
#SBATCH --job-name=grpo-sync-multi-node
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=96
#SBATCH --exclusive
#SBATCH --output=logs/%x.job%j.out
#SBATCH --time=24:00:00

# Exit on any error
set -euo pipefail

# Ensure logs directory exists
mkdir -p logs

# Environment variables
export LIST_TO_STACK=1
export VLLM_USE_V1=0
export RAY_CLUSTER_MANAGED_EXTERNALLY=1

# Run command in Ray cluster
CMD="python grpo-sync.py mode=sync train_model.num_devices=8 ref_model.num_devices=4 inference_model.num_devices=4"
srun bash run_in_ray_cluster.sh "$CMD"

echo "Job completed"
11 changes: 9 additions & 2 deletions sota-implementations/grpo/grpo-sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,11 @@ def train(
model_metadata = vLLMUpdater.get_model_metadata(policy_training)

# Create weight updater with remote LLM
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
weight_updater: vLLMUpdater = make_weight_updater(
master_address="localhost", # Since we're running locally
master_address="localhost"
if not ray_managed_externally
else ray.util.get_node_ip_address(),
master_port=None, # Will auto-assign an open port
model_metadata=model_metadata,
vllm_tp_size=cfg.inference_model.num_devices
Expand Down Expand Up @@ -436,7 +439,11 @@ def main(cfg):
ray_init_config["runtime_env"]["env_vars"]
)
torchrl_logger.info(f"Ray init config: {ray_init_config=}")
ray.init(**ray_init_config)
ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY")
if ray_managed_externally:
ray.init(address="auto")
else:
ray.init(**ray_init_config)

# Check if num_devices is set
if cfg.inference_model.num_devices is None:
Expand Down
59 changes: 59 additions & 0 deletions sota-implementations/grpo/run_in_ray_cluster.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/bin/bash

set -euo pipefail

# Get command from argument
CMD="$1"

# Set up Ray cluster configuration
HEAD_NODE=$(scontrol show hostname "$SLURM_NODELIST" | head -n 1)
RAY_PORT=6379

# Get current node name
CURRENT_NODE=$(hostname | cut -d. -f1)

# Get HEAD_NODE_IP
if [ "$SLURM_NODEID" -eq 0 ]; then
# We're on the head node, get our own IP
HEAD_NODE_IP=$(hostname -I | awk '{print $1}')
else
# We're on a worker, resolve the head node's IP using DNS
HEAD_NODE_IP=$(getent hosts "$HEAD_NODE" | awk '{print $1}')
fi

# Set up cleanup function
cleanup() {
if command -v ray &>/dev/null; then
echo "Stopping Ray on node $CURRENT_NODE"
ray stop || true
fi
}
trap cleanup EXIT

# Start Ray based on node role
if [ "$SLURM_NODEID" -eq 0 ]; then
echo "Starting Ray head node on $CURRENT_NODE"
ray start --head --disable-usage-stats --port=$RAY_PORT
echo "Ray head node started at $HEAD_NODE_IP:$RAY_PORT"
else
echo "Waiting for head node to be ready..."
sleep 10
echo "Starting Ray worker on node $CURRENT_NODE (ID: $SLURM_NODEID)"
ray start --disable-usage-stats --address="$HEAD_NODE_IP:$RAY_PORT"
fi

# Ensure Ray cluster is ready
sleep 2

# Only head node runs the training command
if [ "$SLURM_NODEID" -eq 0 ]; then
echo "Starting training process on head node $CURRENT_NODE"
bash -c "$CMD"
else
# Worker nodes just wait for the head to finish
while ray status --address="$HEAD_NODE_IP:$RAY_PORT" &>/dev/null; do
sleep 10
done
fi

echo "Node $CURRENT_NODE: Done"