You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/design/v1/torch_compile.md
+7-3Lines changed: 7 additions & 3 deletions
Original file line number
Diff line number
Diff line change
@@ -99,7 +99,9 @@ This time, Inductor compilation is completely bypassed, and we will load from di
99
99
100
100
The above example just uses Inductor to compile for a general shape (i.e. symbolic shape). We can also use Inductor to compile for some of the specific shapes, for example:
Then it will also compile a specific kernel just for batch size `1, 2, 4, 8`. At this time, all of the shapes in the computation graph are static and known, and we will turn on auto-tuning to tune for max performance. This can be slow when you run it for the first time, but the next time you run it, we can directly bypass the tuning and run the tuned kernel.
105
107
@@ -134,12 +136,14 @@ The cudagraphs are captured and managed by the compiler backend, and replayed wh
134
136
135
137
By default, vLLM will try to determine a set of sizes to capture cudagraph. You can also override it using the config `cudagraph_capture_sizes`:
Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture.
140
144
141
145
### Full Cudagraph capture
142
146
143
-
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"`
147
+
It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config '{"full_cuda_graph": true}'`.
144
148
145
149
Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled.
0 commit comments