Skip to content

Commit 573fc23

Browse files
authored
Add sum example and test (#256)
- Add sum kernel that reduces 2D tensors along the last dimension - Add unit test for sum kernel
1 parent bd0f27a commit 573fc23

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

examples/sum.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import run_example
7+
import helion.language as hl
8+
9+
10+
@helion.kernel()
11+
def sum_kernel(x: torch.Tensor) -> torch.Tensor:
12+
"""Sum 2D tensor along the last dimension."""
13+
m, n = x.shape
14+
out = torch.empty([m], dtype=x.dtype, device=x.device)
15+
16+
for tile_m in hl.tile(m):
17+
out[tile_m] = x[tile_m, :].sum(-1)
18+
19+
return out
20+
21+
22+
def check(m: int, n: int) -> None:
23+
x = torch.randn([m, n], device="cuda", dtype=torch.float32)
24+
kernels = {"helion": sum_kernel}
25+
run_example(kernels, lambda x: x.sum(-1), (x,))
26+
27+
28+
def main() -> None:
29+
check(512, 256)
30+
check(1024, 1024)
31+
32+
33+
if __name__ == "__main__":
34+
main()

test/test_examples.expected

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,44 @@ def _softmax_two_pass_make_precompiler(x: torch.Tensor):
13141314
from helion.runtime.precompile_shim import make_precompiler
13151315
return make_precompiler(_softmax_two_pass_kernel)(x, out, out.size(0), out.size(1), x.size(0), x.size(1), out.stride(0), out.stride(1), x.stride(0), x.stride(1), m, n, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
13161316

1317+
--- assertExpectedJournal(TestExamples.test_sum)
1318+
from __future__ import annotations
1319+
1320+
import torch
1321+
import triton
1322+
import triton.language as tl
1323+
1324+
@triton.jit
1325+
def _sum_kernel_kernel(x, out, out_stride_0, x_stride_0, x_stride_1, n, _REDUCTION_BLOCK_1: tl.constexpr):
1326+
pid_0 = tl.program_id(0)
1327+
offset_0 = pid_0
1328+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
1329+
sum_1_acc = tl.full([1, _REDUCTION_BLOCK_1], 0, tl.float32)
1330+
for roffset_1 in tl.range(0, n, step=_REDUCTION_BLOCK_1):
1331+
rindex_1 = roffset_1 + tl.arange(0, _REDUCTION_BLOCK_1).to(tl.int32)
1332+
mask_1 = rindex_1 < n
1333+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + rindex_1[None, :] * x_stride_1), mask_1[None, :], other=0)
1334+
v_0 = sum_1_acc + load
1335+
sum_1_acc = v_0
1336+
sum_1 = tl.sum(sum_1_acc, 1)
1337+
tl.store(out + indices_0 * out_stride_0, sum_1, None)
1338+
1339+
def sum_kernel(x: torch.Tensor):
1340+
"""Sum 2D tensor along the last dimension."""
1341+
m, n = x.shape
1342+
out = torch.empty([m], dtype=x.dtype, device=x.device)
1343+
_REDUCTION_BLOCK_1 = 32768
1344+
_sum_kernel_kernel[m,](x, out, out.stride(0), x.stride(0), x.stride(1), n, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
1345+
return out
1346+
1347+
def _sum_kernel_make_precompiler(x: torch.Tensor):
1348+
"""Sum 2D tensor along the last dimension."""
1349+
m, n = x.shape
1350+
out = torch.empty([m], dtype=x.dtype, device=x.device)
1351+
_REDUCTION_BLOCK_1 = 32768
1352+
from helion.runtime.precompile_shim import make_precompiler
1353+
return make_precompiler(_sum_kernel_kernel)(x, out, out.stride(0), x.stride(0), x.stride(1), n, _REDUCTION_BLOCK_1, num_warps=4, num_stages=3)
1354+
13171355
--- assertExpectedJournal(TestExamples.test_template_via_closure0)
13181356
from __future__ import annotations
13191357

@@ -1490,3 +1528,4 @@ def _matmul_with_epilogue_make_precompiler(x: Tensor, y: Tensor, epilogue: Calla
14901528
_BLOCK_SIZE_2 = 16
14911529
from helion.runtime.precompile_shim import make_precompiler
14921530
return make_precompiler(_matmul_with_epilogue_kernel)(x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=2, num_stages=4)
1531+

test/test_examples.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,19 @@ def test_matmul_split_k(self):
420420
)
421421
)
422422

423+
def test_sum(self):
424+
args = (torch.randn([512, 512], device=DEVICE, dtype=torch.float32),)
425+
self.assertExpectedJournal(
426+
check_example(
427+
"sum",
428+
args,
429+
torch.sum(args[0], dim=-1),
430+
fn_name="sum_kernel",
431+
block_sizes=[1],
432+
reduction_loops=[32768],
433+
)
434+
)
435+
423436

424437
if __name__ == "__main__":
425438
unittest.main()

0 commit comments

Comments
 (0)