Skip to content

Commit 345fb5c

Browse files
authored
Remove make_precompiler generated wrapper (#314)
1 parent dc88062 commit 345fb5c

30 files changed

+877
-2761
lines changed

helion/_compiler/device_function.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,15 @@ def __init__(self, name: str, config: Config, codegen: GenerateAST) -> None:
181181
self.pid: ProgramIDs | None = None
182182
self.namespace: _Namespace = _Namespace()
183183
self.namespace._used_names.update(reserved_names())
184+
self.namespace._used_names.update(
185+
# used by triton run() method
186+
[
187+
"grid",
188+
"warmup",
189+
"num_warps",
190+
"num_stages",
191+
]
192+
)
184193
self._variable_renames: dict[str, list[str]] = {}
185194
self.dce_vars: list[str] = []
186195
self.block_size_var_cache: dict[tuple[int, ...], str] = {}
@@ -448,7 +457,7 @@ def codegen_function_call(self) -> ast.AST:
448457
assert pid is not None
449458
# TODO(jansel): we should run CSE this statement
450459
call_statement = statement_from_string(
451-
f"{self.name}[__call_grid_expr]({', '.join(args)})",
460+
f"_launcher({self.name}, __call_grid_expr, {', '.join(args)})",
452461
__call_grid_expr=pid.codegen_grid(),
453462
)
454463
assert isinstance(call_statement, ExtendedAST)

helion/_compiler/generate_ast.py

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from .. import exc
1212
from ..language._decorators import is_api_func
13-
from ..runtime.precompile_shim import make_precompiler
1413
from .ast_extension import ExtendedAST
1514
from .ast_extension import LoopType
1615
from .ast_extension import NodeVisitor
@@ -367,68 +366,6 @@ def has_mask(self) -> bool:
367366
)
368367

369368

370-
def codegen_precompile_def(
371-
host_def: ast.FunctionDef, device_function_name: str
372-
) -> ast.FunctionDef:
373-
"""
374-
Generate a precompile function definition for the given host function.
375-
The precompile function is the same as the normal function, but the call to the
376-
kernel is replaced with a call to make_precompiler.
377-
378-
Args:
379-
host_def: The host function definition to that is used to call the kernel.
380-
device_function_name: The name of the device function to be called.
381-
382-
Returns:
383-
A transformed function definition with the kernel call replaced.
384-
"""
385-
386-
def transform(node: ExtendedAST) -> ExtendedAST:
387-
nonlocal found_calls
388-
assert not node._is_kernel_call
389-
fields = node.fields()
390-
for key, value in [*fields.items()]:
391-
if isinstance(value, list):
392-
new_list = []
393-
for item in value:
394-
assert isinstance(item, ExtendedAST)
395-
if item._is_kernel_call:
396-
with item:
397-
found_calls += 1
398-
new_list.append(
399-
statement_from_string(
400-
f"from {make_precompiler.__module__} import make_precompiler"
401-
)
402-
)
403-
assert isinstance(item, ast.Expr)
404-
value = item.value
405-
assert isinstance(value, ExtendedAST)
406-
new_list.append(
407-
create(
408-
ast.Return,
409-
value=value.copy(
410-
func=expr_from_string(
411-
f"make_precompiler({device_function_name})"
412-
)
413-
),
414-
)
415-
)
416-
break
417-
new_list.append(transform(item))
418-
fields[key] = new_list
419-
elif isinstance(value, ExtendedAST):
420-
fields[key] = transform(value)
421-
return node.new(fields)
422-
423-
found_calls = 0
424-
assert isinstance(host_def, ExtendedAST)
425-
new_fn = transform(host_def)
426-
assert isinstance(new_fn, ast.FunctionDef)
427-
new_fn.name = f"_{host_def.name}_make_precompiler"
428-
assert found_calls == 1
429-
return new_fn
430-
431-
432369
def generate_ast(func: HostFunction, config: Config) -> ast.AST:
433370
with func:
434371
codegen = GenerateAST(func, config)
@@ -438,16 +375,13 @@ def generate_ast(func: HostFunction, config: Config) -> ast.AST:
438375
kernel_def = codegen.device_function.codegen_function_def()
439376
codegen.host_dead_code_elimination()
440377
host_def = func.codegen_function_def(codegen.host_statements)
441-
precompile_def = codegen_precompile_def(
442-
host_def, codegen.device_function.name
443-
)
378+
444379
result = ast.Module(
445380
[
446381
*func.codegen_imports(),
447382
*codegen.device_function.codegen_helper_functions(),
448383
*kernel_def,
449384
host_def,
450-
precompile_def,
451385
],
452386
[],
453387
)

helion/_compiler/host_function.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .. import exc
1818
from . import ast_extension
19+
from .ast_extension import expr_from_string
1920
from .ast_extension import statement_from_string
2021
from .compile_environment import CompileEnvironment
2122
from .output_header import SOURCE_MODULE
@@ -212,10 +213,32 @@ def debug_str(self) -> str:
212213
return "\n\n".join(result)
213214

214215
def codegen_function_def(self, statements: list[ast.AST]) -> ast.FunctionDef:
216+
# Create a new arguments structure with _launcher kwarg-only parameter
217+
new_args = ast_extension.create(
218+
ast.arguments,
219+
posonlyargs=self.args.posonlyargs,
220+
args=self.args.args,
221+
vararg=self.args.vararg,
222+
kwonlyargs=[
223+
*self.args.kwonlyargs,
224+
ast_extension.create(
225+
ast.arg,
226+
arg="_launcher",
227+
annotation=None,
228+
),
229+
],
230+
kw_defaults=[
231+
*self.args.kw_defaults,
232+
expr_from_string("_default_launcher"),
233+
],
234+
kwarg=self.args.kwarg,
235+
defaults=self.args.defaults,
236+
)
237+
215238
return ast_extension.create(
216239
ast.FunctionDef,
217240
name=self.name,
218-
args=self.args,
241+
args=new_args,
219242
body=statements,
220243
decorator_list=[],
221244
type_comment=None,

helion/_compiler/output_header.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121
"triton_helpers": "from torch._inductor.runtime import triton_helpers",
2222
"tl_math": "from torch._inductor.runtime.triton_helpers import math as tl_math",
2323
"libdevice": "from torch._inductor.runtime.triton_compat import libdevice",
24+
"_default_launcher": "from helion.runtime import default_launcher as _default_launcher",
2425
}
2526

2627
disallowed_names: dict[str, None] = dict.fromkeys(
2728
[
2829
SOURCE_MODULE,
29-
"make_precompiler",
30+
"_launcher",
31+
"_default_launcher",
3032
"_NUM_SM",
3133
]
3234
)

helion/autotuner/base_search.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@
2222

2323
from .. import exc
2424
from ..runtime.precompile_shim import already_compiled
25+
from ..runtime.precompile_shim import make_precompiler
2526
from .config_generation import ConfigGeneration
2627
from .config_generation import FlatConfig
2728
from .logger import LambdaLogger
2829

2930
if TYPE_CHECKING:
3031
from collections.abc import Sequence
3132

33+
import triton
34+
3235
from ..runtime.config import Config
3336
from ..runtime.kernel import BoundKernel
3437
from ..runtime.kernel import CompiledConfig
@@ -144,9 +147,24 @@ def start_precompile_and_check_for_hangs(
144147
return PrecompileFuture.skip(self, config, True)
145148
ctx = mp.get_context("fork")
146149

147-
precompiler = fn.make_precompiler(*self.args) # pyright: ignore[reportFunctionMemberAccess]
148-
if precompiler is already_compiled:
149-
return PrecompileFuture.skip(self, config, True)
150+
def extract_launcher(
151+
triton_kernel: triton.JITFunction,
152+
grid: tuple[int, ...],
153+
*args: object,
154+
**kwargs: object,
155+
):
156+
"""Custom launcher that extracts arguments instead of executing."""
157+
raise _ExtractedLaunchArgs(triton_kernel, grid, args, kwargs)
158+
159+
try:
160+
# Call main function with extraction launcher to extract arguments
161+
fn(*self.args, _launcher=extract_launcher)
162+
# Should not reach here
163+
raise RuntimeError("Expected _ExtractedLaunchArgs exception")
164+
except _ExtractedLaunchArgs as e:
165+
precompiler = make_precompiler(e.kernel)(*e.args, **e.kwargs)
166+
if precompiler is already_compiled:
167+
return PrecompileFuture.skip(self, config, True)
150168
process: mp.Process = ctx.Process(target=precompiler) # pyright: ignore[reportAssignmentType]
151169
process.start()
152170
return PrecompileFuture(
@@ -501,3 +519,14 @@ def _mark_complete(self) -> bool:
501519

502520
self.ok = False
503521
return False
522+
523+
524+
class _ExtractedLaunchArgs(Exception):
525+
"""Exception that carries kernel launch arguments for precompiler extraction."""
526+
527+
def __init__(self, triton_kernel, grid, args, kwargs):
528+
super().__init__()
529+
self.kernel = triton_kernel
530+
self.grid = grid
531+
self.args = args
532+
self.kwargs = kwargs

helion/runtime/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,16 @@ def get_num_sm(device: torch.device) -> int:
4242
"""
4343
assert device.type == "cuda", "TODO: implement for other devices"
4444
return torch.cuda.get_device_properties(device.index).multi_processor_count
45+
46+
47+
def default_launcher(
48+
triton_kernel: triton.JITFunction,
49+
grid: tuple[int, ...],
50+
*args: object,
51+
num_warps: int,
52+
num_stages: int,
53+
):
54+
"""Default launcher function that executes the kernel immediately."""
55+
return triton_kernel.run(
56+
*args, grid=grid, warmup=False, num_warps=num_warps, num_stages=num_stages
57+
)

helion/runtime/kernel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ def compile_config(
387387
print(triton_code, file=sys.stderr)
388388
module = PyCodeCache.load(triton_code)
389389
rv = getattr(module, self.kernel.name)
390-
rv.make_precompiler = getattr(module, f"_{self.kernel.name}_make_precompiler")
391390
self._compile_cache[config] = rv
392391
return rv
393392

0 commit comments

Comments
 (0)