@@ -158,12 +158,13 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
158
158
159
159
self .aot_schedule = (get_flash_attn_version () == 3 )
160
160
self .use_full_cuda_graph = compilation_config .full_cuda_graph
161
- if self .use_full_cuda_graph and not self .aot_schedule :
162
- raise ValueError ("Full CUDA graph mode requires AOT scheduling, "
163
- "which requires FlashAttention 3." )
164
- self .scheduler_metadata = torch .zeros (self .runner .max_num_reqs + 1 ,
165
- dtype = torch .int32 ,
166
- device = self .runner .device )
161
+ if self .use_full_cuda_graph :
162
+ # NOTE(lucas): AOT scheduling not supported in full cuda graph mode
163
+ # yet. This is because the scheduler and kernel need to always use
164
+ # the same num_splits (which acts as an upper bound with the
165
+ # dynamic split scheduler) which is currently heuristically decided
166
+ # by the kernel launching code.
167
+ self .aot_schedule = False
167
168
168
169
# Sliding window size to be used with the AOT scheduler will be
169
170
# populated on first build() call.
@@ -299,18 +300,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
299
300
max_seq_len = max_seq_len ,
300
301
causal = True )
301
302
302
- if self .use_full_cuda_graph :
303
- assert scheduler_metadata is not None
304
- n = scheduler_metadata .shape [0 ]
305
- self .scheduler_metadata [:n ].copy_ (scheduler_metadata ,
306
- non_blocking = True )
307
- # NOTE(woosuk): We should zero out the rest of the scheduler
308
- # metadata to guarantee the correctness. Otherwise, some thread
309
- # blocks may use the invalid scheduler metadata and overwrite the
310
- # output buffer.
311
- self .scheduler_metadata [n :] = 0
312
- scheduler_metadata = self .scheduler_metadata [:n ]
313
-
314
303
attn_metadata = FlashAttentionMetadata (
315
304
num_actual_tokens = num_actual_tokens ,
316
305
max_query_len = max_query_len ,
0 commit comments