Skip to content

Allow TMA benchmarks for flex-attention kernel #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions tritonbench/operators/flex_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def parse_op_args(args: List[str]):
"--sliding-window-size", type=int, default=128, help="sliding window size"
)
parser.add_argument("--prefix-length", type=int, default=128, help="prefix length")
parser.add_argument(
"--use-tma", action="store_true", help="Whether to enable TMA in kernel options"
)
return parser.parse_args(args)


Expand All @@ -109,6 +112,7 @@ def __init__(
self.mod_type = args.mod_type
self.sliding_window_size = args.sliding_window_size
self.prefix_length = args.prefix_length
self.use_tma = args.use_tma

def get_input_iter(self) -> Generator:
"""Generate a single input configuration for benchmarking."""
Expand Down Expand Up @@ -327,12 +331,15 @@ def get_full_shape(
"Inputs must be either two torch.Tensor objects or two shape tuples"
)

@staticmethod
def get_kernel_options(attn_type: str, shape: FullShape):
@classmethod
def get_kernel_options(cls, attn_type: str, shape: FullShape):
"""
Get kernel options for the given attention type and shape.
This method can be called as either a class method or an instance method.
"""
B, Hq, M, Hkv, N, D = shape
is_decoding = M == 1
# TODO add ways to specify TMA and warp spec
# "ENABLE_TMA": True,
# Get base kernel options
kernel_opt_training_dict = {
"noop": None,
"causal": None,
Expand Down Expand Up @@ -379,12 +386,20 @@ def get_default_split_k(B: int, H: int, Mk: int) -> int:
"softcap": {"SPLIT_KV": get_default_split_k(B, Hkv, N) * 2},
}

return (
# Get base options
base_options = (
kernel_opt_decoding_dict[attn_type]
if is_decoding
else kernel_opt_training_dict[attn_type]
)

# Add USE_TMA if enabled and base_options exists
# Check if this is being called from an instance or the class
if hasattr(cls, "use_tma") and cls.use_tma and base_options is not None:
base_options = {**base_options, "USE_TMA": True}

return base_options

@staticmethod
def generate_block_mask(
attn_type: str, shape: FullShape, sliding_window_size: int, prefix_length: int
Expand Down
Loading