Skip to content

Commit ccd96e8

Browse files
committed
Toggle for v1 attention
1 parent 1466c79 commit ccd96e8

File tree

4 files changed

+15
-6
lines changed

4 files changed

+15
-6
lines changed

vllm/attention/ops/chunked_prefill_paged_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ def chunked_prefill_paged_decode(
269269
# Conversion of FP8 Tensor from uint8 storage to
270270
# appropriate torch.dtype for interpretation by Triton
271271
if "fp8" in kv_cache_dtype:
272-
assert key_cache.dtype == torch.uint8
273-
assert value_cache.dtype == torch.uint8
272+
assert key_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
273+
assert value_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
274274

275275
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
276276
target_dtype = current_platform.fp8_dtype()

vllm/attention/ops/prefix_prefill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -749,8 +749,8 @@ def context_attention_fwd(q,
749749
# Conversion of FP8 Tensor from uint8 storage to
750750
# appropriate torch.dtype for interpretation by Triton
751751
if "fp8" in kv_cache_dtype:
752-
assert (k_cache.dtype == torch.uint8)
753-
assert (v_cache.dtype == torch.uint8)
752+
assert k_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
753+
assert v_cache.dtype in [torch.uint8, current_platform.fp8_dtype()]
754754

755755
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
756756
target_dtype = current_platform.fp8_dtype()

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
VLLM_USE_TRITON_FLASH_ATTN: bool = True
2121
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True
2222
VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False
23+
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
2324
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
2425
LOCAL_RANK: int = 0
2526
CUDA_VISIBLE_DEVICES: Optional[str] = None
@@ -331,6 +332,13 @@ def get_vllm_port() -> Optional[int]:
331332
lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in
332333
("true", "1")),
333334

335+
# Use separate prefill and decode kernels for V1 attention instead of
336+
# the unified triton kernel.
337+
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION":
338+
lambda:
339+
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
340+
("true", "1")),
341+
334342
# Internal flag to enable/disable Inductor standalone compile
335343
"VLLM_TEST_STANDALONE_COMPILE":
336344
lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0",

vllm/v1/attention/backends/triton_attn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66

77
from vllm import _custom_ops as ops
8+
from vllm import envs
89
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
910
AttentionMetadata, AttentionType)
1011
from vllm.attention.ops.chunked_prefill_paged_decode import (
@@ -167,8 +168,8 @@ def forward(
167168
# performance to make sure it does not introduce any overhead.
168169

169170
num_queries_per_kv = query.shape[1] // key.shape[1]
170-
use_prefill_decode_attn = (num_queries_per_kv &
171-
(num_queries_per_kv - 1)) != 0
171+
use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION or (
172+
(num_queries_per_kv & (num_queries_per_kv - 1)) != 0)
172173

173174
num_actual_tokens = attn_metadata.num_actual_tokens
174175

0 commit comments

Comments
 (0)