@@ -46,13 +46,9 @@ class ConcreteSizeEntry:
46
46
# Output buffer of cudagraph
47
47
output_buffer : Optional [paddle .Tensor ] = None
48
48
49
- # for cudagraph debugging, track the input addresses
50
- # during capture, and check if they are the same during replay
51
- input_addresses : Optional [list [int ]] = None
52
-
53
49
54
50
class CudaGraphPiecewiseBackend :
55
- """ """
51
+ """ Manage the capture and replay of CUDA graphs at the subgraph level. """
56
52
57
53
def __init__ (
58
54
self ,
@@ -65,33 +61,31 @@ def __init__(
65
61
self .warm_up_size = fd_config .graph_opt_config .cudagraph_num_of_warmups
66
62
self .batch_size_to_captured_size = fd_config .graph_opt_config .batch_size_to_captured_size
67
63
68
- # runtime_bs -> ConcreteSizeEntry
64
+ # Runtime batch size -> ConcreteSizeEntry
69
65
self .concrete_size_entries : Dict [int , ConcreteSizeEntry ] = {}
70
66
71
67
for shape in self .cudagraph_capture_sizes :
72
68
self .concrete_size_entries [shape ] = ConcreteSizeEntry (
73
69
runtime_bs = shape )
74
70
75
- print ("[CUDA GRAPH] Created all batch size entry " )
71
+ logger . debug ("[CUDA GRAPH] Created all batch size entry " )
76
72
77
73
def __call__ (self , ** kwargs ):
78
74
# Get batch size
79
75
ids_remove_padding : paddle .Tensor = kwargs ["ids_remove_padding" ]
80
76
batch_size = ids_remove_padding .shape [0 ]
81
-
82
77
padding_batch_size = self .batch_size_to_captured_size [batch_size ]
83
- # print(
84
- # f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{batch_size}, ",
85
- # f"The padded batch size is :{padding_batch_size}"
86
- # )
78
+ logger .debug (
79
+ f"[CUDA GRAPH] The actual batch size obtained by CUDAGraph is :{ batch_size } , " ,
80
+ f"The padded batch size is :{ padding_batch_size } " )
87
81
88
82
entry = self .concrete_size_entries .get (padding_batch_size )
89
83
assert entry is not None , f"Batch size:{ padding_batch_size } is not in cuda graph capture list."
90
84
if entry .runnable is None :
91
85
entry .runnable = self .runnable
92
- # print (
93
- # f"[CUDA GRAPH] New entry lazy initialize with batch size {padding_batch_size}"
94
- # )
86
+ logger . debug (
87
+ f"[CUDA GRAPH] New entry lazy initialize with batch size { padding_batch_size } "
88
+ )
95
89
96
90
if not entry .use_cudagraph :
97
91
return entry .runnable (** kwargs )
@@ -102,10 +96,10 @@ def __call__(self, **kwargs):
102
96
for n in range (entry .num_finished_warmup , self .warm_up_size ):
103
97
entry .num_finished_warmup += 1
104
98
entry .runnable (** kwargs )
105
- # print (
106
- # "[CUDA GRAPH] Warm up for batch size ",
107
- # f"{padding_batch_size}, finished ({n+1}/{entry.num_finished_warmup}) times"
108
- # )
99
+ logger . debug (
100
+ "[CUDA GRAPH] Warm up for batch size " ,
101
+ f"{ padding_batch_size } , finished ({ n + 1 } /{ entry .num_finished_warmup } ) times"
102
+ )
109
103
110
104
# Store input addresses for debug
111
105
input_addresses = [
@@ -129,11 +123,13 @@ def __call__(self, **kwargs):
129
123
output ._clear
130
124
131
125
paddle .device .synchronize ()
132
- # print (
133
- # f"[CUDA GRAPH] CUDAGraph captured for batch size {padding_batch_size}"
134
- # )
126
+ logger . debug (
127
+ f"[CUDA GRAPH] CUDAGraph captured for batch size { padding_batch_size } "
128
+ )
135
129
136
130
# Replay
137
131
entry .cuda_graph .replay ()
138
- # print(f"[CUDA GRAPH] CUDAGraph replayed for batch size {padding_batch_size}")
132
+ logger .debug (
133
+ f"[CUDA GRAPH] CUDAGraph replayed for batch size { padding_batch_size } "
134
+ )
139
135
return entry .output_buffer
0 commit comments