Skip to content

[WIP] 20250722 benchmark sweep #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 155 additions & 55 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,55 +26,64 @@
from typing import Callable

# Maps tritonbench op names to Helion kernel examples
KERNEL_MAPPINGS: dict[str, tuple[str, str, str]] = {
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
"embedding": (
"tritonbench.operators.embedding.operator",
"examples.embedding",
"embedding_tritonbench",
),
"vector_exp": (
"tritonbench.operators.vector_exp.operator",
"examples.exp",
"exp_tritonbench",
),
"rms_norm": (
"tritonbench.operators.rms_norm.operator",
"examples.rms_norm",
"rms_norm_tritonbench",
),
"sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"),
"softmax": (
"tritonbench.operators.softmax.operator",
"examples.softmax",
"softmax",
),
"jagged_mean": (
"tritonbench.operators.jagged_mean.operator",
"examples.jagged_mean",
"jagged_mean_tritonbench",
# Structure: {tritonbench_op_name: (tritonbench_module, helion_module, helion_func) or [(helion_module, helion_func), ...]}
KERNEL_MAPPINGS: dict[str, tuple[str, str, str] | tuple[str, list[tuple[str, str]]]] = {
# Single kernel mapping: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
# "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
# "embedding": (
# "tritonbench.operators.embedding.operator",
# "examples.embedding",
# "embedding_tritonbench",
# ),
# "vector_exp": (
# "tritonbench.operators.vector_exp.operator",
# "examples.exp",
# "exp_tritonbench",
# ),
# "rms_norm": (
# "tritonbench.operators.rms_norm.operator",
# "examples.rms_norm",
# "rms_norm_tritonbench",
# ),
# "sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"),
# "softmax": (
# "tritonbench.operators.softmax.operator",
# "examples.softmax",
# "softmax",
# ),
# "cross_entropy": (
# "tritonbench.operators.cross_entropy.operator",
# "examples.cross_entropy",
# "cross_entropy",
# ),
# "jagged_mean": (
# "tritonbench.operators.jagged_mean.operator",
# "examples.jagged_mean",
# "jagged_mean_tritonbench",
# ),
# Multiple kernel mappings: (<tritonbench_module_path>, [(<helion_module>, <helion_func>), ...])
"gemm": (
"tritonbench.operators.gemm.operator",
[
("examples.matmul", "matmul"),
("examples.matmul_split_k", "matmul_split_k"),
],
),
"fp8_gemm": (
"tritonbench.operators.fp8_gemm.fp8_gemm",
"examples.fp8_gemm",
"fp8_gemm_tritonbench",
),
"flash_attention": (
"tritonbench.operators.flash_attention.operator",
"examples.attention",
"attention",
),
"cross_entropy": (
"tritonbench.operators.cross_entropy.operator",
"examples.cross_entropy",
"cross_entropy",
),
"fp8_attention": (
"tritonbench.operators.fp8_attention.operator",
"examples.fp8_attention",
"fp8_attention_tritonbench",
),
# "flash_attention": (
# "tritonbench.operators.flash_attention.operator",
# "examples.attention",
# "attention",
# ),
# "fp8_attention": (
# "tritonbench.operators.fp8_attention.operator",
# "examples.fp8_attention",
# "fp8_attention_tritonbench",
# ),
}


Expand Down Expand Up @@ -195,7 +204,11 @@ def check_and_setup_tritonbench() -> None:
sys.exit(1)


def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
def run_kernel(
kernel_name: str,
tritonbench_args: list[str],
input_shard_info: tuple[int, int] | None = None,
) -> None:
"""Run a single kernel benchmark."""
# Check if kernel is in the mapping table
if kernel_name not in KERNEL_MAPPINGS:
Expand All @@ -204,8 +217,30 @@ def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}", file=sys.stderr
)
sys.exit(1)

mapping = KERNEL_MAPPINGS[kernel_name]

# Check if it's multiple variants or a single kernel
if len(mapping) == 2 and isinstance(mapping[1], list):
# Multiple variants with shared tritonbench module
tritonbench_module = mapping[0]
variants = mapping[1]
for i, (module_path, func_name) in enumerate(variants):
# Extract variant name from func_name for display
variant_name = func_name
if i > 0:
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"Kernel: {kernel_name} (variant: {variant_name})", file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
run_single_kernel_variant(kernel_name, tritonbench_module, module_path, func_name, tritonbench_args.copy(), variant_name, input_shard_info)
else:
# Single kernel with full mapping
tritonbench_module, module_path, func_name = mapping
run_single_kernel_variant(kernel_name, tritonbench_module, module_path, func_name, tritonbench_args, None, input_shard_info)


tritonbench_module, module_path, func_name = KERNEL_MAPPINGS[kernel_name]
def run_single_kernel_variant(kernel_name: str, tritonbench_module: str, module_path: str, func_name: str, tritonbench_args: list[str], variant_name: str | None = None, input_shard_info: tuple[int, int] | None = None) -> None:
"""Run a single kernel variant."""

# Import from the mapped module
try:
Expand Down Expand Up @@ -305,7 +340,10 @@ def _inner() -> Callable[..., Any] | object:
return _inner

# Method name for the benchmark
helion_method_name = f"helion_{kernel_name}"
if variant_name:
helion_method_name = f"helion_{kernel_name}_{variant_name}"
else:
helion_method_name = f"helion_{kernel_name}"

# Import register_benchmark API
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
Expand All @@ -325,14 +363,50 @@ def _inner() -> Callable[..., Any] | object:
# Set the decorated method on the Operator class
setattr(Operator, helion_method_name, decorated_method)

print(
f"Running {operator_name} benchmark with Helion implementation...\n",
file=sys.stderr,
)
if variant_name:
print(
f"Running {operator_name} benchmark with Helion implementation (variant: {variant_name})...\n",
file=sys.stderr,
)
else:
print(
f"Running {operator_name} benchmark with Helion implementation...\n",
file=sys.stderr,
)

# Create and run the operator with unknown args
op = Operator(tb_args=tb_args, extra_args=unknown_args)

# Handle input sharding if requested
if input_shard_info:
shard_idx, total_shards = input_shard_info

# Get the actual number of inputs for this operator
total_inputs = op._available_num_inputs

# Calculate shard boundaries
inputs_per_shard = total_inputs // total_shards
extra_inputs = total_inputs % total_shards

if shard_idx <= extra_inputs:
start_idx = (shard_idx - 1) * (inputs_per_shard + 1)
shard_size = inputs_per_shard + 1
else:
start_idx = (
extra_inputs * (inputs_per_shard + 1)
+ (shard_idx - 1 - extra_inputs) * inputs_per_shard
)
shard_size = inputs_per_shard

# Override the operator's input range
op._input_id = start_idx
op._num_inputs = shard_size

print(
f"Running input shard {shard_idx}/{total_shards}: inputs {start_idx} to {start_idx + shard_size - 1} (of {total_inputs} total)",
file=sys.stderr,
)

# Run with proper parameters
warmup = int(getattr(tb_args, "warmup", 25))
rep = int(getattr(tb_args, "iter", 100))
Expand All @@ -359,13 +433,37 @@ def main() -> None:
type=str,
help="Name(s) of the Helion kernel module(s) to run. Can be a single kernel or comma-separated list (e.g., vector_add or vector_add,rms_norm). If not specified, runs all kernels.",
)
parser.add_argument(
"--input-shard",
type=str,
help="Run only a subset of inputs for each kernel. Format: M/N where M is the shard number (1-indexed) and N is the total number of shards. For example, --input-shard 1/3 runs the first third of inputs for each kernel.",
)

# Parse known args to get the kernel name, pass rest to tritonbench
args, tritonbench_args = parser.parse_known_args()

# Check and setup tritonbench if needed
check_and_setup_tritonbench()

# Store input-shard info for later processing
input_shard_info = None
if args.input_shard:
try:
shard_idx, total_shards = map(int, args.input_shard.split("/"))
if shard_idx < 1 or shard_idx > total_shards:
print(
f"Error: Shard number {shard_idx} must be between 1 and {total_shards}",
file=sys.stderr,
)
sys.exit(1)
input_shard_info = (shard_idx, total_shards)
except ValueError:
print(
f"Error: Invalid input-shard format '{args.input_shard}'. Expected format: M/N (e.g., 1/3)",
file=sys.stderr,
)
sys.exit(1)

if args.kernel:
# Parse comma-separated kernel names
kernel_names = [k.strip() for k in args.kernel.split(",")]
Expand All @@ -385,7 +483,7 @@ def main() -> None:

# Run specified kernels
if len(kernel_names) == 1:
run_kernel(kernel_names[0], tritonbench_args)
run_kernel(kernel_names[0], tritonbench_args, input_shard_info)
else:
print(
f"Running {len(kernel_names)} kernels: {', '.join(kernel_names)}...\n",
Expand All @@ -395,15 +493,17 @@ def main() -> None:
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"Kernel: {kernel_name}", file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
run_kernel(kernel_name, tritonbench_args.copy())
run_kernel(kernel_name, tritonbench_args.copy(), input_shard_info)
else:
# Run all kernels
print(f"Running all {len(KERNEL_MAPPINGS)} kernels...\n", file=sys.stderr)
for kernel_name in KERNEL_MAPPINGS:
all_kernels = list(KERNEL_MAPPINGS.keys())

print(f"Running {len(all_kernels)} kernels...\n", file=sys.stderr)
for kernel_name in all_kernels:
print(f"\n{'=' * 60}", file=sys.stderr)
print(f"Kernel: {kernel_name}", file=sys.stderr)
print(f"{'=' * 60}\n", file=sys.stderr)
run_kernel(kernel_name, tritonbench_args.copy())
run_kernel(kernel_name, tritonbench_args.copy(), input_shard_info)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions benchmarks/run_input_shard.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[[ -z "$RANK_OFFSET" ]] && { echo "Error: RANK_OFFSET is not set"; exit 1; }
[[ -z "$SHARD" ]] && { echo "Error: SHARD is not set"; exit 1; }
[[ -z "$WORLD_SIZE" ]] && { echo "Error: WORLD_SIZE is not set"; exit 1; }
CUDA_VISIBLE_DEVICES=$((RANK_OFFSET+SHARD)) python benchmarks/run.py --input-shard $((SHARD+1))/${WORLD_SIZE} >benchmarks_autotune_$(date +%s)_input_shard_$((SHARD+1))_of_${WORLD_SIZE}.txt 2>&1
Loading