Skip to content

Commit 7aa722b

Browse files
authored
[Benchmark] Add initial TritonBench integration and vector_add benchmark example (#247)
ghstack-source-id: 8e1c5ae Pull-Request: #246
1 parent 902741b commit 7aa722b

File tree

3 files changed

+287
-0
lines changed

3 files changed

+287
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,4 @@ venv
8888
CLAUDE.md
8989
triton
9090
torch
91+
benchmark/tritonbench

benchmark/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
## Benchmarking
2+
3+
Performance comparison between Helion, torch.compile, Triton, and PyTorch eager is done by leveraging [TritonBench](https://github.com/pytorch-labs/tritonbench).
4+
5+
Currently supported kernels for performance comparison are in `benchmark/`.
6+
7+
To run the benchmark:
8+
9+
`$ python benchmark/run.py --metrics speedup,accuracy --kernel <kernel_name>`
10+
11+
e.g. for `vector_add` kernel:
12+
13+
`$ python benchmark/run.py --metrics speedup,accuracy --kernel vector_add`

benchmark/run.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
"""Performance comparison between Helion, torch.compile, Triton, and PyTorch eager by leveraging TritonBench.
2+
3+
Currently supported kernels are in `benchmark/`.
4+
5+
Usage:
6+
$ python benchmark/run.py [tritonbench args...] --kernel <kernel_name>
7+
8+
Example usage:
9+
$ python benchmark/run.py --metrics speedup,accuracy --kernel vector_add
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import argparse
15+
import importlib
16+
from pathlib import Path
17+
import subprocess
18+
import sys
19+
from typing import Any
20+
from typing import Callable
21+
22+
# Maps tritonbench op names to Helion kernel examples
23+
KERNEL_MAPPINGS: dict[str, tuple[str, str]] = {
24+
# <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>)
25+
"vector_add": ("examples.add", "add"),
26+
}
27+
28+
29+
def get_system_memory_gb() -> float:
30+
"""Get system memory in GB."""
31+
try:
32+
# Try to read from /proc/meminfo on Linux
33+
meminfo_path = Path("/proc/meminfo")
34+
if meminfo_path.exists():
35+
with open(meminfo_path) as f:
36+
for line in f:
37+
if line.startswith("MemTotal:"):
38+
# Extract memory in kB and convert to GB
39+
mem_kb = int(line.split()[1])
40+
return mem_kb / (1024 * 1024)
41+
42+
# Fallback: use psutil if available
43+
try:
44+
import psutil
45+
46+
return psutil.virtual_memory().total / (1024**3)
47+
except ImportError:
48+
pass
49+
50+
except Exception:
51+
pass
52+
53+
# Default to assuming high memory if we can't detect
54+
return 32.0
55+
56+
57+
def check_and_setup_tritonbench() -> None:
58+
"""Check if tritonbench is installed and install it from GitHub if not."""
59+
# Check if tritonbench is already installed
60+
try:
61+
import tritonbench
62+
63+
return # Already installed
64+
except ImportError:
65+
pass
66+
67+
print("Tritonbench not found. Installing...", file=sys.stderr)
68+
69+
# Clone to benchmark/tritonbench
70+
benchmark_dir = Path(__file__).parent
71+
tritonbench_path = benchmark_dir / "tritonbench"
72+
73+
try:
74+
# Clone the repository if it doesn't exist
75+
if not tritonbench_path.exists():
76+
print("Cloning tritonbench repository...", file=sys.stderr)
77+
subprocess.run(
78+
[
79+
"git",
80+
"clone",
81+
"https://github.com/pytorch-labs/tritonbench.git",
82+
str(tritonbench_path),
83+
],
84+
check=True,
85+
)
86+
87+
# Initialize submodules
88+
print("Initializing tritonbench's submodules...", file=sys.stderr)
89+
subprocess.run(
90+
["git", "submodule", "update", "--init", "--recursive"],
91+
cwd=tritonbench_path,
92+
check=True,
93+
)
94+
95+
# Detect system memory and choose install flags.
96+
# Low-memory systems can freeze when building dependencies like flash-attn,
97+
# so we only install the Liger library in that case.
98+
memory_gb = get_system_memory_gb()
99+
install_flag = "--liger" if memory_gb < 16 else "--all"
100+
101+
# Install optional dependencies for tritonbench
102+
print(
103+
f"Running install.py {install_flag} (detected {memory_gb:.1f}GB system RAM)...",
104+
file=sys.stderr,
105+
)
106+
subprocess.run(
107+
[sys.executable, "install.py", install_flag],
108+
cwd=tritonbench_path,
109+
check=True,
110+
)
111+
112+
# Install tritonbench package
113+
print("Installing tritonbench package...", file=sys.stderr)
114+
subprocess.run(
115+
[sys.executable, "-m", "pip", "install", "-e", str(tritonbench_path)],
116+
check=True,
117+
)
118+
119+
# Invalidate import caches to recognize newly installed package
120+
importlib.invalidate_caches()
121+
122+
# Verify installation worked
123+
try:
124+
import tritonbench # noqa: F401
125+
126+
print(
127+
f"Tritonbench installed successfully with {install_flag}.",
128+
file=sys.stderr,
129+
)
130+
except ImportError:
131+
print(
132+
"Error: Tritonbench package installation failed. The package cannot be imported.",
133+
file=sys.stderr,
134+
)
135+
sys.exit(1)
136+
137+
except subprocess.CalledProcessError as e:
138+
print(f"Error installing tritonbench: {e}", file=sys.stderr)
139+
if e.stdout:
140+
print(f"stdout: {e.stdout}", file=sys.stderr)
141+
if e.stderr:
142+
print(f"stderr: {e.stderr}", file=sys.stderr)
143+
sys.exit(1)
144+
145+
146+
def main() -> None:
147+
# Parse command line arguments
148+
parser = argparse.ArgumentParser(description="Run Helion kernels with tritonbench")
149+
parser.add_argument(
150+
"--kernel",
151+
type=str,
152+
required=True,
153+
help="Name of the Helion kernel module (e.g., vector_add)",
154+
)
155+
156+
# Parse known args to get the kernel name, pass rest to tritonbench
157+
args, tritonbench_args = parser.parse_known_args()
158+
159+
# Check and setup tritonbench if needed
160+
check_and_setup_tritonbench()
161+
162+
kernel_name = args.kernel
163+
164+
# Check if kernel is in the mapping table
165+
assert kernel_name in KERNEL_MAPPINGS
166+
module_path, func_name = KERNEL_MAPPINGS[kernel_name]
167+
# Import from the mapped module
168+
try:
169+
module = importlib.import_module(module_path)
170+
if not hasattr(module, func_name):
171+
print(
172+
f"Error: Module '{module_path}' does not have a function named '{func_name}'",
173+
file=sys.stderr,
174+
)
175+
sys.exit(1)
176+
kernel_func = getattr(module, func_name)
177+
except ImportError as e:
178+
print(
179+
f"Error: Could not import {func_name} from {module_path}", file=sys.stderr
180+
)
181+
print(f"Import error: {e}", file=sys.stderr)
182+
sys.exit(1)
183+
return
184+
185+
# Import tritonbench components
186+
try:
187+
from tritonbench.utils.parser import get_parser # pyre-ignore[21]
188+
except ImportError:
189+
print(
190+
"Error: Could not import tritonbench. Make sure it's in the path.",
191+
file=sys.stderr,
192+
)
193+
sys.exit(1)
194+
195+
# Get the tritonbench operator name (assume it's the same as the kernel name)
196+
operator_name = kernel_name
197+
198+
# Parse tritonbench arguments
199+
tb_parser = get_parser()
200+
201+
assert "--op" not in tritonbench_args
202+
tritonbench_args = ["--op", operator_name, *tritonbench_args]
203+
204+
tb_args = tb_parser.parse_args(tritonbench_args)
205+
206+
# Register the Helion kernel with tritonbench BEFORE importing the operator
207+
from tritonbench.utils.triton_op import ( # pyre-ignore[21]
208+
register_benchmark_mannually,
209+
)
210+
211+
# Create the benchmark method
212+
def create_helion_method( # pyre-ignore[3]
213+
kernel_func: Callable[..., Any], # pyre-ignore[2]
214+
) -> Callable[..., Any]:
215+
def helion_method( # pyre-ignore[3]
216+
self: Any, # pyre-ignore[2]
217+
*args: Any,
218+
) -> Callable[..., Any]:
219+
"""Helion implementation."""
220+
221+
def _inner() -> Callable[..., Any]: # pyre-ignore[3]
222+
return kernel_func(*args)
223+
224+
return _inner
225+
226+
return helion_method
227+
228+
# Register it as a benchmark first
229+
helion_method_name = f"helion_{kernel_name}"
230+
register_benchmark_mannually(
231+
operator_name=operator_name,
232+
func_name=helion_method_name,
233+
baseline=False,
234+
enabled=True,
235+
label=helion_method_name,
236+
)
237+
238+
# Import and run the operator
239+
operator_module_name = f"tritonbench.operators.{operator_name}.operator"
240+
try:
241+
operator_module = importlib.import_module(operator_module_name)
242+
Operator = operator_module.Operator
243+
except ImportError:
244+
print(
245+
f"Error: Could not import operator '{operator_name}' from tritonbench",
246+
file=sys.stderr,
247+
)
248+
sys.exit(1)
249+
return
250+
251+
# Monkey-patch the Operator class after import
252+
setattr(Operator, helion_method_name, create_helion_method(kernel_func))
253+
254+
print(
255+
f"Running {operator_name} benchmark with Helion implementation...\n",
256+
file=sys.stderr,
257+
)
258+
259+
# Create and run the operator
260+
op = Operator(tb_args=tb_args, extra_args={})
261+
262+
# Run with proper parameters
263+
warmup = getattr(tb_args, "warmup", 25)
264+
rep = getattr(tb_args, "iter", 100)
265+
op.run(warmup=warmup, rep=rep)
266+
267+
# Print results
268+
print("\nBenchmark Results:", file=sys.stderr)
269+
print(op.output, file=sys.stderr)
270+
271+
272+
if __name__ == "__main__":
273+
main()

0 commit comments

Comments
 (0)