Skip to content

Commit 0f0d8cb

Browse files
committed
Improve naming for generated helper functions
1 parent dcfa500 commit 0f0d8cb

10 files changed

+190
-152
lines changed

helion/_compiler/device_function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,8 @@ def register_helper_function(
504504
self, helper_graph_info: HelperFunctionGraphInfo
505505
) -> None:
506506
"""Register a helper function to be generated at global scope."""
507-
self.helper_manager.register_helper_function(helper_graph_info)
507+
name = self.namespace.create_name(self.name, None)
508+
self.helper_manager.register_helper_function(helper_graph_info, name)
508509

509510
def codegen_helper_functions(self) -> list[ast.stmt]:
510511
"""Generate helper function definitions at global scope."""

helion/_compiler/device_ir.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,9 +937,14 @@ class HelperFunctionGraphInfo(NodeArgsGraphInfo):
937937
"""Graph info for helper functions in higher-order operations like associative_scan."""
938938

939939
_param_names: list[str] = dataclasses.field(default_factory=list)
940+
original_function_name: str | None = dataclasses.field(default=None)
940941

941942
@property
942943
def name(self) -> str:
944+
# This property should only be used during registration, not for final names
945+
# Final names are generated in codegen using the namespace below
946+
if self.original_function_name:
947+
return f"{self.original_function_name}_{self.graph_id}"
943948
return f"helper_function_{self.graph_id}"
944949

945950
def find_input_nodes(self) -> list[torch.fx.Node]:

helion/_compiler/helper_function.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ def extract_helper_function(helper_fn: object) -> types.FunctionType:
5858
return helper_fn.fn if isinstance(helper_fn, Kernel) else helper_fn # pyright: ignore[reportReturnType]
5959

6060

61+
def extract_helper_function_name(helper_fn: object) -> str:
62+
"""Extract the function name from a Kernel object or regular function."""
63+
return extract_helper_function(helper_fn).__name__
64+
65+
6166
CombineFunctionBasic = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
6267
CombineFunctionTuple = Callable[..., tuple[torch.Tensor, ...]]
6368
CombineFunction = CombineFunctionBasic | CombineFunctionTuple
@@ -153,17 +158,22 @@ class HelperFunctionManager:
153158

154159
def __init__(self) -> None:
155160
self.helper_functions: dict[str, HelperFunctionGraphInfo] = {}
161+
self._final_names: dict[str, str] = {}
156162

157163
def register_helper_function(
158-
self, helper_graph_info: HelperFunctionGraphInfo
164+
self, helper_graph_info: HelperFunctionGraphInfo, final_name: str
159165
) -> None:
160166
"""Register a helper function to be generated at global scope."""
161167
self.helper_functions[helper_graph_info.name] = helper_graph_info
168+
self._final_names[helper_graph_info.name] = final_name
162169

163170
def codegen_helper_functions(self) -> list[ast.stmt]:
164171
"""Generate helper function definitions at global scope."""
165172
helper_defs = []
166173
for helper_graph_info in self.helper_functions.values():
174+
# Get the final name that was already determined during registration
175+
final_name = self._final_names[helper_graph_info.name]
176+
167177
# Determine the number of parameters from the graph
168178
input_nodes = helper_graph_info.find_input_nodes()
169179

@@ -184,7 +194,7 @@ def codegen_helper_functions(self) -> list[ast.stmt]:
184194
# Generate the function structure with @triton.jit decorator
185195
func_def = create(
186196
ast.FunctionDef,
187-
name=helper_graph_info.name,
197+
name=final_name,
188198
args=create_arguments(args),
189199
body=func_body,
190200
decorator_list=[expr_from_string("triton.jit")],
@@ -195,6 +205,10 @@ def codegen_helper_functions(self) -> list[ast.stmt]:
195205

196206
return helper_defs
197207

208+
def get_final_name(self, helper_graph_info: HelperFunctionGraphInfo) -> str:
209+
"""Get the final generated name for a helper function."""
210+
return self._final_names.get(helper_graph_info.name, helper_graph_info.name)
211+
198212
def _codegen_helper_function_body(
199213
self, helper_graph_info: HelperFunctionGraphInfo
200214
) -> list[ast.stmt]:

helion/language/reduce_ops.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def _(
137137
from .._compiler.device_ir import args_to_proxies
138138
from .._compiler.device_ir import select_decomp_table
139139
from .._compiler.helper_function import create_combine_function_wrapper
140+
from .._compiler.helper_function import extract_helper_function_name
140141

141142
is_tuple_input = isinstance(input_tensor, (tuple, list))
142143
if is_tuple_input:
@@ -147,6 +148,8 @@ def _(
147148
assert isinstance(input_tensor, torch.Tensor), "reduce input must be a tensor"
148149

149150
assert callable(combine_fn), "combine_fn must be callable"
151+
# Extract the function name before wrapping
152+
original_function_name = extract_helper_function_name(combine_fn)
150153
combine_fn = create_combine_function_wrapper(
151154
combine_fn, is_tuple_input=is_tuple_input, target_format="tuple"
152155
)
@@ -182,6 +185,7 @@ def _(
182185
combine_graph,
183186
HelperFunctionGraphInfo,
184187
node_args=[],
188+
original_function_name=original_function_name,
185189
)
186190

187191
# Validate other parameter for mask_node_inputs
@@ -334,12 +338,17 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]:
334338

335339

336340
def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str:
337-
"""Register the helper function and return its name."""
341+
"""Register the helper function and return its final name."""
342+
from .._compiler.device_ir import HelperFunctionGraphInfo
338343
from .._compiler.host_function import HostFunction
339344

340345
helper_graph_info = HostFunction.current().device_ir.graphs[combine_graph_id]
341-
state.codegen.device_function.register_helper_function(helper_graph_info) # pyright: ignore[reportArgumentType]
342-
return helper_graph_info.name
346+
assert isinstance(helper_graph_info, HelperFunctionGraphInfo)
347+
state.codegen.device_function.register_helper_function(helper_graph_info)
348+
# Get the final name from the helper manager (which uses the namespace)
349+
return state.codegen.device_function.helper_manager.get_final_name(
350+
helper_graph_info
351+
)
343352

344353

345354
def _create_reduce_expression(

helion/language/scan_ops.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def _(
117117
from .._compiler.device_ir import args_to_proxies
118118
from .._compiler.device_ir import select_decomp_table
119119
from .._compiler.helper_function import create_combine_function_wrapper
120+
from .._compiler.helper_function import extract_helper_function_name
120121

121122
is_tuple_input = isinstance(input_tensor, (tuple, list))
122123
if is_tuple_input:
@@ -130,6 +131,8 @@ def _(
130131
assert isinstance(dim, int), "associative_scan dim must be an integer"
131132

132133
assert callable(combine_fn), "combine_fn must be callable"
134+
# Extract the function name before wrapping
135+
original_function_name = extract_helper_function_name(combine_fn)
133136
combine_fn = create_combine_function_wrapper(
134137
combine_fn, is_tuple_input=is_tuple_input, target_format="unpacked"
135138
)
@@ -151,6 +154,7 @@ def _(
151154
combine_graph,
152155
HelperFunctionGraphInfo,
153156
node_args=[],
157+
original_function_name=original_function_name,
154158
)
155159

156160
# Create the associative_scan tracing operation
@@ -349,12 +353,17 @@ def _get_input_tensor_ast(state: CodegenState, is_tuple_input: bool) -> ast.AST:
349353

350354

351355
def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str:
352-
"""Register the helper function and return its name."""
356+
"""Register the helper function and return its final name."""
357+
from .._compiler.device_ir import HelperFunctionGraphInfo
353358
from .._compiler.host_function import HostFunction
354359

355360
helper_graph_info = HostFunction.current().device_ir.graphs[combine_graph_id]
356-
state.codegen.device_function.register_helper_function(helper_graph_info) # pyright: ignore[reportArgumentType]
357-
return helper_graph_info.name
361+
assert isinstance(helper_graph_info, HelperFunctionGraphInfo)
362+
state.codegen.device_function.register_helper_function(helper_graph_info)
363+
# Get the final name from the helper manager (which uses the namespace)
364+
return state.codegen.device_function.helper_manager.get_final_name(
365+
helper_graph_info
366+
)
358367

359368

360369
def _create_scan_expression(

0 commit comments

Comments
 (0)