From f1795e83c966e70acd6b6b88cdcc305eb55581fd Mon Sep 17 00:00:00 2001 From: Jason Ansel Date: Mon, 14 Jul 2025 20:47:20 -0700 Subject: [PATCH] Improve naming for generated helper functions stack-info: PR: https://github.com/pytorch-labs/helion/pull/323, branch: jansel/stack/113 --- helion/_compiler/device_function.py | 3 +- helion/_compiler/device_ir.py | 5 + helion/_compiler/helper_function.py | 18 ++- helion/language/reduce_ops.py | 15 ++- helion/language/scan_ops.py | 15 ++- test/test_associative_scan.expected | 180 ++++++++++++++-------------- test/test_associative_scan.py | 32 ++--- test/test_examples.expected | 4 +- test/test_reduce.expected | 60 +++++----- test/test_reduce.py | 10 +- 10 files changed, 190 insertions(+), 152 deletions(-) diff --git a/helion/_compiler/device_function.py b/helion/_compiler/device_function.py index f1ccd237..3e51bd2e 100644 --- a/helion/_compiler/device_function.py +++ b/helion/_compiler/device_function.py @@ -504,7 +504,8 @@ def register_helper_function( self, helper_graph_info: HelperFunctionGraphInfo ) -> None: """Register a helper function to be generated at global scope.""" - self.helper_manager.register_helper_function(helper_graph_info) + name = self.namespace.create_name(helper_graph_info.name, None) + self.helper_manager.register_helper_function(helper_graph_info, name) def codegen_helper_functions(self) -> list[ast.stmt]: """Generate helper function definitions at global scope.""" diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 53777b1e..9f66dd32 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -942,9 +942,14 @@ class HelperFunctionGraphInfo(NodeArgsGraphInfo): """Graph info for helper functions in higher-order operations like associative_scan.""" _param_names: list[str] = dataclasses.field(default_factory=list) + original_function_name: str | None = dataclasses.field(default=None) @property def name(self) -> str: + # This property should only be used during registration, not for final names + # Final names are generated in codegen using the namespace below + if self.original_function_name: + return f"{self.original_function_name}_{self.graph_id}" return f"helper_function_{self.graph_id}" def find_input_nodes(self) -> list[torch.fx.Node]: diff --git a/helion/_compiler/helper_function.py b/helion/_compiler/helper_function.py index 49330125..13db6ae1 100644 --- a/helion/_compiler/helper_function.py +++ b/helion/_compiler/helper_function.py @@ -58,6 +58,11 @@ def extract_helper_function(helper_fn: object) -> types.FunctionType: return helper_fn.fn if isinstance(helper_fn, Kernel) else helper_fn # pyright: ignore[reportReturnType] +def extract_helper_function_name(helper_fn: object) -> str: + """Extract the function name from a Kernel object or regular function.""" + return extract_helper_function(helper_fn).__name__ + + CombineFunctionBasic = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] CombineFunctionTuple = Callable[..., tuple[torch.Tensor, ...]] CombineFunction = CombineFunctionBasic | CombineFunctionTuple @@ -153,17 +158,22 @@ class HelperFunctionManager: def __init__(self) -> None: self.helper_functions: dict[str, HelperFunctionGraphInfo] = {} + self._final_names: dict[str, str] = {} def register_helper_function( - self, helper_graph_info: HelperFunctionGraphInfo + self, helper_graph_info: HelperFunctionGraphInfo, final_name: str ) -> None: """Register a helper function to be generated at global scope.""" self.helper_functions[helper_graph_info.name] = helper_graph_info + self._final_names[helper_graph_info.name] = final_name def codegen_helper_functions(self) -> list[ast.stmt]: """Generate helper function definitions at global scope.""" helper_defs = [] for helper_graph_info in self.helper_functions.values(): + # Get the final name that was already determined during registration + final_name = self._final_names[helper_graph_info.name] + # Determine the number of parameters from the graph input_nodes = helper_graph_info.find_input_nodes() @@ -184,7 +194,7 @@ def codegen_helper_functions(self) -> list[ast.stmt]: # Generate the function structure with @triton.jit decorator func_def = create( ast.FunctionDef, - name=helper_graph_info.name, + name=final_name, args=create_arguments(args), body=func_body, decorator_list=[expr_from_string("triton.jit")], @@ -195,6 +205,10 @@ def codegen_helper_functions(self) -> list[ast.stmt]: return helper_defs + def get_final_name(self, helper_graph_info: HelperFunctionGraphInfo) -> str: + """Get the final generated name for a helper function.""" + return self._final_names.get(helper_graph_info.name, helper_graph_info.name) + def _codegen_helper_function_body( self, helper_graph_info: HelperFunctionGraphInfo ) -> list[ast.stmt]: diff --git a/helion/language/reduce_ops.py b/helion/language/reduce_ops.py index 5d15c81c..b811330b 100644 --- a/helion/language/reduce_ops.py +++ b/helion/language/reduce_ops.py @@ -137,6 +137,7 @@ def _( from .._compiler.device_ir import args_to_proxies from .._compiler.device_ir import select_decomp_table from .._compiler.helper_function import create_combine_function_wrapper + from .._compiler.helper_function import extract_helper_function_name is_tuple_input = isinstance(input_tensor, (tuple, list)) if is_tuple_input: @@ -147,6 +148,8 @@ def _( assert isinstance(input_tensor, torch.Tensor), "reduce input must be a tensor" assert callable(combine_fn), "combine_fn must be callable" + # Extract the function name before wrapping + original_function_name = extract_helper_function_name(combine_fn) combine_fn = create_combine_function_wrapper( combine_fn, is_tuple_input=is_tuple_input, target_format="tuple" ) @@ -182,6 +185,7 @@ def _( combine_graph, HelperFunctionGraphInfo, node_args=[], + original_function_name=original_function_name, ) # Validate other parameter for mask_node_inputs @@ -334,12 +338,17 @@ def _(state: CodegenState) -> ast.AST | list[ast.AST]: def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str: - """Register the helper function and return its name.""" + """Register the helper function and return its final name.""" + from .._compiler.device_ir import HelperFunctionGraphInfo from .._compiler.host_function import HostFunction helper_graph_info = HostFunction.current().device_ir.graphs[combine_graph_id] - state.codegen.device_function.register_helper_function(helper_graph_info) # pyright: ignore[reportArgumentType] - return helper_graph_info.name + assert isinstance(helper_graph_info, HelperFunctionGraphInfo) + state.codegen.device_function.register_helper_function(helper_graph_info) + # Get the final name from the helper manager (which uses the namespace) + return state.codegen.device_function.helper_manager.get_final_name( + helper_graph_info + ) def _create_reduce_expression( diff --git a/helion/language/scan_ops.py b/helion/language/scan_ops.py index a9502c55..44ea10cb 100644 --- a/helion/language/scan_ops.py +++ b/helion/language/scan_ops.py @@ -117,6 +117,7 @@ def _( from .._compiler.device_ir import args_to_proxies from .._compiler.device_ir import select_decomp_table from .._compiler.helper_function import create_combine_function_wrapper + from .._compiler.helper_function import extract_helper_function_name is_tuple_input = isinstance(input_tensor, (tuple, list)) if is_tuple_input: @@ -130,6 +131,8 @@ def _( assert isinstance(dim, int), "associative_scan dim must be an integer" assert callable(combine_fn), "combine_fn must be callable" + # Extract the function name before wrapping + original_function_name = extract_helper_function_name(combine_fn) combine_fn = create_combine_function_wrapper( combine_fn, is_tuple_input=is_tuple_input, target_format="unpacked" ) @@ -151,6 +154,7 @@ def _( combine_graph, HelperFunctionGraphInfo, node_args=[], + original_function_name=original_function_name, ) # Create the associative_scan tracing operation @@ -349,12 +353,17 @@ def _get_input_tensor_ast(state: CodegenState, is_tuple_input: bool) -> ast.AST: def _register_helper_function(state: CodegenState, combine_graph_id: int) -> str: - """Register the helper function and return its name.""" + """Register the helper function and return its final name.""" + from .._compiler.device_ir import HelperFunctionGraphInfo from .._compiler.host_function import HostFunction helper_graph_info = HostFunction.current().device_ir.graphs[combine_graph_id] - state.codegen.device_function.register_helper_function(helper_graph_info) # pyright: ignore[reportArgumentType] - return helper_graph_info.name + assert isinstance(helper_graph_info, HelperFunctionGraphInfo) + state.codegen.device_function.register_helper_function(helper_graph_info) + # Get the final name from the helper manager (which uses the namespace) + return state.codegen.device_function.helper_manager.get_final_name( + helper_graph_info + ) def _create_scan_expression( diff --git a/test/test_associative_scan.expected b/test/test_associative_scan.expected index d942a7f5..6a29045a 100644 --- a/test/test_associative_scan.expected +++ b/test/test_associative_scan.expected @@ -12,7 +12,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def argmax_combine_tuple_fn_0(param_0, param_1, param_2, param_3): v_0 = param_2 > param_0 v_1 = tl.where(v_0, param_2, param_0) v_2 = tl.where(v_0, param_3, param_1) @@ -31,8 +31,8 @@ def _cumulative_argmax_tuple_kernel_kernel(input_data, positions, max_values, ma v_0 = load_1.to(tl.float32) unsqueeze = v_0[None, :] indices = tl.broadcast_to(unsqueeze, [_BLOCK_SIZE_0, _RDIM_SIZE_1]) - out_vals = tl.associative_scan((vals, indices), 1, helper_function_0)[0] - out_indices = tl.associative_scan((vals, indices), 1, helper_function_0)[1] + out_vals = tl.associative_scan((vals, indices), 1, argmax_combine_tuple_fn_0)[0] + out_indices = tl.associative_scan((vals, indices), 1, argmax_combine_tuple_fn_0)[1] tl.store(max_values + (indices_0[:, None] * max_values_stride_0 + indices_1[None, :] * max_values_stride_1), out_vals, mask_0[:, None] & mask_1[None, :]) v_1 = out_indices.to(tl.int32) tl.store(max_indices + (indices_0[:, None] * max_indices_stride_0 + indices_1[None, :] * max_indices_stride_1), v_1, mask_0[:, None] & mask_1[None, :]) @@ -56,7 +56,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -69,7 +69,7 @@ def _test_scan_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, res indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_scan_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -90,7 +90,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -99,7 +99,7 @@ def _test_codegen_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1 indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) def test_codegen_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -119,7 +119,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def argmax_combine_fn_0(param_0, param_1, param_2, param_3): v_0 = param_2 > param_0 v_1 = tl.where(v_0, param_2, param_0) v_2 = tl.where(v_0, param_3, param_1) @@ -138,8 +138,8 @@ def _cumulative_argmax_kernel_kernel(input_data, positions, max_values, max_indi v_0 = load_1.to(tl.float32) unsqueeze = v_0[None, :] indices = tl.broadcast_to(unsqueeze, [_BLOCK_SIZE_0, _RDIM_SIZE_1]) - out_vals = tl.associative_scan((vals, indices), 1, helper_function_0)[0] - out_indices = tl.associative_scan((vals, indices), 1, helper_function_0)[1] + out_vals = tl.associative_scan((vals, indices), 1, argmax_combine_fn_0)[0] + out_indices = tl.associative_scan((vals, indices), 1, argmax_combine_fn_0)[1] tl.store(max_values + (indices_0[:, None] * max_values_stride_0 + indices_1[None, :] * max_values_stride_1), out_vals, mask_0[:, None] & mask_1[None, :]) v_1 = out_indices.to(tl.int32) tl.store(max_indices + (indices_0[:, None] * max_indices_stride_0 + indices_1[None, :] * max_indices_stride_1), v_1, mask_0[:, None] & mask_1[None, :]) @@ -163,7 +163,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -176,7 +176,7 @@ def _test_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, re indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -197,7 +197,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -210,7 +210,7 @@ def _test_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, re indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -231,7 +231,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -244,7 +244,7 @@ def _test_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, re indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -265,7 +265,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -278,7 +278,7 @@ def _test_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, re indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -299,7 +299,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -308,7 +308,7 @@ def _test_size_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1, _ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -328,7 +328,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -341,7 +341,7 @@ def _test_size_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, res indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -362,7 +362,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -375,7 +375,7 @@ def _test_size_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, res indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -396,7 +396,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -407,7 +407,7 @@ def _test_size_kernel_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < x_size_0 row_data = tl.load(x + indices_0[:, None] * x_stride_0, mask_0[:, None], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + indices_0[:, None] * result_stride_0, _associative_scan, mask_0[:, None]) def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -427,7 +427,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -439,7 +439,7 @@ def _test_size_kernel_kernel(x, result, x_size_1, result_stride_0, result_stride indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -459,7 +459,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -471,7 +471,7 @@ def _test_size_kernel_kernel(x, result, x_size_1, result_stride_0, result_stride indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_size_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -491,14 +491,14 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @triton.jit def _test_single_element_kernel(x, result): row_data = tl.load(x + tl.zeros([1, 1], tl.int32), None) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + tl.zeros([1, 1], tl.int32), _associative_scan, None) def test_single_element(x: torch.Tensor, *, _launcher=_default_launcher): @@ -517,7 +517,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -526,7 +526,7 @@ def _test_single_element_kernel(x, result, x_size_1, result_stride_1, x_stride_1 indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) def test_single_element(x: torch.Tensor, *, _launcher=_default_launcher): @@ -546,7 +546,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -559,7 +559,7 @@ def _test_helper_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, r indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(load, 0, helper_function_0) + _associative_scan = tl.associative_scan(load, 0, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_helper_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -580,7 +580,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def jit_add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -589,7 +589,7 @@ def _test_jit_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1, _R indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, jit_add_combine_fn_0) tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) def test_jit_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -609,7 +609,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -621,7 +621,7 @@ def _test_large_kernel_kernel(x, result, x_size_1, result_stride_0, result_strid indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_large_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -642,7 +642,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def max_combine_fn_0(param_0, param_1): v_0 = triton_helpers.maximum(param_0, param_1) return v_0 @@ -655,7 +655,7 @@ def _test_max_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, resu indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, max_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_max_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -677,7 +677,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def min_combine_fn_0(param_0, param_1): v_0 = triton_helpers.minimum(param_0, param_1) return v_0 @@ -690,7 +690,7 @@ def _test_min_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, resu indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, min_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_min_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -712,12 +712,12 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @triton.jit -def helper_function_1(param_0, param_1): +def max_combine_fn_1(param_0, param_1): v_0 = triton_helpers.maximum(param_0, param_1) return v_0 @@ -726,9 +726,9 @@ def _test_multi_kernel_kernel(x, sum_result, max_result, x_size_1, max_result_st indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(sum_result + indices_1[None, :] * sum_result_stride_1, _associative_scan, mask_1[None, :]) - _associative_scan_1 = tl.associative_scan(_associative_scan, 1, helper_function_1) + _associative_scan_1 = tl.associative_scan(_associative_scan, 1, max_combine_fn_1) tl.store(max_result + indices_1[None, :] * max_result_stride_1, _associative_scan_1, mask_1[None, :]) def test_multi_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -749,7 +749,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def mul_combine_fn_0(param_0, param_1): v_0 = param_0 * param_1 return v_0 @@ -762,7 +762,7 @@ def _test_mul_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, resu indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, mul_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_mul_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -783,7 +783,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -792,7 +792,7 @@ def _test_reverse_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1 indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0, reverse=True) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0, reverse=True) tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) def test_reverse_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -812,7 +812,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def segmented_combine_fn_0(param_0, param_1, param_2, param_3): v_0 = param_1 == param_3 v_1 = param_0 + param_2 v_2 = tl.where(v_0, v_1, param_2) @@ -834,7 +834,7 @@ def _segmented_scan_kernel_kernel(input_data, indices, output, indices_stride_0, v_0 = load_1.to(tl.float32) unsqueeze = v_0[:, None] idxs = tl.broadcast_to(unsqueeze, [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) - out_vals = tl.associative_scan((vals, idxs), 0, helper_function_0)[0] + out_vals = tl.associative_scan((vals, idxs), 0, segmented_combine_fn_0)[0] tl.store(output + (indices_0[:, None] * output_stride_0 + indices_1[None, :] * output_stride_1), out_vals, mask_0[:, None] & mask_1[None, :]) def segmented_scan_kernel(indices: torch.Tensor, input_data: torch.Tensor, *, _launcher=_default_launcher): @@ -856,7 +856,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -869,7 +869,7 @@ def _test_torch_hops_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_torch_hops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -890,7 +890,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def helion_combine_fn_0(param_0, param_1, param_2, param_3): v_0 = param_1 == param_3 v_1 = param_0 + param_2 v_2 = tl.where(v_0, v_1, param_2) @@ -911,7 +911,7 @@ def _test_segmented_kernel_kernel(input_data, indices, output, indices_stride_0, load_1 = tl.load(indices + indices_0 * indices_stride_0, mask_0, other=0) unsqueeze = load_1[:, None] idxs = tl.broadcast_to(unsqueeze, [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) - out_vals = tl.associative_scan((vals, idxs), 0, helper_function_0)[0] + out_vals = tl.associative_scan((vals, idxs), 0, helion_combine_fn_0)[0] tl.store(output + (indices_0[:, None] * output_stride_0 + indices_1[None, :] * output_stride_1), out_vals, mask_0[:, None] & mask_1[None, :]) def test_segmented_kernel(indices: torch.Tensor, input_data: torch.Tensor, *, _launcher=_default_launcher): @@ -933,7 +933,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def helion_combine_tuple_fn_0(param_0, param_1, param_2, param_3): v_0 = param_1 == param_3 v_1 = param_0 + param_2 v_2 = tl.where(v_0, v_1, param_2) @@ -954,7 +954,7 @@ def _test_segmented_tuple_kernel_kernel(input_data, indices, output, indices_str load_1 = tl.load(indices + indices_0 * indices_stride_0, mask_0, other=0) unsqueeze = load_1[:, None] idxs = tl.broadcast_to(unsqueeze, [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) - out_vals = tl.associative_scan((vals, idxs), 0, helper_function_0)[0] + out_vals = tl.associative_scan((vals, idxs), 0, helion_combine_tuple_fn_0)[0] tl.store(output + (indices_0[:, None] * output_stride_0 + indices_1[None, :] * output_stride_1), out_vals, mask_0[:, None] & mask_1[None, :]) def test_segmented_tuple_kernel(indices: torch.Tensor, input_data: torch.Tensor, *, _launcher=_default_launcher): @@ -976,7 +976,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_associative_scan as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -988,7 +988,7 @@ def _test_type_kernel_kernel(x, result, x_size_1, result_stride_0, result_stride indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_combine_fn_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_type_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1006,7 +1006,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def mul_0(param_0, param_1): v_0 = param_0 * param_1 return v_0 @@ -1019,7 +1019,7 @@ def _test_cumprod_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, mul_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumprod_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1038,7 +1038,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def mul_0(param_0, param_1): v_0 = param_0 * param_1 return v_0 @@ -1051,7 +1051,7 @@ def _test_cumprod_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_stri indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, mul_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumprod_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1070,7 +1070,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def mul_0(param_0, param_1): v_0 = param_0 * param_1 return v_0 @@ -1083,7 +1083,7 @@ def _test_cumprod_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_stri indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, mul_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumprod_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1102,7 +1102,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def mul_0(param_0, param_1): v_0 = param_0 * param_1 return v_0 @@ -1115,7 +1115,7 @@ def _test_cumprod_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_stri indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, mul_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumprod_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1134,7 +1134,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def mul_0(param_0, param_1): v_0 = param_0 * param_1 return v_0 @@ -1147,7 +1147,7 @@ def _test_cumprod_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_stri indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, mul_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumprod_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1166,7 +1166,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def mul_0(param_0, param_1): v_0 = param_0 * param_1 return v_0 @@ -1175,7 +1175,7 @@ def _test_cumprod_reverse_kernel_kernel(x, result, x_size_1, result_stride_1, x_ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0, reverse=True) + _associative_scan = tl.associative_scan(row_data, 1, mul_0, reverse=True) tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) def test_cumprod_reverse_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1193,7 +1193,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def add_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -1206,7 +1206,7 @@ def _test_cumsum_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, r indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumsum_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1225,12 +1225,12 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def add_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @triton.jit -def helper_function_1(param_0, param_1): +def mul_1(param_0, param_1): v_0 = param_0 * param_1 return v_0 @@ -1239,9 +1239,9 @@ def _test_mixed_kernel_kernel(x, sum_result, prod_result, x_size_1, prod_result_ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_0) tl.store(sum_result + indices_1[None, :] * sum_result_stride_1, _associative_scan, mask_1[None, :]) - _associative_scan_1 = tl.associative_scan(_associative_scan, 1, helper_function_1) + _associative_scan_1 = tl.associative_scan(_associative_scan, 1, mul_1) tl.store(prod_result + indices_1[None, :] * prod_result_stride_1, _associative_scan_1, mask_1[None, :]) def test_mixed_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1260,7 +1260,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def add_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -1273,7 +1273,7 @@ def _test_cumsum_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_strid indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumsum_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1292,7 +1292,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def add_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -1305,7 +1305,7 @@ def _test_cumsum_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_strid indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumsum_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1324,7 +1324,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def add_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -1337,7 +1337,7 @@ def _test_cumsum_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_strid indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumsum_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1356,7 +1356,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def add_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -1369,7 +1369,7 @@ def _test_cumsum_dtype_kernel_kernel(x, result, x_size_0, x_size_1, result_strid indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) + _associative_scan = tl.associative_scan(row_data, 1, add_0) tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_0[:, None] & mask_1[None, :]) def test_cumsum_dtype_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -1388,7 +1388,7 @@ import triton.language as tl from helion.runtime import default_launcher as _default_launcher @triton.jit -def helper_function_0(param_0, param_1): +def add_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -1397,7 +1397,7 @@ def _test_cumsum_reverse_kernel_kernel(x, result, x_size_1, result_stride_1, x_s indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _associative_scan = tl.associative_scan(row_data, 1, helper_function_0, reverse=True) + _associative_scan = tl.associative_scan(row_data, 1, add_0, reverse=True) tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) def test_cumsum_reverse_kernel(x: torch.Tensor, *, _launcher=_default_launcher): diff --git a/test/test_associative_scan.py b/test/test_associative_scan.py index ac1cccf1..f947ab22 100644 --- a/test/test_associative_scan.py +++ b/test/test_associative_scan.py @@ -126,7 +126,7 @@ def test_scan_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) # Verify the generated code contains the correct helper function - self.assertIn("def helper_function_", code) + self.assertIn("def add_combine_fn_", code) self.assertIn("param_0 + param_1", code) self.assertIn("tl.associative_scan", code) @@ -248,8 +248,8 @@ def test_multi_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected_sum) # Verify multiple helper functions are generated - self.assertIn("helper_function_0", code) - self.assertIn("helper_function_1", code) + self.assertIn("add_combine_fn_", code) + self.assertIn("max_combine_fn_", code) self.assertIn("param_0 + param_1", code) # Check for maximum operation (either format) self.assertTrue("tl.maximum" in code or "triton_helpers.maximum" in code) @@ -454,7 +454,7 @@ def test_torch_hops_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) # Verify the generated code contains the proper combine function and associative scan - self.assertIn("def helper_function_", code) + self.assertIn("def add_combine_fn_", code) self.assertIn("tl.associative_scan", code) self.assertIn("param_0 + param_1", code) @@ -475,7 +475,7 @@ def test_codegen_kernel(x: torch.Tensor) -> torch.Tensor: # Check essential code structure self.assertIn("@triton.jit", code) - self.assertIn("def helper_function_", code) + self.assertIn("def add_combine_fn_", code) self.assertIn("tl.associative_scan", code) self.assertIn("return", code) @@ -503,7 +503,7 @@ def test_jit_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) # Verify the generated code contains the proper combine function and associative scan - self.assertIn("def helper_function_", code) + self.assertIn("def jit_add_combine_fn_", code) self.assertIn("tl.associative_scan", code) self.assertIn("param_0 + param_1", code) # Verify @helion.jit decorator doesn't appear in generated code @@ -555,7 +555,7 @@ def test_segmented_kernel( torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) # Verify the generated code structure - self.assertIn("def helper_function_", code) + self.assertIn("def helion_combine_fn_", code) self.assertIn("tl.associative_scan", code) def test_associative_scan_segmented_reduction(self): @@ -609,7 +609,7 @@ def segmented_scan_kernel( torch.testing.assert_close(result, expected, rtol=1e-4, atol=1e-4) # Verify the generated code structure - self.assertIn("def helper_function_", code) + self.assertIn("def segmented_combine_fn_", code) self.assertIn("tl.associative_scan", code) def test_associative_scan_cumulative_argmax(self): @@ -678,7 +678,7 @@ def cumulative_argmax_kernel( torch.testing.assert_close(result_indices, expected_indices) # Verify the generated code structure - self.assertIn("def helper_function_", code) + self.assertIn("def argmax_combine_fn_", code) self.assertIn("tl.associative_scan", code) def test_associative_scan_in_helper_function(self): @@ -709,7 +709,7 @@ def test_helper_kernel(x: torch.Tensor) -> torch.Tensor: self.assertFalse(torch.equal(result, x)) # Verify the generated code contains the helper function and associative scan - self.assertIn("def helper_function_", code) + self.assertIn("def add_combine_fn_", code) self.assertIn("tl.associative_scan", code) self.assertIn("param_0 + param_1", code) @@ -736,7 +736,7 @@ def test_cumsum_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) # Verify the generated code contains cumsum implementation - self.assertIn("def helper_function_", code) + self.assertIn("def add_", code) self.assertIn("param_0 + param_1", code) self.assertIn("tl.associative_scan", code) @@ -817,7 +817,7 @@ def test_cumprod_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected) # Verify the generated code contains cumprod implementation - self.assertIn("def helper_function_", code) + self.assertIn("def mul_", code) self.assertIn("param_0 * param_1", code) self.assertIn("tl.associative_scan", code) @@ -903,8 +903,8 @@ def test_mixed_kernel(x: torch.Tensor) -> torch.Tensor: torch.testing.assert_close(result, expected_sum) # Verify both helper functions are generated - self.assertIn("helper_function_0", code) - self.assertIn("helper_function_1", code) + self.assertIn("add_", code) + self.assertIn("mul_", code) self.assertIn("param_0 + param_1", code) self.assertIn("param_0 * param_1", code) @@ -956,7 +956,7 @@ def test_segmented_tuple_kernel( torch.testing.assert_close(result, expected) # Verify the generated code structure - self.assertIn("def helper_function_", code) + self.assertIn("def helion_combine_tuple_fn_", code) self.assertIn("tl.associative_scan", code) def test_associative_scan_argmax_tuple_format(self): @@ -1025,7 +1025,7 @@ def cumulative_argmax_tuple_kernel( torch.testing.assert_close(result_indices, expected_indices) # Verify the generated code structure - self.assertIn("def helper_function_", code) + self.assertIn("def argmax_combine_tuple_fn_", code) self.assertIn("tl.associative_scan", code) diff --git a/test/test_examples.expected b/test/test_examples.expected index cc597ddd..7ee954ee 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1143,7 +1143,7 @@ from helion.runtime import default_launcher as _default_launcher import helion._testing.segment_reduction as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def combine_fn_helion_0(param_0, param_1, param_2, param_3): v_0 = param_1 == param_3 v_1 = param_0 + param_2 v_2 = tl.where(v_0, v_1, param_2) @@ -1171,7 +1171,7 @@ def _segmented_reduction_helion_kernel(input_data, indices, output, indices_stri v_4 = idxs.to(tl.float32) unsqueeze = v_4[:, None] expand = tl.broadcast_to(unsqueeze, [_BLOCK_SIZE_0, _BLOCK_SIZE_1]) - out_vals = tl.associative_scan((vals, expand), 0, helper_function_0)[0] + out_vals = tl.associative_scan((vals, expand), 0, combine_fn_helion_0)[0] v_5 = idxs != idxs_next _BLOCK_SIZE_0_ = _BLOCK_SIZE_0 v_6 = _BLOCK_SIZE_0_.to(tl.int32) diff --git a/test/test_reduce.expected b/test/test_reduce.expected index 6152e533..a3477a07 100644 --- a/test/test_reduce.expected +++ b/test/test_reduce.expected @@ -12,7 +12,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def argmax_combine_fn_0(param_0, param_1, param_2, param_3): v_0 = param_2 > param_0 v_1 = tl.where(v_0, param_2, param_0) v_2 = tl.where(v_0, param_3, param_1) @@ -29,7 +29,7 @@ def _test_argmax_negative_kernel_kernel(indices, values, result, indices_size_1, row_values = tl.load(values + (indices_0[:, None] * values_stride_0 + indices_1[None, :] * values_stride_1), mask_0[:, None] & mask_1[None, :], other=0) row_indices = tl.load(indices + (indices_0[:, None] * indices_stride_0 + indices_1[None, :] * indices_stride_1), mask_0[:, None] & mask_1[None, :], other=0) _mask_to = tl.where(mask_0[:, None] & mask_1[None, :], row_values, float('-inf')) - max_index = tl.reduce((_mask_to, row_indices), 1, helper_function_0)[1] + max_index = tl.reduce((_mask_to, row_indices), 1, argmax_combine_fn_0)[1] tl.store(result + indices_0 * result_stride_0, max_index, mask_0) def test_argmax_negative_kernel(values: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher): @@ -51,7 +51,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def argmax_combine_unpacked_fn_0(param_0, param_1, param_2, param_3): v_0 = param_2 > param_0 v_1 = tl.where(v_0, param_2, param_0) v_2 = tl.where(v_0, param_3, param_1) @@ -67,7 +67,7 @@ def _test_argmax_unpacked_kernel_kernel(indices, values, result, indices_size_1, mask_1 = indices_1 < indices_size_1 row_values = tl.load(values + (indices_0[:, None] * values_stride_0 + indices_1[None, :] * values_stride_1), mask_0[:, None] & mask_1[None, :], other=0) row_indices = tl.load(indices + (indices_0[:, None] * indices_stride_0 + indices_1[None, :] * indices_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - max_index = tl.reduce((row_values, row_indices), 1, helper_function_0)[1] + max_index = tl.reduce((row_values, row_indices), 1, argmax_combine_unpacked_fn_0)[1] tl.store(result + indices_0 * result_stride_0, max_index, mask_0) def test_argmax_unpacked_kernel(values: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher): @@ -89,7 +89,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -102,7 +102,7 @@ def _test_reduce_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, x indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _reduce = tl.reduce(row_data, 1, helper_function_0) + _reduce = tl.reduce(row_data, 1, add_combine_fn_0) tl.store(result + indices_0 * result_stride_0, _reduce, mask_0) def test_reduce_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -123,7 +123,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -132,7 +132,7 @@ def _test_reduce_codegen_kernel_kernel(x, result, x_size_1, x_stride_1, _RDIM_SI indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) - _reduce = tl.reduce(row_data, 1, helper_function_0) + _reduce = tl.reduce(row_data, 1, add_combine_fn_0) tl.store(result + tl.zeros([1], tl.int32), _reduce, None) def test_reduce_codegen_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -152,7 +152,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -165,7 +165,7 @@ def _test_reduce_int_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _reduce = tl.reduce(row_data, 1, helper_function_0) + _reduce = tl.reduce(row_data, 1, add_combine_fn_0) tl.store(result + indices_0 * result_stride_0, _reduce, mask_0) def test_reduce_int_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -186,7 +186,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def jit_add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -199,7 +199,7 @@ def _test_reduce_jit_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _reduce = tl.reduce(row_data, 1, helper_function_0) + _reduce = tl.reduce(row_data, 1, jit_add_combine_fn_0) tl.store(result + indices_0 * result_stride_0, _reduce, mask_0) def test_reduce_jit_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -221,7 +221,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def max_combine_fn_0(param_0, param_1): v_0 = triton_helpers.maximum(param_0, param_1) return v_0 @@ -234,7 +234,7 @@ def _test_reduce_max_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _reduce = tl.reduce(row_data, 1, helper_function_0) + _reduce = tl.reduce(row_data, 1, max_combine_fn_0) tl.store(result + indices_0 * result_stride_0, _reduce, mask_0) def test_reduce_max_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -256,7 +256,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def min_combine_fn_0(param_0, param_1): v_0 = triton_helpers.minimum(param_0, param_1) return v_0 @@ -269,7 +269,7 @@ def _test_reduce_min_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_ indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _reduce = tl.reduce(row_data, 1, helper_function_0) + _reduce = tl.reduce(row_data, 1, min_combine_fn_0) tl.store(result + indices_0 * result_stride_0, _reduce, mask_0) def test_reduce_min_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -290,7 +290,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def mul_combine_fn_0(param_0, param_1): v_0 = param_0 * param_1 return v_0 @@ -304,7 +304,7 @@ def _test_reduce_product_kernel_kernel(x, result, x_size_0, x_size_1, result_str mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) _mask_to = tl.where(mask_0[:, None] & mask_1[None, :], row_data, 1.0) - _reduce = tl.reduce(_mask_to, 1, helper_function_0) + _reduce = tl.reduce(_mask_to, 1, mul_combine_fn_0) tl.store(result + indices_0 * result_stride_0, _reduce, mask_0) def test_reduce_product_kernel(x: torch.Tensor, *, _launcher=_default_launcher): @@ -325,7 +325,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def tuple_add_combine_fn_0(param_0, param_1, param_2, param_3): v_0 = param_0 + param_2 v_1 = param_1 + param_3 return (v_0, v_1) @@ -340,8 +340,8 @@ def _test_reduce_tuple_kernel_kernel(x, y, result_x, result_y, x_size_0, x_size_ mask_1 = indices_1 < x_size_1 row_x = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) row_y = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - getitem = tl.reduce((row_x, row_y), 1, helper_function_0)[0] - getitem_1 = tl.reduce((row_x, row_y), 1, helper_function_0)[1] + getitem = tl.reduce((row_x, row_y), 1, tuple_add_combine_fn_0)[0] + getitem_1 = tl.reduce((row_x, row_y), 1, tuple_add_combine_fn_0)[1] tl.store(result_x + indices_0 * result_x_stride_0, getitem, mask_0) tl.store(result_y + indices_0 * result_y_stride_0, getitem_1, mask_0) @@ -364,7 +364,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def tuple_add_combine_unpacked_fn_0(param_0, param_1, param_2, param_3): v_0 = param_0 + param_2 v_1 = param_1 + param_3 return (v_0, v_1) @@ -379,8 +379,8 @@ def _test_reduce_tuple_unpacked_kernel_kernel(x, y, result_x, result_y, x_size_0 mask_1 = indices_1 < x_size_1 row_x = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) row_y = tl.load(y + (indices_0[:, None] * y_stride_0 + indices_1[None, :] * y_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - getitem = tl.reduce((row_x, row_y), 1, helper_function_0)[0] - getitem_1 = tl.reduce((row_x, row_y), 1, helper_function_0)[1] + getitem = tl.reduce((row_x, row_y), 1, tuple_add_combine_unpacked_fn_0)[0] + getitem_1 = tl.reduce((row_x, row_y), 1, tuple_add_combine_unpacked_fn_0)[1] tl.store(result_x + indices_0 * result_x_stride_0, getitem, mask_0) tl.store(result_y + indices_0 * result_y_stride_0, getitem_1, mask_0) @@ -403,7 +403,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def argmax_combine_fn_0(param_0, param_1, param_2, param_3): v_0 = param_2 > param_0 v_1 = tl.where(v_0, param_2, param_0) v_2 = tl.where(v_0, param_3, param_1) @@ -419,7 +419,7 @@ def _test_tuple_oneline_kernel_kernel(indices, values, result, indices_size_1, i mask_1 = indices_1 < indices_size_1 row_values = tl.load(values + (indices_0[:, None] * values_stride_0 + indices_1[None, :] * values_stride_1), mask_0[:, None] & mask_1[None, :], other=0) row_indices = tl.load(indices + (indices_0[:, None] * indices_stride_0 + indices_1[None, :] * indices_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - max_index = tl.reduce((row_values, row_indices), 1, helper_function_0)[1] + max_index = tl.reduce((row_values, row_indices), 1, argmax_combine_fn_0)[1] tl.store(result + indices_0 * result_stride_0, max_index, mask_0) def test_tuple_oneline_kernel(values: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher): @@ -441,7 +441,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1, param_2, param_3): +def argmax_combine_fn_0(param_0, param_1, param_2, param_3): v_0 = param_2 > param_0 v_1 = tl.where(v_0, param_2, param_0) v_2 = tl.where(v_0, param_3, param_1) @@ -457,7 +457,7 @@ def _test_tuple_twoline_kernel_kernel(indices, values, result, indices_size_1, i mask_1 = indices_1 < indices_size_1 row_values = tl.load(values + (indices_0[:, None] * values_stride_0 + indices_1[None, :] * values_stride_1), mask_0[:, None] & mask_1[None, :], other=0) row_indices = tl.load(indices + (indices_0[:, None] * indices_stride_0 + indices_1[None, :] * indices_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - max_index = tl.reduce((row_values, row_indices), 1, helper_function_0)[1] + max_index = tl.reduce((row_values, row_indices), 1, argmax_combine_fn_0)[1] tl.store(result + indices_0 * result_stride_0, max_index, mask_0) def test_tuple_twoline_kernel(values: torch.Tensor, indices: torch.Tensor, *, _launcher=_default_launcher): @@ -479,7 +479,7 @@ from helion.runtime import default_launcher as _default_launcher import test.test_reduce as _source_module @triton.jit -def helper_function_0(param_0, param_1): +def add_combine_fn_0(param_0, param_1): v_0 = param_0 + param_1 return v_0 @@ -492,7 +492,7 @@ def _test_reduce_keep_dims_kernel_kernel(x, result, x_size_0, x_size_1, result_s indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) - _reduce = tl.reduce(row_data, 1, helper_function_0, keep_dims=True) + _reduce = tl.reduce(row_data, 1, add_combine_fn_0, keep_dims=True) tl.store(result + indices_0[:, None] * result_stride_0, _reduce, mask_0[:, None]) def test_reduce_keep_dims_kernel(x: torch.Tensor, *, _launcher=_default_launcher): diff --git a/test/test_reduce.py b/test/test_reduce.py index e6b2278a..83d8c1b9 100644 --- a/test/test_reduce.py +++ b/test/test_reduce.py @@ -106,7 +106,7 @@ def test_reduce_kernel(x: torch.Tensor) -> torch.Tensor: # Check that the generated code contains triton reduce calls self.assertIn("tl.reduce", code) - self.assertIn("helper_function_", code) + self.assertIn("add_combine_fn_", code) def test_reduce_max(self): """Test reduce with maximum operation.""" @@ -369,7 +369,7 @@ def test_tuple_oneline_kernel( # Check that the generated code contains the expected elements self.assertIn("tl.reduce", code) - self.assertIn("helper_function_", code) + self.assertIn("argmax_combine_fn_", code) def test_reduce_tuple_unpacking_twoline(self): """Test tuple unpacking in two lines: result = hl.reduce(...); a, b = result""" @@ -432,7 +432,7 @@ def test_tuple_twoline_kernel( # Check that the generated code contains the expected elements self.assertIn("tl.reduce", code) - self.assertIn("helper_function_", code) + self.assertIn("argmax_combine_fn_", code) def test_reduce_argmax_negative_values(self): """Test argmax with all negative values using other=(-inf, 0).""" @@ -497,7 +497,7 @@ def test_argmax_negative_kernel( # Check that the generated code contains the expected elements self.assertIn("tl.reduce", code) - self.assertIn("helper_function_", code) + self.assertIn("argmax_combine_fn_", code) def test_reduce_code_generation(self): """Test that reduce generates correct Triton code.""" @@ -519,7 +519,7 @@ def test_reduce_codegen_kernel(x: torch.Tensor) -> torch.Tensor: # Check that the generated code contains the expected elements self.assertIn("tl.reduce", code) - self.assertIn("helper_function_", code) + self.assertIn("add_combine_fn_", code) self.assertIn("@triton.jit", code) # Verify correctness