Skip to content

Commit 235bfd5

Browse files
authored
[Docs] Improve documentation for RLHF example (#20598)
Signed-off-by: Ricardo Decal <rdecal@anyscale.com>
1 parent 68d28e3 commit 235bfd5

File tree

1 file changed

+49
-36
lines changed

1 file changed

+49
-36
lines changed

examples/offline_inference/rlhf.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,31 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""
4-
a simple demonstration of RLHF with vLLM, inspired by
5-
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
6-
It follows the design that, training processes and inference processes
7-
are different, and they live on different GPUs.
8-
Training processes send prompts to inference processes to generate data,
9-
and also synchronize the weights of the model by broadcasting the weights
10-
from the training process to the inference process.
11-
Note that this is a simple demonstration of one training instance and one
12-
inference instance. In practice, there could be multiple training instances
13-
and multiple inference instances. For the full implementation, please refer
14-
to the OpenRLHF framework.
4+
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
5+
6+
The script separates training and inference workloads onto distinct GPUs
7+
so that Ray can manage process placement and inter-process communication.
8+
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
9+
tensor-parallel vLLM inference engine occupies GPU 1–2.
10+
11+
The example performs the following steps:
12+
13+
* Load the training model on GPU 0.
14+
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
15+
and Ray placement groups.
16+
* Generate text from a list of prompts using the inference engine.
17+
* Update the weights of the training model and broadcast the updated weights
18+
to the inference engine by using a Ray collective RPC group. Note that
19+
for demonstration purposes we simply zero out the weights.
20+
21+
For a production-ready implementation that supports multiple training and
22+
inference replicas, see the OpenRLHF framework:
23+
https://github.com/OpenRLHF/OpenRLHF
24+
25+
This example assumes a single-node cluster with three GPUs, but Ray
26+
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
27+
workloads. Residual GPU activity interferes with vLLM memory profiling and
28+
causes unexpected behavior.
1529
"""
1630

1731
import os
@@ -28,40 +42,37 @@
2842

2943

3044
class MyLLM(LLM):
45+
"""Configure the vLLM worker for Ray placement group execution."""
46+
3147
def __init__(self, *args, **kwargs):
32-
# a hack to make the script work.
33-
# stop ray from manipulating CUDA_VISIBLE_DEVICES
34-
# at the top-level
48+
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
49+
# so that vLLM can manage its own device placement within the worker.
3550
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
3651
super().__init__(*args, **kwargs)
3752

3853

39-
"""
40-
Start the training process, here we use huggingface transformers
41-
as an example to hold a model on GPU 0.
42-
"""
43-
54+
# Load the OPT-125M model onto GPU 0 for the training workload.
4455
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
4556
train_model.to("cuda:0")
46-
"""
47-
Start the inference process, here we use vLLM to hold a model on GPU 1 and
48-
GPU 2. For the details on how to use ray, please refer to the ray
49-
documentation https://docs.ray.io/en/latest/ .
50-
"""
57+
58+
# Initialize Ray and set the visible devices. The vLLM engine will
59+
# be placed on GPUs 1 and 2.
5160
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
5261
ray.init()
5362

63+
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
64+
# Learn more about Ray placement groups:
65+
# https://docs.ray.io/en/latest/placement-groups.html
5466
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
5567
ray.get(pg_inference.ready())
5668
scheduling_inference = PlacementGroupSchedulingStrategy(
5769
placement_group=pg_inference,
5870
placement_group_capture_child_tasks=True,
5971
placement_group_bundle_index=0,
6072
)
61-
"""
62-
launch the vLLM inference engine.
63-
here we use `enforce_eager` to reduce the start time.
64-
"""
73+
74+
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
75+
# start-up latency.
6576
llm = ray.remote(
6677
num_cpus=0,
6778
num_gpus=0,
@@ -74,7 +85,7 @@ def __init__(self, *args, **kwargs):
7485
distributed_executor_backend="ray",
7586
)
7687

77-
# Generate texts from the prompts.
88+
# Generate text from the prompts.
7889
prompts = [
7990
"Hello, my name is",
8091
"The president of the United States is",
@@ -93,8 +104,8 @@ def __init__(self, *args, **kwargs):
93104
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
94105
print("-" * 50)
95106

96-
# set up the communication between the training process
97-
# and the inference engine.
107+
# Set up the communication channel between the training process and the
108+
# inference engine.
98109
master_address = get_ip()
99110
master_port = get_open_port()
100111

@@ -107,21 +118,23 @@ def __init__(self, *args, **kwargs):
107118
)
108119
ray.get(handle)
109120

110-
# simulate training, modify the weights of the model.
121+
# Simulate a training step by zeroing out all model weights.
122+
# In a real RLHF training loop the weights would be updated using the gradient
123+
# from an RL objective such as PPO on a reward model.
111124
for name, p in train_model.named_parameters():
112125
p.data.zero_()
113126

114-
# sync weight from the training process to the inference engine.
127+
# Synchronize the updated weights to the inference engine.
115128
for name, p in train_model.named_parameters():
116129
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
117130
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
118131
ray.get(handle)
119132

120-
# check if the weights are updated.
133+
# Verify that the inference weights have been updated.
121134
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
122135

123-
# use the updated model to generate texts, they will be nonsense
124-
# because the weights are all zeros.
136+
# Generate text with the updated model. The output is expected to be nonsense
137+
# because the weights are zero.
125138
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
126139
print("-" * 50)
127140
for output in outputs_updated:

0 commit comments

Comments
 (0)