From 4052d1287522baa541fdf0e646135ab63693580e Mon Sep 17 00:00:00 2001 From: joydddd Date: Thu, 10 Jul 2025 15:39:09 -0700 Subject: [PATCH] tensor[tile] when tile size is 1 returns a 1D tensor, instead of a scalar stack-info: PR: https://github.com/pytorch-labs/helion/pull/269, branch: joydddd/stack/14 --- helion/_compiler/indexing_strategy.py | 5 +- helion/_compiler/tile_strategy.py | 18 +++- test/test_associative_scan.expected | 144 ++++++++++++++++---------- test/test_loops.expected | 50 ++++----- test/test_reduce.expected | 21 ++-- 5 files changed, 139 insertions(+), 99 deletions(-) diff --git a/helion/_compiler/indexing_strategy.py b/helion/_compiler/indexing_strategy.py index d3d44b03..ad6df9cf 100644 --- a/helion/_compiler/indexing_strategy.py +++ b/helion/_compiler/indexing_strategy.py @@ -378,9 +378,8 @@ def create( assert len(index_values) == fake_value.ndim index_expr = [] for i, idx in enumerate(index_values): - if fake_value.size(i) != 1: - stride = state.device_function.tensor_stride(fake_value, i).name - index_expr.append(f"{idx} * {stride}") + stride = state.device_function.tensor_stride(fake_value, i).name + index_expr.append(f"{idx} * {stride}") if not index_expr: shape_str = tile_strategy.shape_str(output_size) index_expr.append(f"tl.zeros({shape_str}, {dtype})") diff --git a/helion/_compiler/tile_strategy.py b/helion/_compiler/tile_strategy.py index f7e28092..d93af8e8 100644 --- a/helion/_compiler/tile_strategy.py +++ b/helion/_compiler/tile_strategy.py @@ -670,11 +670,19 @@ def codegen_device_loop(self, state: CodegenState) -> DeviceLoopState: type_comment=None, ) assert for_node.body is body - extra_body = [ - statement_from_string( - f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})" - ), - ] + extra_body = [] + if block_size == 1: + extra_body.append( + statement_from_string( + f"{index_var} = {offset_var} + tl.zeros([1], {dtype})" + ), + ) + else: + extra_body.append( + statement_from_string( + f"{index_var} = {offset_var} + tl.arange(0, ({block_size_var})).to({dtype})" + ), + ) mask_statement = self._setup_mask( # pyright: ignore[reportAttributeAccessIssue] state, block_idx, block_size, index_var, end ) diff --git a/test/test_associative_scan.expected b/test/test_associative_scan.expected index 44919a44..a9eb3f9a 100644 --- a/test/test_associative_scan.expected +++ b/test/test_associative_scan.expected @@ -107,24 +107,27 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_codegen_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_codegen_kernel_kernel(x, result, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) - tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) + tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_codegen_kernel(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_codegen_kernel_kernel[1,](x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_codegen_kernel_kernel[1,](x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return result def _test_codegen_kernel_make_precompiler(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_codegen_kernel_kernel)(x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_codegen_kernel_kernel)(x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_associative_scan_cumulative_argmax) from __future__ import annotations @@ -352,24 +355,27 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_size_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_size_kernel_kernel(x, result, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) - tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) + tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_size_kernel(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_size_kernel_kernel[1,](x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_size_kernel_kernel[1,](x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return result def _test_size_kernel_make_precompiler(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_size_kernel_kernel)(x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_size_kernel_kernel)(x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_associative_scan_different_sizes) from __future__ import annotations @@ -466,26 +472,26 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_size_kernel_kernel(x, result, x_size_0, result_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): +def _test_size_kernel_kernel(x, result, x_size_0, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < x_size_0 - row_data = tl.load(x + indices_0[:, None] * x_stride_0, mask_0[:, None], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + tl.zeros([1], tl.int32)[None, :] * x_stride_1), mask_0[:, None], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) - tl.store(result + indices_0[:, None] * result_stride_0, _associative_scan, mask_0[:, None]) + tl.store(result + (indices_0[:, None] * result_stride_0 + tl.zeros([1], tl.int32)[None, :] * result_stride_1), _associative_scan, mask_0[:, None]) def test_size_kernel(x: torch.Tensor): result = torch.empty_like(x) _BLOCK_SIZE_0 = 2 - _test_size_kernel_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + _test_size_kernel_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, result, x.size(0), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, num_warps=4, num_stages=3) return result def _test_size_kernel_make_precompiler(x: torch.Tensor): result = torch.empty_like(x) _BLOCK_SIZE_0 = 2 from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_size_kernel_kernel)(x, result, x.size(0), result.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return make_precompiler(_test_size_kernel_kernel)(x, result, x.size(0), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_associative_scan_different_sizes) from __future__ import annotations @@ -576,20 +582,23 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_single_element_kernel(x, result): - row_data = tl.load(x + tl.zeros([1, 1], tl.int32), None) +def _test_single_element_kernel(x, result, result_stride_0, result_stride_1, x_stride_0, x_stride_1): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + tl.zeros([1], tl.int32)[None, :] * x_stride_1), None) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) - tl.store(result + tl.zeros([1, 1], tl.int32), _associative_scan, None) + tl.store(result + (indices_0[:, None] * result_stride_0 + tl.zeros([1], tl.int32)[None, :] * result_stride_1), _associative_scan, None) def test_single_element(x: torch.Tensor): result = torch.empty_like(x) - _test_single_element_kernel[1,](x, result, num_warps=4, num_stages=3) + _test_single_element_kernel[1,](x, result, result.stride(0), result.stride(1), x.stride(0), x.stride(1), num_warps=4, num_stages=3) return result def _test_single_element_make_precompiler(x: torch.Tensor): result = torch.empty_like(x) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_single_element_kernel)(x, result, num_warps=4, num_stages=3) + return make_precompiler(_test_single_element_kernel)(x, result, result.stride(0), result.stride(1), x.stride(0), x.stride(1), num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_associative_scan_edge_cases) from __future__ import annotations @@ -606,24 +615,27 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_single_element_kernel(x, result, x_size_1, result_stride_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_single_element_kernel(x, result, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) - tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) + tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_single_element(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_single_element_kernel[1,](x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_single_element_kernel[1,](x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return result def _test_single_element_make_precompiler(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_single_element_kernel)(x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_single_element_kernel)(x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_associative_scan_in_helper_function) from __future__ import annotations @@ -680,24 +692,27 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_jit_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_jit_kernel_kernel(x, result, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) - tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) + tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_jit_kernel(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_jit_kernel_kernel[1,](x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_jit_kernel_kernel[1,](x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return result def _test_jit_kernel_make_precompiler(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_jit_kernel_kernel)(x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_jit_kernel_kernel)(x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_associative_scan_large_scale) from __future__ import annotations @@ -839,20 +854,23 @@ def helper_function_1(param_0, param_1): return v_0 @triton.jit -def _test_multi_kernel_kernel(x, sum_result, max_result, x_size_1, max_result_stride_1, sum_result_stride_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_multi_kernel_kernel(x, sum_result, max_result, x_size_1, max_result_stride_0, max_result_stride_1, sum_result_stride_0, sum_result_stride_1, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) - tl.store(sum_result + indices_1[None, :] * sum_result_stride_1, _associative_scan, mask_1[None, :]) + tl.store(sum_result + (indices_0[:, None] * sum_result_stride_0 + indices_1[None, :] * sum_result_stride_1), _associative_scan, mask_1[None, :]) _associative_scan_1 = tl.associative_scan(_associative_scan, 1, helper_function_1) - tl.store(max_result + indices_1[None, :] * max_result_stride_1, _associative_scan_1, mask_1[None, :]) + tl.store(max_result + (indices_0[:, None] * max_result_stride_0 + indices_1[None, :] * max_result_stride_1), _associative_scan_1, mask_1[None, :]) def test_multi_kernel(x: torch.Tensor): sum_result = torch.empty_like(x) max_result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_multi_kernel_kernel[1,](x, sum_result, max_result, x.size(1), max_result.stride(1), sum_result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_multi_kernel_kernel[1,](x, sum_result, max_result, x.size(1), max_result.stride(0), max_result.stride(1), sum_result.stride(0), sum_result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return sum_result def _test_multi_kernel_make_precompiler(x: torch.Tensor): @@ -860,7 +878,7 @@ def _test_multi_kernel_make_precompiler(x: torch.Tensor): max_result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_multi_kernel_kernel)(x, sum_result, max_result, x.size(1), max_result.stride(1), sum_result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_multi_kernel_kernel)(x, sum_result, max_result, x.size(1), max_result.stride(0), max_result.stride(1), sum_result.stride(0), sum_result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_associative_scan_multiplication) from __future__ import annotations @@ -917,24 +935,27 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_reverse_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_reverse_kernel_kernel(x, result, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0, reverse=True) - tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) + tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_reverse_kernel(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_reverse_kernel_kernel[1,](x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_reverse_kernel_kernel[1,](x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return result def _test_reverse_kernel_make_precompiler(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_reverse_kernel_kernel)(x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_reverse_kernel_kernel)(x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_associative_scan_segmented_reduction) from __future__ import annotations @@ -1367,24 +1388,27 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_cumprod_reverse_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_cumprod_reverse_kernel_kernel(x, result, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0, reverse=True) - tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) + tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_cumprod_reverse_kernel(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_cumprod_reverse_kernel_kernel[1,](x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_cumprod_reverse_kernel_kernel[1,](x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return result def _test_cumprod_reverse_kernel_make_precompiler(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_cumprod_reverse_kernel_kernel)(x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_cumprod_reverse_kernel_kernel)(x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_cumsum_basic) from __future__ import annotations @@ -1442,20 +1466,23 @@ def helper_function_1(param_0, param_1): return v_0 @triton.jit -def _test_mixed_kernel_kernel(x, sum_result, prod_result, x_size_1, prod_result_stride_1, sum_result_stride_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_mixed_kernel_kernel(x, sum_result, prod_result, x_size_1, prod_result_stride_0, prod_result_stride_1, sum_result_stride_0, sum_result_stride_1, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0) - tl.store(sum_result + indices_1[None, :] * sum_result_stride_1, _associative_scan, mask_1[None, :]) + tl.store(sum_result + (indices_0[:, None] * sum_result_stride_0 + indices_1[None, :] * sum_result_stride_1), _associative_scan, mask_1[None, :]) _associative_scan_1 = tl.associative_scan(_associative_scan, 1, helper_function_1) - tl.store(prod_result + indices_1[None, :] * prod_result_stride_1, _associative_scan_1, mask_1[None, :]) + tl.store(prod_result + (indices_0[:, None] * prod_result_stride_0 + indices_1[None, :] * prod_result_stride_1), _associative_scan_1, mask_1[None, :]) def test_mixed_kernel(x: torch.Tensor): sum_result = torch.empty_like(x) prod_result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_mixed_kernel_kernel[1,](x, sum_result, prod_result, x.size(1), prod_result.stride(1), sum_result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_mixed_kernel_kernel[1,](x, sum_result, prod_result, x.size(1), prod_result.stride(0), prod_result.stride(1), sum_result.stride(0), sum_result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return sum_result def _test_mixed_kernel_make_precompiler(x: torch.Tensor): @@ -1463,7 +1490,7 @@ def _test_mixed_kernel_make_precompiler(x: torch.Tensor): prod_result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_mixed_kernel_kernel)(x, sum_result, prod_result, x.size(1), prod_result.stride(1), sum_result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_mixed_kernel_kernel)(x, sum_result, prod_result, x.size(1), prod_result.stride(0), prod_result.stride(1), sum_result.stride(0), sum_result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestAssociativeScan.test_cumsum_different_dtypes) from __future__ import annotations @@ -1630,21 +1657,24 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_cumsum_reverse_kernel_kernel(x, result, x_size_1, result_stride_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_cumsum_reverse_kernel_kernel(x, result, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _associative_scan = tl.associative_scan(row_data, 1, helper_function_0, reverse=True) - tl.store(result + indices_1[None, :] * result_stride_1, _associative_scan, mask_1[None, :]) + tl.store(result + (indices_0[:, None] * result_stride_0 + indices_1[None, :] * result_stride_1), _associative_scan, mask_1[None, :]) def test_cumsum_reverse_kernel(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_cumsum_reverse_kernel_kernel[1,](x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_cumsum_reverse_kernel_kernel[1,](x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return result def _test_cumsum_reverse_kernel_make_precompiler(x: torch.Tensor): result = torch.empty_like(x) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_cumsum_reverse_kernel_kernel)(x, result, x.size(1), result.stride(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_cumsum_reverse_kernel_kernel)(x, result, x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) diff --git a/test/test_loops.expected b/test/test_loops.expected index 7b4ff877..e122d73b 100644 --- a/test/test_loops.expected +++ b/test/test_loops.expected @@ -66,7 +66,7 @@ def _device_loop_3d_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, out indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < b for offset_3 in tl.range(0, d.to(tl.int32), step=1): - indices_3 = offset_3 + tl.arange(0, 1).to(tl.int32) + indices_3 = offset_3 + tl.zeros([1], tl.int32) load = tl.load(x + (indices_0[:, None, None, None] * x_stride_0 + indices_1[None, :, None, None] * x_stride_1 + indices_2[None, None, :, None] * x_stride_2 + indices_3[None, None, None, :] * x_stride_3), mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None], other=0) v_0 = tl_math.sin(load) tl.store(out + (indices_0[:, None, None, None] * out_stride_0 + indices_1[None, :, None, None] * out_stride_1 + indices_2[None, None, :, None] * out_stride_2 + indices_3[None, None, None, :] * out_stride_3), v_0, mask_0[:, None, None, None] & mask_1[None, :, None, None] & mask_2[None, None, :, None]) @@ -197,7 +197,7 @@ def _chebyshev_kernel_kernel(x, w, out, out_stride_0, out_stride_1, w_stride_0, v_3 = 2.0 v_4 = in_x * v_3 for offset_2 in tl.range(2, 5, step=1): - indices_2 = offset_2 + tl.arange(0, 1).to(tl.int32) + indices_2 = offset_2 + tl.zeros([1], tl.int32) v_4_copy = v_4 in_x_0_copy = in_x_0 T0_copy = T0 @@ -245,13 +245,13 @@ import triton import triton.language as tl @triton.jit -def _fn_kernel(x, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr): +def _fn_kernel(x, end, out, x_size_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr): pid_0 = tl.program_id(0) offset_1 = pid_0 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < x_size_0 acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32) - load = tl.load(end + tl.zeros([], tl.int32), None) + load = tl.load(end + 0 * end_stride_0, None) for offset_0 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_0): indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) mask_0 = indices_0 < load @@ -267,7 +267,7 @@ def fn(x: torch.Tensor, end: torch.Tensor): bs = 32 _BLOCK_SIZE_1 = 32 _BLOCK_SIZE_0 = 32 - _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, end, out, x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) return out def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor): @@ -276,7 +276,7 @@ def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor): _BLOCK_SIZE_1 = 32 _BLOCK_SIZE_0 = 32 from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_fn_kernel)(x, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return make_precompiler(_fn_kernel)(x, end, out, x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) --- assertExpectedJournal(TestLoops.test_data_dependent_bounds2) from __future__ import annotations @@ -286,13 +286,13 @@ import triton import triton.language as tl @triton.jit -def _fn_kernel(x, end, out, out_size_0, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): +def _fn_kernel(x, end, out, out_size_0, x_size_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < x_size_0 acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32) - load = tl.load(end + tl.zeros([], tl.int32), None) + load = tl.load(end + 0 * end_stride_0, None) for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < load @@ -307,7 +307,7 @@ def fn(x: torch.Tensor, end: torch.Tensor): out = x.new_empty([x.size(0)]) _BLOCK_SIZE_0 = 32 _BLOCK_SIZE_1 = 32 - _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end, out, out.size(0), x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) return out def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor): @@ -315,7 +315,7 @@ def _fn_make_precompiler(x: torch.Tensor, end: torch.Tensor): _BLOCK_SIZE_0 = 32 _BLOCK_SIZE_1 = 32 from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_fn_kernel)(x, end, out, out.size(0), x.size(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestLoops.test_data_dependent_bounds3) from __future__ import annotations @@ -325,14 +325,14 @@ import triton import triton.language as tl @triton.jit -def _fn_kernel(x, end0, end1, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): +def _fn_kernel(x, end0, end1, out, x_size_0, end0_stride_0, end1_stride_0, out_stride_0, x_stride_0, x_stride_1, x_stride_2, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < x_size_0 acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float64) - load = tl.load(end0 + tl.zeros([], tl.int32), None) - load_1 = tl.load(end1 + tl.zeros([], tl.int32), None) + load = tl.load(end0 + 0 * end0_stride_0, None) + load_1 = tl.load(end1 + 0 * end1_stride_0, None) for offset_1 in tl.range(0, load.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < load @@ -352,7 +352,7 @@ def fn(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor): _BLOCK_SIZE_0 = 32 _BLOCK_SIZE_2 = 32 _BLOCK_SIZE_1 = 32 - _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) return out def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor): @@ -361,7 +361,7 @@ def _fn_make_precompiler(x: torch.Tensor, end0: torch.Tensor, end1: torch.Tensor _BLOCK_SIZE_2 = 32 _BLOCK_SIZE_1 = 32 from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_fn_kernel)(x, end0, end1, out, x.size(0), end0.stride(0), end1.stride(0), out.stride(0), x.stride(0), x.stride(1), x.stride(2), _BLOCK_SIZE_0, _BLOCK_SIZE_2, _BLOCK_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestLoops.test_data_dependent_bounds4) from __future__ import annotations @@ -371,14 +371,14 @@ import triton import triton.language as tl @triton.jit -def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr): +def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr): pid_0 = tl.program_id(0) offset_1 = pid_0 * _BLOCK_SIZE_1 indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) mask_1 = indices_1 < x_size_0 acc = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32) - load = tl.load(begin + tl.zeros([], tl.int32), None) - load_1 = tl.load(end + tl.zeros([], tl.int32), None) + load = tl.load(begin + 0 * begin_stride_0, None) + load_1 = tl.load(end + 0 * end_stride_0, None) for offset_0 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_0): indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32) mask_0 = indices_0 < load_1 @@ -394,7 +394,7 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor): bs = 32 _BLOCK_SIZE_1 = 32 _BLOCK_SIZE_0 = 32 - _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_1),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) return out def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor): @@ -403,7 +403,7 @@ def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor _BLOCK_SIZE_1 = 32 _BLOCK_SIZE_0 = 32 from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3) --- assertExpectedJournal(TestLoops.test_data_dependent_bounds5) from __future__ import annotations @@ -413,14 +413,14 @@ import triton import triton.language as tl @triton.jit -def _fn_kernel(x, begin, end, out, x_size_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): +def _fn_kernel(x, begin, end, out, x_size_0, begin_stride_0, end_stride_0, out_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) mask_0 = indices_0 < x_size_0 acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32) - load = tl.load(begin + tl.zeros([], tl.int32), None) - load_1 = tl.load(end + tl.zeros([], tl.int32), None) + load = tl.load(begin + 0 * begin_stride_0, None) + load_1 = tl.load(end + 0 * end_stride_0, None) for offset_1 in tl.range(load.to(tl.int32), load_1.to(tl.int32), step=_BLOCK_SIZE_1): indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) mask_1 = indices_1 < load_1 @@ -435,7 +435,7 @@ def fn(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor): out = x.new_empty([x.size(0)]) _BLOCK_SIZE_0 = 32 _BLOCK_SIZE_1 = 32 - _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + _fn_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) return out def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor): @@ -443,7 +443,7 @@ def _fn_make_precompiler(x: torch.Tensor, begin: torch.Tensor, end: torch.Tensor _BLOCK_SIZE_0 = 32 _BLOCK_SIZE_1 = 32 from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_fn_kernel)(x, begin, end, out, x.size(0), begin.stride(0), end.stride(0), out.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestLoops.test_l2_grouping_with_register_block_size) from __future__ import annotations diff --git a/test/test_reduce.expected b/test/test_reduce.expected index e36ac00f..3920238d 100644 --- a/test/test_reduce.expected +++ b/test/test_reduce.expected @@ -147,24 +147,27 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_reduce_codegen_kernel_kernel(x, result, x_size_1, x_stride_1, _RDIM_SIZE_1: tl.constexpr): +def _test_reduce_codegen_kernel_kernel(x, result, x_size_1, result_stride_0, x_stride_0, x_stride_1, _RDIM_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) indices_1 = tl.arange(0, _RDIM_SIZE_1).to(tl.int32) mask_1 = indices_1 < x_size_1 - row_data = tl.load(x + indices_1[None, :] * x_stride_1, mask_1[None, :], other=0) + row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_1[None, :], other=0) _reduce = tl.reduce(row_data, 1, helper_function_0) - tl.store(result + tl.zeros([1], tl.int32), _reduce, None) + tl.store(result + indices_0 * result_stride_0, _reduce, None) def test_reduce_codegen_kernel(x: torch.Tensor): result = torch.empty([x.size(0)], dtype=x.dtype, device=x.device) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_reduce_codegen_kernel_kernel[1,](x, result, x.size(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_reduce_codegen_kernel_kernel[1,](x, result, x.size(1), result.stride(0), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) return result def _test_reduce_codegen_kernel_make_precompiler(x: torch.Tensor): result = torch.empty([x.size(0)], dtype=x.dtype, device=x.device) _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_reduce_codegen_kernel_kernel)(x, result, x.size(1), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_reduce_codegen_kernel_kernel)(x, result, x.size(1), result.stride(0), x.stride(0), x.stride(1), _RDIM_SIZE_1, num_warps=4, num_stages=3) --- assertExpectedJournal(TestReduce.test_reduce_different_dtypes) from __future__ import annotations @@ -566,7 +569,7 @@ def helper_function_0(param_0, param_1): return v_0 @triton.jit -def _test_reduce_keep_dims_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): +def _test_reduce_keep_dims_kernel_kernel(x, result, x_size_0, x_size_1, result_stride_0, result_stride_1, x_stride_0, x_stride_1, _BLOCK_SIZE_0: tl.constexpr, _RDIM_SIZE_1: tl.constexpr): pid_0 = tl.program_id(0) offset_0 = pid_0 * _BLOCK_SIZE_0 indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) @@ -575,13 +578,13 @@ def _test_reduce_keep_dims_kernel_kernel(x, result, x_size_0, x_size_1, result_s mask_1 = indices_1 < x_size_1 row_data = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) _reduce = tl.reduce(row_data, 1, helper_function_0, keep_dims=True) - tl.store(result + indices_0[:, None] * result_stride_0, _reduce, mask_0[:, None]) + tl.store(result + (indices_0[:, None] * result_stride_0 + tl.zeros([1], tl.int32)[None, :] * result_stride_1), _reduce, mask_0[:, None]) def test_reduce_keep_dims_kernel(x: torch.Tensor): result = torch.empty([x.size(0), 1], dtype=x.dtype, device=x.device) _BLOCK_SIZE_0 = 2 _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) - _test_reduce_keep_dims_kernel_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, result, x.size(0), x.size(1), result.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + _test_reduce_keep_dims_kernel_kernel[triton.cdiv(x.size(0), _BLOCK_SIZE_0),](x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) return result def _test_reduce_keep_dims_kernel_make_precompiler(x: torch.Tensor): @@ -589,4 +592,4 @@ def _test_reduce_keep_dims_kernel_make_precompiler(x: torch.Tensor): _BLOCK_SIZE_0 = 2 _RDIM_SIZE_1 = triton.next_power_of_2(x.size(1)) from helion.runtime.precompile_shim import make_precompiler - return make_precompiler(_test_reduce_keep_dims_kernel_kernel)(x, result, x.size(0), x.size(1), result.stride(0), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) + return make_precompiler(_test_reduce_keep_dims_kernel_kernel)(x, result, x.size(0), x.size(1), result.stride(0), result.stride(1), x.stride(0), x.stride(1), _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)