diff --git a/tritonbench/operators/flex_attention/operator.py b/tritonbench/operators/flex_attention/operator.py index 0a83266b..0f7e3a83 100644 --- a/tritonbench/operators/flex_attention/operator.py +++ b/tritonbench/operators/flex_attention/operator.py @@ -88,6 +88,9 @@ def parse_op_args(args: List[str]): "--sliding-window-size", type=int, default=128, help="sliding window size" ) parser.add_argument("--prefix-length", type=int, default=128, help="prefix length") + parser.add_argument( + "--use-tma", action="store_true", help="Whether to enable TMA in kernel options" + ) return parser.parse_args(args) @@ -109,6 +112,7 @@ def __init__( self.mod_type = args.mod_type self.sliding_window_size = args.sliding_window_size self.prefix_length = args.prefix_length + self.use_tma = args.use_tma def get_input_iter(self) -> Generator: """Generate a single input configuration for benchmarking.""" @@ -327,12 +331,15 @@ def get_full_shape( "Inputs must be either two torch.Tensor objects or two shape tuples" ) - @staticmethod - def get_kernel_options(attn_type: str, shape: FullShape): + @classmethod + def get_kernel_options(cls, attn_type: str, shape: FullShape): + """ + Get kernel options for the given attention type and shape. + This method can be called as either a class method or an instance method. + """ B, Hq, M, Hkv, N, D = shape is_decoding = M == 1 - # TODO add ways to specify TMA and warp spec - # "ENABLE_TMA": True, + # Get base kernel options kernel_opt_training_dict = { "noop": None, "causal": None, @@ -379,12 +386,20 @@ def get_default_split_k(B: int, H: int, Mk: int) -> int: "softcap": {"SPLIT_KV": get_default_split_k(B, Hkv, N) * 2}, } - return ( + # Get base options + base_options = ( kernel_opt_decoding_dict[attn_type] if is_decoding else kernel_opt_training_dict[attn_type] ) + # Add USE_TMA if enabled and base_options exists + # Check if this is being called from an instance or the class + if hasattr(cls, "use_tma") and cls.use_tma and base_options is not None: + base_options = {**base_options, "USE_TMA": True} + + return base_options + @staticmethod def generate_block_mask( attn_type: str, shape: FullShape, sliding_window_size: int, prefix_length: int