1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
"""
4
- a simple demonstration of RLHF with vLLM, inspired by
5
- the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
6
- It follows the design that, training processes and inference processes
7
- are different, and they live on different GPUs.
8
- Training processes send prompts to inference processes to generate data,
9
- and also synchronize the weights of the model by broadcasting the weights
10
- from the training process to the inference process.
11
- Note that this is a simple demonstration of one training instance and one
12
- inference instance. In practice, there could be multiple training instances
13
- and multiple inference instances. For the full implementation, please refer
14
- to the OpenRLHF framework.
4
+ Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
5
+
6
+ The script separates training and inference workloads onto distinct GPUs
7
+ so that Ray can manage process placement and inter-process communication.
8
+ A Hugging Face Transformer model occupies GPU 0 for training, whereas a
9
+ tensor-parallel vLLM inference engine occupies GPU 1–2.
10
+
11
+ The example performs the following steps:
12
+
13
+ * Load the training model on GPU 0.
14
+ * Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
15
+ and Ray placement groups.
16
+ * Generate text from a list of prompts using the inference engine.
17
+ * Update the weights of the training model and broadcast the updated weights
18
+ to the inference engine by using a Ray collective RPC group. Note that
19
+ for demonstration purposes we simply zero out the weights.
20
+
21
+ For a production-ready implementation that supports multiple training and
22
+ inference replicas, see the OpenRLHF framework:
23
+ https://github.com/OpenRLHF/OpenRLHF
24
+
25
+ This example assumes a single-node cluster with three GPUs, but Ray
26
+ supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
27
+ workloads. Residual GPU activity interferes with vLLM memory profiling and
28
+ causes unexpected behavior.
15
29
"""
16
30
17
31
import os
28
42
29
43
30
44
class MyLLM (LLM ):
45
+ """Configure the vLLM worker for Ray placement group execution."""
46
+
31
47
def __init__ (self , * args , ** kwargs ):
32
- # a hack to make the script work.
33
- # stop ray from manipulating CUDA_VISIBLE_DEVICES
34
- # at the top-level
48
+ # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
49
+ # so that vLLM can manage its own device placement within the worker.
35
50
os .environ .pop ("CUDA_VISIBLE_DEVICES" , None )
36
51
super ().__init__ (* args , ** kwargs )
37
52
38
53
39
- """
40
- Start the training process, here we use huggingface transformers
41
- as an example to hold a model on GPU 0.
42
- """
43
-
54
+ # Load the OPT-125M model onto GPU 0 for the training workload.
44
55
train_model = AutoModelForCausalLM .from_pretrained ("facebook/opt-125m" )
45
56
train_model .to ("cuda:0" )
46
- """
47
- Start the inference process, here we use vLLM to hold a model on GPU 1 and
48
- GPU 2. For the details on how to use ray, please refer to the ray
49
- documentation https://docs.ray.io/en/latest/ .
50
- """
57
+
58
+ # Initialize Ray and set the visible devices. The vLLM engine will
59
+ # be placed on GPUs 1 and 2.
51
60
os .environ ["CUDA_VISIBLE_DEVICES" ] = "1,2"
52
61
ray .init ()
53
62
63
+ # Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
64
+ # Learn more about Ray placement groups:
65
+ # https://docs.ray.io/en/latest/placement-groups.html
54
66
pg_inference = placement_group ([{"GPU" : 1 , "CPU" : 0 }] * 2 )
55
67
ray .get (pg_inference .ready ())
56
68
scheduling_inference = PlacementGroupSchedulingStrategy (
57
69
placement_group = pg_inference ,
58
70
placement_group_capture_child_tasks = True ,
59
71
placement_group_bundle_index = 0 ,
60
72
)
61
- """
62
- launch the vLLM inference engine.
63
- here we use `enforce_eager` to reduce the start time.
64
- """
73
+
74
+ # Launch the vLLM inference engine. The `enforce_eager` flag reduces
75
+ # start-up latency.
65
76
llm = ray .remote (
66
77
num_cpus = 0 ,
67
78
num_gpus = 0 ,
@@ -74,7 +85,7 @@ def __init__(self, *args, **kwargs):
74
85
distributed_executor_backend = "ray" ,
75
86
)
76
87
77
- # Generate texts from the prompts.
88
+ # Generate text from the prompts.
78
89
prompts = [
79
90
"Hello, my name is" ,
80
91
"The president of the United States is" ,
@@ -93,8 +104,8 @@ def __init__(self, *args, **kwargs):
93
104
print (f"Prompt: { prompt !r} \n Generated text: { generated_text !r} " )
94
105
print ("-" * 50 )
95
106
96
- # set up the communication between the training process
97
- # and the inference engine.
107
+ # Set up the communication channel between the training process and the
108
+ # inference engine.
98
109
master_address = get_ip ()
99
110
master_port = get_open_port ()
100
111
@@ -107,21 +118,23 @@ def __init__(self, *args, **kwargs):
107
118
)
108
119
ray .get (handle )
109
120
110
- # simulate training, modify the weights of the model.
121
+ # Simulate a training step by zeroing out all model weights.
122
+ # In a real RLHF training loop the weights would be updated using the gradient
123
+ # from an RL objective such as PPO on a reward model.
111
124
for name , p in train_model .named_parameters ():
112
125
p .data .zero_ ()
113
126
114
- # sync weight from the training process to the inference engine.
127
+ # Synchronize the updated weights to the inference engine.
115
128
for name , p in train_model .named_parameters ():
116
129
handle = llm .collective_rpc .remote ("update_weight" , args = (name , p .dtype , p .shape ))
117
130
model_update_group .broadcast (p , src = 0 , stream = torch .cuda .current_stream ())
118
131
ray .get (handle )
119
132
120
- # check if the weights are updated.
133
+ # Verify that the inference weights have been updated.
121
134
assert all (ray .get (llm .collective_rpc .remote ("check_weights_changed" )))
122
135
123
- # use the updated model to generate texts, they will be nonsense
124
- # because the weights are all zeros .
136
+ # Generate text with the updated model. The output is expected to be nonsense
137
+ # because the weights are zero .
125
138
outputs_updated = ray .get (llm .generate .remote (prompts , sampling_params ))
126
139
print ("-" * 50 )
127
140
for output in outputs_updated :
0 commit comments