Skip to content

Commit a11c0a6

Browse files
committed
[WIP] Improve autotune infra to catch more error cases
1 parent 41fe6e9 commit a11c0a6

File tree

4 files changed

+107
-12
lines changed

4 files changed

+107
-12
lines changed

benchmarks/run.py

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,37 @@ def main() -> None:
359359
type=str,
360360
help="Name(s) of the Helion kernel module(s) to run. Can be a single kernel or comma-separated list (e.g., vector_add or vector_add,rms_norm). If not specified, runs all kernels.",
361361
)
362+
parser.add_argument(
363+
"--split",
364+
type=str,
365+
help="Run only a subset of kernels. Format: M/N where M is the part number (1-indexed) and N is the total number of parts. For example, --split 1/3 runs the first third of kernels.",
366+
)
362367

363368
# Parse known args to get the kernel name, pass rest to tritonbench
364369
args, tritonbench_args = parser.parse_known_args()
365370

366371
# Check and setup tritonbench if needed
367372
check_and_setup_tritonbench()
368373

374+
# Parse split argument if provided
375+
part_num = None
376+
total_parts = None
377+
if args.split:
378+
try:
379+
part_num, total_parts = map(int, args.split.split("/"))
380+
if part_num < 1 or part_num > total_parts:
381+
print(
382+
f"Error: Part number {part_num} must be between 1 and {total_parts}",
383+
file=sys.stderr,
384+
)
385+
sys.exit(1)
386+
except ValueError:
387+
print(
388+
f"Error: Invalid split format '{args.split}'. Expected format: M/N (e.g., 1/3)",
389+
file=sys.stderr,
390+
)
391+
sys.exit(1)
392+
369393
if args.kernel:
370394
# Parse comma-separated kernel names
371395
kernel_names = [k.strip() for k in args.kernel.split(",")]
@@ -383,6 +407,31 @@ def main() -> None:
383407
)
384408
sys.exit(1)
385409

410+
# Apply split filtering if specified
411+
if args.split:
412+
# Calculate which kernels belong to this part
413+
kernels_per_part = len(kernel_names) // total_parts
414+
remainder = len(kernel_names) % total_parts
415+
416+
# Calculate start and end indices for this part
417+
if part_num <= remainder:
418+
# Parts 1 to remainder get one extra kernel
419+
start_idx = (part_num - 1) * (kernels_per_part + 1)
420+
end_idx = start_idx + kernels_per_part + 1
421+
else:
422+
# Remaining parts get the base number of kernels
423+
start_idx = (
424+
remainder * (kernels_per_part + 1)
425+
+ (part_num - remainder - 1) * kernels_per_part
426+
)
427+
end_idx = start_idx + kernels_per_part
428+
429+
kernel_names = kernel_names[start_idx:end_idx]
430+
print(
431+
f"Running part {part_num}/{total_parts}: kernels {start_idx + 1} to {end_idx} of total",
432+
file=sys.stderr,
433+
)
434+
386435
# Run specified kernels
387436
if len(kernel_names) == 1:
388437
run_kernel(kernel_names[0], tritonbench_args)
@@ -398,8 +447,35 @@ def main() -> None:
398447
run_kernel(kernel_name, tritonbench_args.copy())
399448
else:
400449
# Run all kernels
401-
print(f"Running all {len(KERNEL_MAPPINGS)} kernels...\n", file=sys.stderr)
402-
for kernel_name in KERNEL_MAPPINGS:
450+
all_kernels = list(KERNEL_MAPPINGS.keys())
451+
452+
# Apply split filtering if specified
453+
if args.split:
454+
# Calculate which kernels belong to this part
455+
kernels_per_part = len(all_kernels) // total_parts
456+
remainder = len(all_kernels) % total_parts
457+
458+
# Calculate start and end indices for this part
459+
if part_num <= remainder:
460+
# Parts 1 to remainder get one extra kernel
461+
start_idx = (part_num - 1) * (kernels_per_part + 1)
462+
end_idx = start_idx + kernels_per_part + 1
463+
else:
464+
# Remaining parts get the base number of kernels
465+
start_idx = (
466+
remainder * (kernels_per_part + 1)
467+
+ (part_num - remainder - 1) * kernels_per_part
468+
)
469+
end_idx = start_idx + kernels_per_part
470+
471+
all_kernels = all_kernels[start_idx:end_idx]
472+
print(
473+
f"Running part {part_num}/{total_parts}: kernels {start_idx + 1} to {end_idx} of {len(KERNEL_MAPPINGS)} total",
474+
file=sys.stderr,
475+
)
476+
477+
print(f"Running {len(all_kernels)} kernels...\n", file=sys.stderr)
478+
for kernel_name in all_kernels:
403479
print(f"\n{'=' * 60}", file=sys.stderr)
404480
print(f"Kernel: {kernel_name}", file=sys.stderr)
405481
print(f"{'=' * 60}\n", file=sys.stderr)

helion/_compiler/tile_dispatch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _add_reduction_strategies(self, fn: DeviceFunction, config: Config) -> None:
9494
reduction_loop = env.config_spec.reduction_loops.config_get(
9595
config.reduction_loops, block_id, None
9696
)
97-
if reduction_loop is None:
97+
if reduction_loop is None or reduction_loop <= 1:
9898
strategy: TileStrategy = PersistentReductionStrategy(fn, block_id)
9999
else:
100100
strategy = LoopedReductionStrategy(fn, block_id, reduction_loop)

helion/autotuner/base_search.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from torch._inductor.runtime.triton_compat import OutOfResources
2323
from torch._inductor.runtime.triton_compat import PTXASError
24+
from triton.compiler.errors import CompilationError
2425
import torch.multiprocessing as mp
2526
from triton.testing import do_bench
2627

@@ -43,7 +44,12 @@
4344
from . import ConfigSpec
4445

4546
_expected_errors_regexp: re.Pattern[str] = re.compile(
46-
r"|".join(map(re.escape, ["[CUDA]: invalid argument"]))
47+
r"|".join(
48+
map(
49+
re.escape,
50+
["[CUDA]: invalid argument", "exceeds triton maximum tensor numel"],
51+
)
52+
)
4753
)
4854

4955

@@ -88,10 +94,13 @@ def benchmark(self, config: Config) -> float:
8894
Returns:
8995
The performance of the configuration in seconds.
9096
"""
91-
fn = self.kernel.compile_config(config, allow_print=False)
92-
if self.start_precompile_and_check_for_hangs(config, fn)():
93-
return self.benchmark_function(config, fn)
94-
return inf
97+
try:
98+
fn = self.kernel.compile_config(config, allow_print=False)
99+
if self.start_precompile_and_check_for_hangs(config, fn)():
100+
return self.benchmark_function(config, fn)
101+
return inf
102+
except Exception as e:
103+
return inf
95104

96105
def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
97106
"""
@@ -125,8 +134,10 @@ def benchmark_function(self, config: Config, fn: CompiledConfig) -> float:
125134
self.log.debug("Benchmarking failed: OutOfResources")
126135
except PTXASError:
127136
self.log.warning(f"PTXASError compiling config: {config}")
137+
except CompilationError:
138+
self.log.debug("Benchmarking failed: Triton CompilationError")
128139
except Exception as e:
129-
if not _expected_errors_regexp.search(str(e)):
140+
if not _expected_errors_regexp.search(str(e)) and not "exceeds triton maximum tensor numel" in str(e):
130141
raise exc.TritonError(f"{type(e).__qualname__}: {e}", config) from e
131142
self.log.debug(f"Benchmarking failed: {type(e).__name__}: {e}")
132143
return inf
@@ -149,6 +160,8 @@ def start_precompile_and_check_for_hangs(
149160
"""
150161
if not self.settings.autotune_precompile:
151162
return PrecompileFuture.skip(self, config, True)
163+
if fn is None:
164+
return PrecompileFuture.skip(self, config, False)
152165
ctx = mp.get_context("fork")
153166

154167
def extract_launcher(
@@ -188,7 +201,13 @@ def parallel_benchmark(self, configs: list[Config]) -> list[tuple[Config, float]
188201
Returns:
189202
A list of tuples containing configurations and their performance.
190203
"""
191-
fns = [self.kernel.compile_config(c, allow_print=False) for c in configs]
204+
fns = []
205+
for c in configs:
206+
try:
207+
compile_result = self.kernel.compile_config(c, allow_print=False)
208+
fns.append(compile_result)
209+
except Exception as e:
210+
fns.append(None)
192211
if self.settings.autotune_precompile:
193212
is_workings = PrecompileFuture.wait_for_all(
194213
[

helion/autotuner/config_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,8 +411,8 @@ def _flat_config(
411411
default = min(high, 4096)
412412
value = fn(BlockSizeFragment(low, high, default))
413413
assert isinstance(value, int)
414-
if value >= self.size_hint:
415-
return None # max size becomes persistent reduction
414+
if value >= self.size_hint or value < low:
415+
return None # max size or invalid value becomes persistent reduction
416416
return value
417417

418418
def _normalize(self, name: str, value: object) -> int | None:

0 commit comments

Comments
 (0)