File tree Expand file tree Collapse file tree 4 files changed +15
-6
lines changed Expand file tree Collapse file tree 4 files changed +15
-6
lines changed Original file line number Diff line number Diff line change @@ -269,8 +269,8 @@ def chunked_prefill_paged_decode(
269
269
# Conversion of FP8 Tensor from uint8 storage to
270
270
# appropriate torch.dtype for interpretation by Triton
271
271
if "fp8" in kv_cache_dtype :
272
- assert key_cache .dtype == torch .uint8
273
- assert value_cache .dtype == torch .uint8
272
+ assert key_cache .dtype in [ torch .uint8 , current_platform . fp8_dtype ()]
273
+ assert value_cache .dtype in [ torch .uint8 , current_platform . fp8_dtype ()]
274
274
275
275
if kv_cache_dtype in ("fp8" , "fp8_e4m3" ):
276
276
target_dtype = current_platform .fp8_dtype ()
Original file line number Diff line number Diff line change @@ -749,8 +749,8 @@ def context_attention_fwd(q,
749
749
# Conversion of FP8 Tensor from uint8 storage to
750
750
# appropriate torch.dtype for interpretation by Triton
751
751
if "fp8" in kv_cache_dtype :
752
- assert ( k_cache .dtype == torch .uint8 )
753
- assert ( v_cache .dtype == torch .uint8 )
752
+ assert k_cache .dtype in [ torch .uint8 , current_platform . fp8_dtype ()]
753
+ assert v_cache .dtype in [ torch .uint8 , current_platform . fp8_dtype ()]
754
754
755
755
if kv_cache_dtype in ("fp8" , "fp8_e4m3" ):
756
756
target_dtype = current_platform .fp8_dtype ()
Original file line number Diff line number Diff line change 20
20
VLLM_USE_TRITON_FLASH_ATTN : bool = True
21
21
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT : bool = True
22
22
VLLM_USE_ROCM_FP8_FLASH_ATTN : bool = False
23
+ VLLM_V1_USE_PREFILL_DECODE_ATTENTION : bool = False
23
24
VLLM_FLASH_ATTN_VERSION : Optional [int ] = None
24
25
LOCAL_RANK : int = 0
25
26
CUDA_VISIBLE_DEVICES : Optional [str ] = None
@@ -331,6 +332,13 @@ def get_vllm_port() -> Optional[int]:
331
332
lambda : (os .getenv ("VLLM_USE_ROCM_FP8_FLASH_ATTN" , "False" ).lower () in
332
333
("true" , "1" )),
333
334
335
+ # Use separate prefill and decode kernels for V1 attention instead of
336
+ # the unified triton kernel.
337
+ "VLLM_V1_USE_PREFILL_DECODE_ATTENTION" :
338
+ lambda :
339
+ (os .getenv ("VLLM_V1_USE_PREFILL_DECODE_ATTENTION" , "False" ).lower () in
340
+ ("true" , "1" )),
341
+
334
342
# Internal flag to enable/disable Inductor standalone compile
335
343
"VLLM_TEST_STANDALONE_COMPILE" :
336
344
lambda : os .environ .get ("VLLM_TEST_STANDALONE_COMPILE" , "0" ) != "0" ,
Original file line number Diff line number Diff line change 5
5
import torch
6
6
7
7
from vllm import _custom_ops as ops
8
+ from vllm import envs
8
9
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
9
10
AttentionMetadata , AttentionType )
10
11
from vllm .attention .ops .chunked_prefill_paged_decode import (
@@ -167,8 +168,8 @@ def forward(
167
168
# performance to make sure it does not introduce any overhead.
168
169
169
170
num_queries_per_kv = query .shape [1 ] // key .shape [1 ]
170
- use_prefill_decode_attn = ( num_queries_per_kv &
171
- (num_queries_per_kv - 1 )) != 0
171
+ use_prefill_decode_attn = envs . VLLM_V1_USE_PREFILL_DECODE_ATTENTION or (
172
+ ( num_queries_per_kv & (num_queries_per_kv - 1 )) != 0 )
172
173
173
174
num_actual_tokens = attn_metadata .num_actual_tokens
174
175
You can’t perform that action at this time.
0 commit comments