Skip to content

Commit 509724d

Browse files
Randy Shuaifacebook-github-bot
authored andcommitted
Enrich auto-tune shapes for OC OBA model (#4368)
Summary: Pull Request resolved: #4368 X-link: facebookresearch/FBGEMM#1436 Adding two shapes for the OC OBA model to better the perf using the triton fp8 non-persistent kernel, as suggested by the [log](https://www.internalfb.com/intern/paste/P1843910441/). By which the triton kernel almost on-pars the torch rowwise: |fp8 kernel|Flops|Time per iter|QPS |pytorch rowwise|304.07|35.23ms|87205.84 |triton(without added shapes)|292.83|36.58ms|83982.15 |triton(with the added shapes)|302.63|35.39ms|86793.45 Reviewed By: njriasan, karthik-man Differential Revision: D76631650 fbshipit-source-id: 4a1324302f5ca635d801242d0cca205a33f41c94
1 parent d04b7fd commit 509724d

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3418,6 +3418,34 @@ def get_full_non_persistent_tuning_space():
34183418
num_warps=8,
34193419
num_stages=2,
34203420
),
3421+
triton.Config(
3422+
{
3423+
"BLOCK_M": 128,
3424+
"BLOCK_N": 128,
3425+
"BLOCK_K": 64,
3426+
"GROUP_M": 4,
3427+
"SPLIT_K": 1,
3428+
"waves_per_eu": 2,
3429+
"matrix_instr_nonkdim": 16,
3430+
"kpack": 2,
3431+
},
3432+
num_warps=4,
3433+
num_stages=2,
3434+
),
3435+
triton.Config(
3436+
{
3437+
"BLOCK_M": 128,
3438+
"BLOCK_N": 64,
3439+
"BLOCK_K": 64,
3440+
"GROUP_M": 4,
3441+
"SPLIT_K": 1,
3442+
"waves_per_eu": 0,
3443+
"matrix_instr_nonkdim": 16,
3444+
"kpack": 2,
3445+
},
3446+
num_warps=4,
3447+
num_stages=2,
3448+
),
34213449
]
34223450

34233451
# Set this to enable full autotuning for proper benchmarking.

0 commit comments

Comments
 (0)