From c27dd08e2f4853df1dfd0e5a39d874d31111b509 Mon Sep 17 00:00:00 2001 From: Mandar Deshpande Date: Thu, 15 May 2025 14:45:16 -0700 Subject: [PATCH] Allow TMA benchmarks for flex-attention kernel Summary: This diff adds a new argument `--use-tma` to the `operator.py` file in the `flex_attention` directory of the `tritonbench` repository. This argument allows users to enable Tensor Memory Access (TMA) in kernel options for flex-attention benchmarks. **Changes:** * Added `--use-tma` argument to the `parse_args` function in `operator.py` * Modified the `parse_args` function to store the `--use-tma` value in the `args` object Differential Revision: D74839480 --- .../operators/flex_attention/operator.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/tritonbench/operators/flex_attention/operator.py b/tritonbench/operators/flex_attention/operator.py index 0a83266b..0f7e3a83 100644 --- a/tritonbench/operators/flex_attention/operator.py +++ b/tritonbench/operators/flex_attention/operator.py @@ -88,6 +88,9 @@ def parse_op_args(args: List[str]): "--sliding-window-size", type=int, default=128, help="sliding window size" ) parser.add_argument("--prefix-length", type=int, default=128, help="prefix length") + parser.add_argument( + "--use-tma", action="store_true", help="Whether to enable TMA in kernel options" + ) return parser.parse_args(args) @@ -109,6 +112,7 @@ def __init__( self.mod_type = args.mod_type self.sliding_window_size = args.sliding_window_size self.prefix_length = args.prefix_length + self.use_tma = args.use_tma def get_input_iter(self) -> Generator: """Generate a single input configuration for benchmarking.""" @@ -327,12 +331,15 @@ def get_full_shape( "Inputs must be either two torch.Tensor objects or two shape tuples" ) - @staticmethod - def get_kernel_options(attn_type: str, shape: FullShape): + @classmethod + def get_kernel_options(cls, attn_type: str, shape: FullShape): + """ + Get kernel options for the given attention type and shape. + This method can be called as either a class method or an instance method. + """ B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 - # TODO add ways to specify TMA and warp spec - # "ENABLE_TMA": True, + # Get base kernel options kernel_opt_training_dict = { "noop": None, "causal": None, @@ -379,12 +386,20 @@ def get_default_split_k(B: int, H: int, Mk: int) -> int: "softcap": {"SPLIT_KV": get_default_split_k(B, Hkv, N) * 2}, } - return ( + # Get base options + base_options = ( kernel_opt_decoding_dict[attn_type] if is_decoding else kernel_opt_training_dict[attn_type] ) + # Add USE_TMA if enabled and base_options exists + # Check if this is being called from an instance or the class + if hasattr(cls, "use_tma") and cls.use_tma and base_options is not None: + base_options = {**base_options, "USE_TMA": True} + + return base_options + @staticmethod def generate_block_mask( attn_type: str, shape: FullShape, sliding_window_size: int, prefix_length: int