|
11 | 11 | from tritonbench.operators.gemm.kernels import matmul as kernels
|
12 | 12 | from tritonbench.operators.gemm.partition_k import matmul_partition_k
|
13 | 13 | from tritonbench.operators.gemm.stream_k import streamk_matmul
|
| 14 | +from tritonbench.operators.gemm.warp_spec_persistent_matmul import ( |
| 15 | + blackwell_matmul_descriptor_persistent, |
| 16 | + blackwell_matmul_tma, |
| 17 | + blackwell_matmul_tma_persistent, |
| 18 | +) |
14 | 19 | from tritonbench.utils.data_utils import get_production_shapes
|
15 |
| -from tritonbench.utils.env_utils import is_cuda, is_fbcode, supports_tma |
| 20 | +from tritonbench.utils.env_utils import ( |
| 21 | + get_nvidia_gpu_model, |
| 22 | + is_cuda, |
| 23 | + is_fbcode, |
| 24 | + supports_tma, |
| 25 | +) |
16 | 26 |
|
17 | 27 | from tritonbench.utils.path_utils import REPO_PATH
|
18 | 28 |
|
|
94 | 104 | for k in [4096 * i for i in range(1, 9)]
|
95 | 105 | ]
|
96 | 106 |
|
| 107 | +IS_B200 = is_cuda() and get_nvidia_gpu_model() == "NVIDIA B200" |
| 108 | + |
97 | 109 |
|
98 | 110 | @contextlib.contextmanager
|
99 | 111 | def set_env_variable(key, value):
|
@@ -350,6 +362,74 @@ def decompose_func(a_in, b_in):
|
350 | 362 | else:
|
351 | 363 | return lambda: compiled_decompose_k(a, b)
|
352 | 364 |
|
| 365 | + if IS_B200: |
| 366 | + |
| 367 | + @register_benchmark(enabled=False) |
| 368 | + def triton_blackwell_warpspec_persistent_matmul(self, a, b, bias) -> Callable: |
| 369 | + if bias is not None: |
| 370 | + return ( |
| 371 | + lambda: blackwell_matmul_tma_persistent(a, b, warp_specialize=True) |
| 372 | + + bias |
| 373 | + ) |
| 374 | + else: |
| 375 | + return lambda: blackwell_matmul_tma_persistent( |
| 376 | + a, b, warp_specialize=True |
| 377 | + ) |
| 378 | + |
| 379 | + @register_benchmark(enabled=False) |
| 380 | + def triton_blackwell_persistent_matmul(self, a, b, bias) -> Callable: |
| 381 | + if bias is not None: |
| 382 | + return ( |
| 383 | + lambda: blackwell_matmul_tma_persistent(a, b, warp_specialize=False) |
| 384 | + + bias |
| 385 | + ) |
| 386 | + else: |
| 387 | + return lambda: blackwell_matmul_tma_persistent( |
| 388 | + a, b, warp_specialize=False |
| 389 | + ) |
| 390 | + |
| 391 | + @register_benchmark(enabled=False) |
| 392 | + def triton_blackwell_warpspec_tma_matmul(self, a, b, bias) -> Callable: |
| 393 | + if bias is not None: |
| 394 | + return lambda: blackwell_matmul_tma(a, b, warp_specialize=True) + bias |
| 395 | + else: |
| 396 | + return lambda: blackwell_matmul_tma(a, b, warp_specialize=True) |
| 397 | + |
| 398 | + @register_benchmark(enabled=False) |
| 399 | + def triton_blackwell_tma_matmul(self, a, b, bias) -> Callable: |
| 400 | + if bias is not None: |
| 401 | + return lambda: blackwell_matmul_tma(a, b, warp_specialize=False) + bias |
| 402 | + else: |
| 403 | + return lambda: blackwell_matmul_tma(a, b, warp_specialize=False) |
| 404 | + |
| 405 | + @register_benchmark(enabled=False) |
| 406 | + def triton_blackwell_warpspec_descriptor_matmul(self, a, b, bias) -> Callable: |
| 407 | + if bias is not None: |
| 408 | + return ( |
| 409 | + lambda: blackwell_matmul_descriptor_persistent( |
| 410 | + a, b, warp_specialize=True |
| 411 | + ) |
| 412 | + + bias |
| 413 | + ) |
| 414 | + else: |
| 415 | + return lambda: blackwell_matmul_descriptor_persistent( |
| 416 | + a, b, warp_specialize=True |
| 417 | + ) |
| 418 | + |
| 419 | + @register_benchmark(enabled=False) |
| 420 | + def triton_blackwell_descriptor_matmul(self, a, b, bias) -> Callable: |
| 421 | + if bias is not None: |
| 422 | + return ( |
| 423 | + lambda: blackwell_matmul_descriptor_persistent( |
| 424 | + a, b, warp_specialize=False |
| 425 | + ) |
| 426 | + + bias |
| 427 | + ) |
| 428 | + else: |
| 429 | + return lambda: blackwell_matmul_descriptor_persistent( |
| 430 | + a, b, warp_specialize=False |
| 431 | + ) |
| 432 | + |
353 | 433 | @register_x_val(label="(M, N, K)")
|
354 | 434 | def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
|
355 | 435 | # x-value: computation intensity
|
|
0 commit comments