@@ -1314,6 +1314,44 @@ def _softmax_two_pass_make_precompiler(x: torch.Tensor):
1314
1314
from helion.runtime.precompile_shim import make_precompiler
1315
1315
return make_precompiler(_softmax_two_pass_kernel)(x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1316
1316
1317
+ --- assertExpectedJournal(TestExamples.test_sum)
1318
+ from __future__ import annotations
1319
+
1320
+ import torch
1321
+ import triton
1322
+ import triton.language as tl
1323
+
1324
+ @triton.jit
1325
+ def _sum_kernel_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, n, _REDUCTION_BLOCK_1: tl.constexpr):
1326
+ pid_0 = tl.program_id(0)
1327
+ offset_0 = pid_0
1328
+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
1329
+ sum_1_acc = tl.full([1, _REDUCTION_BLOCK_1], 0, tl.float32)
1330
+ for roffset_1 in tl.range(0, n, step=_REDUCTION_BLOCK_1):
1331
+ rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
1332
+ mask_1 = rindex_1 < n
1333
+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_1[None, :], other=0)
1334
+ v_0 = sum_1_acc + load
1335
+ sum_1_acc = v_0
1336
+ sum_1 = tl.sum(sum_1_acc, 1)
1337
+ tl.store(out + indices_0 * out_stride_0, sum_1, None)
1338
+
1339
+ def sum_kernel(x: torch.Tensor):
1340
+ """Sum 2D tensor along the last dimension."""
1341
+ m, n = x.shape
1342
+ out = torch.empty([m], dtype=x.dtype, device=x.device)
1343
+ _REDUCTION_BLOCK_1 = 32768
1344
+ _sum_kernel_kernel[m,](x, out, out.stride(0), x.stride(0), x.stride(1), n, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
1345
+ return out
1346
+
1347
+ def _sum_kernel_make_precompiler(x: torch.Tensor):
1348
+ """Sum 2D tensor along the last dimension."""
1349
+ m, n = x.shape
1350
+ out = torch.empty([m], dtype=x.dtype, device=x.device)
1351
+ _REDUCTION_BLOCK_1 = 32768
1352
+ from helion.runtime.precompile_shim import make_precompiler
1353
+ return make_precompiler(_sum_kernel_kernel)(x, out, out.stride(0), x.stride(0), x.stride(1), n, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
1354
+
1317
1355
--- assertExpectedJournal(TestExamples.test_template_via_closure0)
1318
1356
from __future__ import annotations
1319
1357
@@ -1490,3 +1528,4 @@ def _matmul_with_epilogue_make_precompiler(x: Tensor, y: Tensor, epilogue: Calla
1490
1528
_BLOCK_SIZE_2 = 16
1491
1529
from helion.runtime.precompile_shim import make_precompiler
1492
1530
return make_precompiler(_matmul_with_epilogue_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1531
+
0 commit comments