Skip to content

Commit 8b6e1d6

Browse files
Zzz9990fsx950223charlifu
authored
[Hardware][AMD] integrate aiter chunked prefill into vllm (#18596)
Signed-off-by: fsx950223 <fsx950223@outlook.com> Signed-off-by: charlifu <charlifu@amd.com> Co-authored-by: fsx950223 <fsx950223@outlook.com> Co-authored-by: charlifu <charlifu@amd.com>
1 parent 735a9de commit 8b6e1d6

File tree

3 files changed

+602
-3
lines changed

3 files changed

+602
-3
lines changed

vllm/envs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
VLLM_ROCM_USE_AITER_MOE: bool = True
8888
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
8989
VLLM_ROCM_USE_AITER_MLA: bool = True
90+
VLLM_ROCM_USE_AITER_MHA: bool = True
9091
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
9192
VLLM_ROCM_FP8_PADDING: bool = True
9293
VLLM_ROCM_MOE_PADDING: bool = True
@@ -653,6 +654,13 @@ def get_vllm_port() -> Optional[int]:
653654
"VLLM_ROCM_USE_AITER_MLA":
654655
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MLA", "True").lower() in
655656
("true", "1")),
657+
658+
# Whether to use aiter mha ops.
659+
# By default is enabled.
660+
"VLLM_ROCM_USE_AITER_MHA":
661+
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
662+
("true", "1")),
663+
656664
# use rocm skinny gemms
657665
"VLLM_ROCM_USE_SKINNY_GEMM":
658666
lambda: (os.getenv("VLLM_ROCM_USE_SKINNY_GEMM", "True").lower() in

vllm/platforms/rocm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
215215
selected_backend = _Backend.ROCM_FLASH
216216

217217
if envs.VLLM_USE_V1:
218-
logger.info("Using Triton Attention backend on V1 engine.")
219-
return ("vllm.v1.attention.backends."
220-
"triton_attn.TritonAttentionBackend")
218+
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA \
219+
and on_gfx9():
220+
logger.info("Using Flash Attention backend on V1 engine.")
221+
return ("vllm.v1.attention.backends."
222+
"rocm_aiter_fa.AiterFlashAttentionBackend")
223+
else:
224+
logger.info("Using Triton Attention backend on V1 engine.")
225+
return ("vllm.v1.attention.backends."
226+
"triton_attn.TritonAttentionBackend")
221227
if selected_backend == _Backend.ROCM_FLASH:
222228
if not cls.has_device_capability(90):
223229
# not Instinct series GPUs.

0 commit comments

Comments
 (0)