diff --git a/sota-implementations/grpo/README.md b/sota-implementations/grpo/README.md index 58f2a4fe633..e526338f16e 100644 --- a/sota-implementations/grpo/README.md +++ b/sota-implementations/grpo/README.md @@ -134,6 +134,20 @@ The async mode offers better performance by: - Better throughput - More flexible buffer management +### Running GRPO on More Than One Node with SLURM + +GRPO can be run across more than one node using SLURM, enabling distributed training for moderately scaled workloads. + +Two scripts are provided for launching multi-node runs: + +- `grpo-sync-multi-node.sbatch`: SLURM job script that launches sync GRPO across multiple nodes using Ray. +- `grpo-async-multi-node.sbatch`: SLURM job script that launches async GRPO across multiple nodes using Ray. + +Example Usage: + +```bash +sbatch sota-implementations/grpo/grpo-sync-multi-node.sbatch + ### KL Divergences in PPO: Reference vs Inference KL divergence is a key regularization term in policy optimization algorithms like PPO and in LLM post-training. It measures how much the updated policy diverges from a baseline or reference policy, helping to prevent the new policy from drifting too far and ensuring stable learning. diff --git a/sota-implementations/grpo/grpo-async-multi-node.sbatch b/sota-implementations/grpo/grpo-async-multi-node.sbatch new file mode 100644 index 00000000000..5abb5d2b167 --- /dev/null +++ b/sota-implementations/grpo/grpo-async-multi-node.sbatch @@ -0,0 +1,25 @@ +#!/bin/bash +#SBATCH --job-name=grpo-async-multi-node +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --exclusive +#SBATCH --output=logs/%x.job%j.out +#SBATCH --time=24:00:00 + +# Exit on any error +set -euo pipefail + +# Ensure logs directory exists +mkdir -p logs + +# Environment variables +export LIST_TO_STACK=1 +export VLLM_USE_V1=0 +export RAY_CLUSTER_MANAGED_EXTERNALLY=1 + +# Run command in Ray cluster +CMD="python grpo-async.py mode=async train_model.num_devices=8 ref_model.num_devices=4 inference_model.num_devices=4" +srun bash run_in_ray_cluster.sh "$CMD" + +echo "Job completed" diff --git a/sota-implementations/grpo/grpo-async.py b/sota-implementations/grpo/grpo-async.py index 6ea882cd5dc..da29876786f 100644 --- a/sota-implementations/grpo/grpo-async.py +++ b/sota-implementations/grpo/grpo-async.py @@ -159,8 +159,11 @@ def train( model_metadata = vLLMUpdater.get_model_metadata(policy_training) # Create weight updater with remote LLM + ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY") weight_updater: vLLMUpdater = make_weight_updater( - master_address="localhost", # Since we're running locally + master_address="localhost" + if not ray_managed_externally + else ray.util.get_node_ip_address(), master_port=None, # Will auto-assign an open port model_metadata=model_metadata, vllm_tp_size=cfg.inference_model.num_devices @@ -422,7 +425,11 @@ def main(cfg): ray_init_config["runtime_env"]["env_vars"] ) torchrl_logger.info(f"Ray init config: {ray_init_config=}") - ray.init(**ray_init_config) + ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY") + if ray_managed_externally: + ray.init(address="auto") + else: + ray.init(**ray_init_config) # Check if num_devices is set if cfg.inference_model.num_devices is None: diff --git a/sota-implementations/grpo/grpo-sync-multi-node.sbatch b/sota-implementations/grpo/grpo-sync-multi-node.sbatch new file mode 100644 index 00000000000..b3044279c42 --- /dev/null +++ b/sota-implementations/grpo/grpo-sync-multi-node.sbatch @@ -0,0 +1,25 @@ +#!/bin/bash +#SBATCH --job-name=grpo-sync-multi-node +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --exclusive +#SBATCH --output=logs/%x.job%j.out +#SBATCH --time=24:00:00 + +# Exit on any error +set -euo pipefail + +# Ensure logs directory exists +mkdir -p logs + +# Environment variables +export LIST_TO_STACK=1 +export VLLM_USE_V1=0 +export RAY_CLUSTER_MANAGED_EXTERNALLY=1 + +# Run command in Ray cluster +CMD="python grpo-sync.py mode=sync train_model.num_devices=8 ref_model.num_devices=4 inference_model.num_devices=4" +srun bash run_in_ray_cluster.sh "$CMD" + +echo "Job completed" diff --git a/sota-implementations/grpo/grpo-sync.py b/sota-implementations/grpo/grpo-sync.py index bd88bfd6be2..c100688b20e 100644 --- a/sota-implementations/grpo/grpo-sync.py +++ b/sota-implementations/grpo/grpo-sync.py @@ -160,8 +160,11 @@ def train( model_metadata = vLLMUpdater.get_model_metadata(policy_training) # Create weight updater with remote LLM + ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY") weight_updater: vLLMUpdater = make_weight_updater( - master_address="localhost", # Since we're running locally + master_address="localhost" + if not ray_managed_externally + else ray.util.get_node_ip_address(), master_port=None, # Will auto-assign an open port model_metadata=model_metadata, vllm_tp_size=cfg.inference_model.num_devices @@ -436,7 +439,11 @@ def main(cfg): ray_init_config["runtime_env"]["env_vars"] ) torchrl_logger.info(f"Ray init config: {ray_init_config=}") - ray.init(**ray_init_config) + ray_managed_externally = os.environ.get("RAY_CLUSTER_MANAGED_EXTERNALLY") + if ray_managed_externally: + ray.init(address="auto") + else: + ray.init(**ray_init_config) # Check if num_devices is set if cfg.inference_model.num_devices is None: diff --git a/sota-implementations/grpo/run_in_ray_cluster.sh b/sota-implementations/grpo/run_in_ray_cluster.sh new file mode 100644 index 00000000000..5325e26737c --- /dev/null +++ b/sota-implementations/grpo/run_in_ray_cluster.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +set -euo pipefail + +# Get command from argument +CMD="$1" + +# Set up Ray cluster configuration +HEAD_NODE=$(scontrol show hostname "$SLURM_NODELIST" | head -n 1) +RAY_PORT=6379 + +# Get current node name +CURRENT_NODE=$(hostname | cut -d. -f1) + +# Get HEAD_NODE_IP +if [ "$SLURM_NODEID" -eq 0 ]; then + # We're on the head node, get our own IP + HEAD_NODE_IP=$(hostname -I | awk '{print $1}') +else + # We're on a worker, resolve the head node's IP using DNS + HEAD_NODE_IP=$(getent hosts "$HEAD_NODE" | awk '{print $1}') +fi + +# Set up cleanup function +cleanup() { + if command -v ray &>/dev/null; then + echo "Stopping Ray on node $CURRENT_NODE" + ray stop || true + fi +} +trap cleanup EXIT + +# Start Ray based on node role +if [ "$SLURM_NODEID" -eq 0 ]; then + echo "Starting Ray head node on $CURRENT_NODE" + ray start --head --disable-usage-stats --port=$RAY_PORT + echo "Ray head node started at $HEAD_NODE_IP:$RAY_PORT" +else + echo "Waiting for head node to be ready..." + sleep 10 + echo "Starting Ray worker on node $CURRENT_NODE (ID: $SLURM_NODEID)" + ray start --disable-usage-stats --address="$HEAD_NODE_IP:$RAY_PORT" +fi + +# Ensure Ray cluster is ready +sleep 2 + +# Only head node runs the training command +if [ "$SLURM_NODEID" -eq 0 ]; then + echo "Starting training process on head node $CURRENT_NODE" + bash -c "$CMD" +else + # Worker nodes just wait for the head to finish + while ray status --address="$HEAD_NODE_IP:$RAY_PORT" &>/dev/null; do + sleep 10 + done +fi + +echo "Node $CURRENT_NODE: Done"