@@ -37,6 +37,8 @@ class ConcreteSizeEntry:
37
37
# during capture, and check if they are the same during replay
38
38
input_addresses : Optional [list [int ]] = None
39
39
40
+ usage_type : Optional [str ] = None
41
+
40
42
41
43
class CUDAPiecewiseBackend :
42
44
@@ -96,6 +98,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
96
98
runtime_shape = shape ,
97
99
need_to_compile = shape in self .compile_sizes ,
98
100
use_cudagraph = shape in self .cudagraph_capture_sizes ,
101
+ usage_type = "piecewise(general)" , # for logging only
99
102
)
100
103
101
104
def check_for_ending_compilation (self ):
@@ -139,27 +142,32 @@ def __call__(self, *args) -> Any:
139
142
self .check_for_ending_compilation ()
140
143
141
144
# Skip CUDA graphs if this entry doesn't use them OR
142
- # if we're supposed to skip them globally
143
- skip_cuda_graphs = get_forward_context ().skip_cuda_graphs
144
- if not entry .use_cudagraph or skip_cuda_graphs :
145
+ # if we're supposed to treat the piecewise graphs as a whole,
146
+ # which implies forward_context.skip_attention_cuda_graphs is False.
147
+ # In the latter case, we rely on a wrapper class to capture
148
+ # the full cudagraph outside the fx graph.
149
+ skip_attention_cuda_graphs = get_forward_context ().skip_attention_cuda_graphs
150
+ if not entry .use_cudagraph or not skip_attention_cuda_graphs :
145
151
return entry .runnable (* args )
146
152
147
153
if entry .cudagraph is None :
148
154
if entry .num_finished_warmup < self .compilation_config .cudagraph_num_of_warmups : # noqa
149
155
entry .num_finished_warmup += 1
150
156
if self .is_first_graph :
151
157
logger .debug (
152
- "Warming up %s/%s for shape %s" ,
158
+ "Warming up %s/%s of %s usage for shape %s" ,
153
159
entry .num_finished_warmup ,
154
160
self .compilation_config .cudagraph_num_of_warmups ,
161
+ entry .usage_type ,
155
162
runtime_shape )
156
163
return entry .runnable (* args )
157
164
158
165
if self .is_first_graph :
159
166
# Since we capture cudagraph for many different shapes and
160
167
# capturing is fast, we don't need to log it for every shape.
161
168
# We only log it in the debug mode.
162
- logger .debug ("Capturing a cudagraph for shape %s" ,
169
+ logger .debug ("Capturing a cudagraph of %s usage for shape %s" ,
170
+ entry .usage_type ,
163
171
runtime_shape )
164
172
165
173
input_addresses = [
@@ -216,3 +224,137 @@ def __call__(self, *args) -> Any:
216
224
217
225
entry .cudagraph .replay ()
218
226
return entry .output
227
+
228
+
229
+ class FullCudagraphWrapper :
230
+ def __init__ (self , graph : fx .GraphModule , vllm_config : VllmConfig ,
231
+ graph_pool : Any , sym_shape_indices : list [int ],
232
+ ):
233
+ self .graph = graph
234
+ self .vllm_config = vllm_config
235
+ self .compilation_config = vllm_config .compilation_config
236
+ self .graph_pool = graph_pool
237
+ self .sym_shape_indices = sym_shape_indices
238
+
239
+ self .separate_attention_routine = vllm_config .compilation_config .separate_attention_routine
240
+
241
+ self .is_debugging_mode = envs .VLLM_LOGGING_LEVEL == "DEBUG"
242
+
243
+ self .first_run_finished = False
244
+
245
+ self .cudagraph_capture_sizes : set [int ] = set (
246
+ self .compilation_config .cudagraph_capture_sizes
247
+ ) if self .compilation_config .use_cudagraph else set ()
248
+
249
+ self .concrete_size_entries : dict [int , ConcreteSizeEntry ] = {}
250
+ self .concrete_size_entries_decode : dict [int , ConcreteSizeEntry ] = {}
251
+
252
+
253
+ for shape in self .cudagraph_capture_sizes :
254
+ self .concrete_size_entries [shape ] = ConcreteSizeEntry (
255
+ runtime_shape = shape ,
256
+ need_to_compile = False ,
257
+ use_cudagraph = True ,
258
+ usage_type = "general" ,
259
+ )
260
+ if self .separate_attention_routine :
261
+ self .concrete_size_entries_decode [shape ] = ConcreteSizeEntry (
262
+ runtime_shape = shape ,
263
+ need_to_compile = False ,
264
+ use_cudagraph = True ,
265
+ usage_type = "decode" ,
266
+ )
267
+
268
+ def __call__ (self , * args ) -> Any :
269
+ if not self .first_run_finished :
270
+ self .first_run_finished = True
271
+ return self .graph (* args )
272
+ list_args = list (args )
273
+ runtime_shape = list_args [self .sym_shape_indices [0 ]].shape [0 ]
274
+ forward_context = get_forward_context ()
275
+
276
+ if forward_context .skip_attention_cuda_graphs :
277
+ # turn back to piecewise cudagraphs backend, which is responsible
278
+ # for capturing and running the piecewise cudagraphs.
279
+ return self .graph (* args )
280
+
281
+ # if not skip, the fx graph and its sub-graphs will only be supposed to
282
+ # eagerly run the compiled graphs, which should be cudagraph capturable
283
+ # as a whole.
284
+
285
+ concrete_size_entries = self .concrete_size_entries # default as general usage
286
+ if self .separate_attention_routine and forward_context .is_pure_decoding :
287
+ concrete_size_entries = self .concrete_size_entries_decode
288
+
289
+ if not runtime_shape in concrete_size_entries :
290
+ # we don't need to do anything for this shape.
291
+ return self .graph (* args )
292
+
293
+ entry = concrete_size_entries [runtime_shape ]
294
+
295
+ if entry .runnable is None :
296
+ entry .runnable = self .graph
297
+
298
+ if not entry .use_cudagraph :
299
+ return entry .runnable (* args )
300
+
301
+ if entry .cudagraph is None :
302
+ if entry .num_finished_warmup < self .compilation_config .cudagraph_num_of_warmups : # noqa
303
+ entry .num_finished_warmup += 1
304
+ logger .debug (
305
+ "Warming up %s/%s of %s usage for shape %s" ,
306
+ entry .num_finished_warmup ,
307
+ self .compilation_config .cudagraph_num_of_warmups ,
308
+ entry .usage_type ,
309
+ runtime_shape )
310
+ return entry .runnable (* args )
311
+
312
+
313
+ # Since we capture cudagraph for many different shapes and
314
+ # capturing is fast, we don't need to log it for every shape.
315
+ # We only log it in the debug mode.
316
+
317
+ logger .debug ("Capturing a cudagraph of %s usage for shape %s" ,
318
+ entry .usage_type ,
319
+ runtime_shape )
320
+
321
+ input_addresses = [
322
+ x .data_ptr () for x in args if isinstance (x , torch .Tensor )
323
+ ]
324
+ entry .input_addresses = input_addresses
325
+ cudagraph = torch .cuda .CUDAGraph ()
326
+
327
+ with ExitStack () as stack :
328
+ # mind-exploding: carefully manage the reference and memory.
329
+ with torch .cuda .graph (cudagraph , pool = self .graph_pool ):
330
+ # `output` is managed by pytorch's cudagraph pool
331
+ output = entry .runnable (* args )
332
+ # by converting it to weak ref,
333
+ # the original `output` will immediately be released
334
+ # to save memory.
335
+ output = weak_ref_tensors (output )
336
+
337
+ # here we always use weak ref for the output
338
+ # to save memory
339
+ entry .output = weak_ref_tensors (output )
340
+ entry .cudagraph = cudagraph
341
+
342
+ compilation_counter .num_cudagraph_captured += 1
343
+
344
+ # important: we need to return the output, rather than
345
+ # the weak ref of the output, so that pytorch can correctly
346
+ # manage the memory during cuda graph capture
347
+ return output
348
+
349
+ if self .is_debugging_mode :
350
+ # check if the input addresses are the same
351
+ new_input_addresses = [
352
+ x .data_ptr () for x in args if isinstance (x , torch .Tensor )
353
+ ]
354
+ assert new_input_addresses == entry .input_addresses , (
355
+ "Input addresses for cudagraphs are different during replay."
356
+ f" Expected { entry .input_addresses } , got { new_input_addresses } "
357
+ )
358
+
359
+ entry .cudagraph .replay ()
360
+ return entry .output
0 commit comments