Skip to content

Commit 89cc4cf

Browse files
njriasanfacebook-github-bot
authored andcommitted
Include blackwell warpspec persistent matmul in TritonBench
Summary: Adds the various blackwell matmuls that can support warpspec from https://github.com/triton-lang/triton/blob/main/python/tutorials/09-persistent-matmul.py to TritonBench so we benchmark them on OmniFm shapes. Note: At this point I made the explicit choice to avoid making warp_spec an autotune configuration. The reasoning here is to both simplify benchmarking as we won't need to "determine" if/when warp spec is useful/how useful and because there might be bugs in warpspec. Reviewed By: PaulZhang12 Differential Revision: D77053488 fbshipit-source-id: 9a3c3135e4646b71192b6e61d815d7043c7ee885
1 parent 095b6b1 commit 89cc4cf

File tree

2 files changed

+521
-1
lines changed

2 files changed

+521
-1
lines changed

tritonbench/operators/gemm/operator.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,18 @@
1111
from tritonbench.operators.gemm.kernels import matmul as kernels
1212
from tritonbench.operators.gemm.partition_k import matmul_partition_k
1313
from tritonbench.operators.gemm.stream_k import streamk_matmul
14+
from tritonbench.operators.gemm.warp_spec_persistent_matmul import (
15+
blackwell_matmul_descriptor_persistent,
16+
blackwell_matmul_tma,
17+
blackwell_matmul_tma_persistent,
18+
)
1419
from tritonbench.utils.data_utils import get_production_shapes
15-
from tritonbench.utils.env_utils import is_cuda, is_fbcode, supports_tma
20+
from tritonbench.utils.env_utils import (
21+
get_nvidia_gpu_model,
22+
is_cuda,
23+
is_fbcode,
24+
supports_tma,
25+
)
1626

1727
from tritonbench.utils.path_utils import REPO_PATH
1828

@@ -94,6 +104,8 @@
94104
for k in [4096 * i for i in range(1, 9)]
95105
]
96106

107+
IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200"
108+
97109

98110
@contextlib.contextmanager
99111
def set_env_variable(key, value):
@@ -350,6 +362,74 @@ def decompose_func(a_in, b_in):
350362
else:
351363
return lambda: compiled_decompose_k(a, b)
352364

365+
if IS_B200:
366+
367+
@register_benchmark(enabled=False)
368+
def triton_blackwell_warpspec_persistent_matmul(self, a, b, bias) -> Callable:
369+
if bias is not None:
370+
return (
371+
lambda: blackwell_matmul_tma_persistent(a, b, warp_specialize=True)
372+
+ bias
373+
)
374+
else:
375+
return lambda: blackwell_matmul_tma_persistent(
376+
a, b, warp_specialize=True
377+
)
378+
379+
@register_benchmark(enabled=False)
380+
def triton_blackwell_persistent_matmul(self, a, b, bias) -> Callable:
381+
if bias is not None:
382+
return (
383+
lambda: blackwell_matmul_tma_persistent(a, b, warp_specialize=False)
384+
+ bias
385+
)
386+
else:
387+
return lambda: blackwell_matmul_tma_persistent(
388+
a, b, warp_specialize=False
389+
)
390+
391+
@register_benchmark(enabled=False)
392+
def triton_blackwell_warpspec_tma_matmul(self, a, b, bias) -> Callable:
393+
if bias is not None:
394+
return lambda: blackwell_matmul_tma(a, b, warp_specialize=True) + bias
395+
else:
396+
return lambda: blackwell_matmul_tma(a, b, warp_specialize=True)
397+
398+
@register_benchmark(enabled=False)
399+
def triton_blackwell_tma_matmul(self, a, b, bias) -> Callable:
400+
if bias is not None:
401+
return lambda: blackwell_matmul_tma(a, b, warp_specialize=False) + bias
402+
else:
403+
return lambda: blackwell_matmul_tma(a, b, warp_specialize=False)
404+
405+
@register_benchmark(enabled=False)
406+
def triton_blackwell_warpspec_descriptor_matmul(self, a, b, bias) -> Callable:
407+
if bias is not None:
408+
return (
409+
lambda: blackwell_matmul_descriptor_persistent(
410+
a, b, warp_specialize=True
411+
)
412+
+ bias
413+
)
414+
else:
415+
return lambda: blackwell_matmul_descriptor_persistent(
416+
a, b, warp_specialize=True
417+
)
418+
419+
@register_benchmark(enabled=False)
420+
def triton_blackwell_descriptor_matmul(self, a, b, bias) -> Callable:
421+
if bias is not None:
422+
return (
423+
lambda: blackwell_matmul_descriptor_persistent(
424+
a, b, warp_specialize=False
425+
)
426+
+ bias
427+
)
428+
else:
429+
return lambda: blackwell_matmul_descriptor_persistent(
430+
a, b, warp_specialize=False
431+
)
432+
353433
@register_x_val(label="(M, N, K)")
354434
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
355435
# x-value: computation intensity

0 commit comments

Comments
 (0)