Skip to content

Commit 0a82d3d

Browse files
manman-renfacebook-github-bot
authored andcommitted
add persistent variant (#77)
Summary: Hongtao identified the performance issue with the initial implementation and updated the assignments of tiles to each SM. Performance with warp specialization (Batch, Heads, SeqLen, Dhead) triton_tutorial_flash_v2_tma_ws_persistent-tflops triton_tutorial_flash_v2_tma_ws-tflops triton_tutorial_flash_v2-tflops ------------------------------- --------------------------------------------------- ---------------------------------------- --------------------------------- (8, 16, 8192, 128) 516.164 490.451 423.905 Pull Request resolved: #77 Reviewed By: xuzhao9, htyu Differential Revision: D66463179 Pulled By: manman-ren fbshipit-source-id: 14fecc1a1449828bfd82600bd161596349da3084
1 parent 7e1f269 commit 0a82d3d

File tree

4 files changed

+306
-10
lines changed

4 files changed

+306
-10
lines changed

test/test_gpu/skip_tests_h100_pytorch.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ flash_attention:
1616
- triton_tutorial_flash_v2_tma
1717
- triton_tutorial_flash_v2_ws
1818
- triton_tutorial_flash_v2_tma_ws
19+
- triton_tutorial_flash_v2_tma_ws_persistent
1920
fp8_attention:
2021
- colfax_fmha
2122
# triton_flash_v2 requires triton-main

test/test_gpu/skip_tests_h100_triton_main.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ flash_attention:
1010
# _ws kernels require Triton with warp specialization
1111
- triton_tutorial_flash_v2_ws
1212
- triton_tutorial_flash_v2_tma_ws
13+
- triton_tutorial_flash_v2_tma_ws_persistent
1314
fp8_attention:
1415
# fb-only kernel
1516
- colfax_fmha

0 commit comments

Comments
 (0)