Skip to content

Commit 25b3ab9

Browse files
authored
Add examples/segment_reduction.py (#300)
1 parent 2ddab75 commit 25b3ab9

File tree

3 files changed

+254
-0
lines changed

3 files changed

+254
-0
lines changed

examples/segment_reduction.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Code based on https://github.com/pytorch-labs/helion/issues/237
2+
from __future__ import annotations
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
8+
import helion
9+
from helion._testing import DEVICE
10+
from helion._testing import run_example
11+
import helion.language as hl
12+
13+
14+
def combine_fn_helion(
15+
left_values: torch.Tensor,
16+
left_indices: torch.Tensor,
17+
right_values: torch.Tensor,
18+
right_indices: torch.Tensor,
19+
) -> tuple[torch.Tensor, torch.Tensor]:
20+
combined_values = torch.where(
21+
left_indices == right_indices, left_values + right_values, right_values
22+
)
23+
return combined_values, right_indices
24+
25+
26+
@helion.kernel()
27+
def segmented_reduction_helion(
28+
indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int
29+
) -> torch.Tensor:
30+
num_elements, num_features = input_data.shape
31+
output = torch.zeros(
32+
(num_nodes, num_features), dtype=input_data.dtype, device=input_data.device
33+
)
34+
for tile_e, tile_f in hl.tile([num_elements, num_features]):
35+
vals = input_data[tile_e, tile_f]
36+
idxs = indices[tile_e]
37+
idxs_next = hl.load(
38+
indices, [tile_e.index + 1], extra_mask=tile_e.index < num_elements - 1
39+
)
40+
tuple_in = (vals, idxs.float().unsqueeze(1).expand_as(vals))
41+
out_vals, _ = hl.associative_scan(combine_fn_helion, tuple_in, dim=0)
42+
mask = (idxs != idxs_next) | (
43+
tile_e.index % tile_e.block_size == tile_e.block_size - 1
44+
)
45+
segment_vals = torch.where(mask.unsqueeze(1), out_vals, 0.0)
46+
hl.atomic_add(output, [idxs, tile_f], segment_vals)
47+
return output
48+
49+
50+
@triton.jit
51+
def combine_fn_triton(left_values, left_indices, right_values, right_indices):
52+
same_segment = left_indices == right_indices
53+
combined_values = tl.where(same_segment, left_values + right_values, right_values)
54+
combined_indices = right_indices
55+
return combined_values, combined_indices
56+
57+
58+
@triton.autotune(
59+
configs=[
60+
triton.Config(
61+
{"BLOCK_SIZE": bs},
62+
)
63+
for bs in [8, 16, 32, 64, 128]
64+
],
65+
key=["C"],
66+
restore_value=["out_ptr"],
67+
)
68+
@triton.jit
69+
def _segmented_reduction_triton(
70+
index, # the input index tensor
71+
in_ptr, # the input tensor
72+
out_ptr, # the output value tensor
73+
E: tl.constexpr, # Number of elements in the input tensor (1d)
74+
C: tl.constexpr, # Number of features in the input tensor (2d)
75+
BLOCK_SIZE: tl.constexpr, # Block size for the scan
76+
):
77+
# Triton version adapted from
78+
# https://github.com/fishmingyu/GeoT/blob/main/geot/triton/seg_reduction.py
79+
pid = tl.program_id(axis=0)
80+
offset_pid = pid // C
81+
feature_id = pid % C
82+
offsets = offset_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
83+
mask = offsets < E
84+
85+
# Load input data
86+
vals = tl.load(in_ptr + offsets * C + feature_id, mask=mask)
87+
idxs = tl.load(index + offsets, mask=mask)
88+
idxs_next = tl.load(index + offsets + 1, offsets < E - 1)
89+
90+
# Perform an inclusive scan using tl.associative_scan
91+
result_values, _ = tl.associative_scan(
92+
(
93+
vals,
94+
idxs,
95+
),
96+
axis=0,
97+
combine_fn=combine_fn_triton,
98+
)
99+
# if offset % BLOCK_SIZE == -1, it means the last element of the segment
100+
segment_start = (idxs != idxs_next) | (offsets % BLOCK_SIZE == BLOCK_SIZE - 1)
101+
tl.atomic_add(out_ptr + idxs * C + feature_id, result_values, mask & segment_start)
102+
103+
104+
def segmented_reduction_triton(indices, input_data, num_nodes):
105+
E, C = input_data.shape
106+
output = torch.zeros(
107+
(num_nodes, C), dtype=input_data.dtype, device=input_data.device
108+
)
109+
110+
def grid(META):
111+
return (triton.cdiv(E, META["BLOCK_SIZE"]) * C,)
112+
113+
_segmented_reduction_triton[grid](indices, input_data, output, E, C)
114+
return output
115+
116+
117+
def segmented_reduction_pytorch(indices, input_data, num_nodes):
118+
# Run PyTorch reference (scatter_add equivalent)
119+
num_features = input_data.size(1)
120+
pytorch_output = torch.zeros(
121+
num_nodes, num_features, device=input_data.device, dtype=input_data.dtype
122+
)
123+
pytorch_output.scatter_add_(
124+
0, indices.unsqueeze(1).expand(-1, num_features), input_data
125+
)
126+
return pytorch_output
127+
128+
129+
def main():
130+
num_nodes = 100
131+
num_edges = 2000
132+
num_features = 128
133+
134+
dtype = torch.float32
135+
136+
# Create sorted indices for segmented reduction
137+
indices = torch.randint(0, num_nodes, (num_edges,), device=DEVICE).sort()[0]
138+
input_data = torch.randn(num_edges, num_features, device=DEVICE, dtype=dtype)
139+
140+
run_example(
141+
segmented_reduction_helion,
142+
{
143+
"triton": segmented_reduction_triton,
144+
"pytorch": segmented_reduction_pytorch,
145+
},
146+
(indices, input_data, num_nodes),
147+
)
148+
149+
150+
if __name__ == "__main__":
151+
main()

test/test_examples.expected

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,84 @@ def _rms_norm_make_precompiler(x: torch.Tensor, weight: torch.Tensor, eps: float
13731373
from helion.runtime.precompile_shim import make_precompiler
13741374
return make_precompiler(_rms_norm_kernel)(x, weight, out, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
13751375

1376+
--- assertExpectedJournal(TestExamples.test_segment_reduction)
1377+
from __future__ import annotations
1378+
1379+
import torch
1380+
import triton
1381+
import triton.language as tl
1382+
from torch._inductor.runtime.triton_compat import libdevice
1383+
1384+
import helion._testing.segment_reduction as _source_module
1385+
1386+
@triton.jit
1387+
def helper_function_0(param_0, param_1, param_2, param_3):
1388+
v_0 = param_1 == param_3
1389+
v_1 = param_0 + param_2
1390+
v_2 = tl.where(v_0, v_1, param_2)
1391+
return (v_2, param_3)
1392+
1393+
@triton.jit
1394+
def _segmented_reduction_helion_kernel(input_data, indices, output, indices_stride_0, input_data_stride_0, input_data_stride_1, output_stride_0, output_stride_1, num_elements, num_features, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
1395+
num_blocks_0 = tl.cdiv(num_elements, _BLOCK_SIZE_0)
1396+
pid_0 = tl.program_id(0) % num_blocks_0
1397+
pid_1 = tl.program_id(0) // num_blocks_0
1398+
offset_0 = pid_0 * _BLOCK_SIZE_0
1399+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1400+
mask_0 = indices_0 < num_elements
1401+
offset_1 = pid_1 * _BLOCK_SIZE_1
1402+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
1403+
mask_1 = indices_1 < num_features
1404+
vals = tl.load(input_data + (indices_0[:, None] * input_data_stride_0 + indices_1[None, :] * input_data_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1405+
idxs = tl.load(indices + indices_0 * indices_stride_0, mask_0, other=0)
1406+
v_0 = tl.full([], 1, tl.int32)
1407+
v_1 = indices_0 + v_0
1408+
sub = -1 + num_elements
1409+
v_2 = sub.to(tl.int32)
1410+
v_3 = indices_0 < v_2
1411+
idxs_next = tl.load(indices + v_1 * indices_stride_0, mask_0 & v_3, other=0)
1412+
v_4 = idxs.to(tl.float32)
1413+
unsqueeze = v_4[:, None]
1414+
expand = tl.broadcast_to(unsqueeze, [_BLOCK_SIZE_0, _BLOCK_SIZE_1])
1415+
out_vals = tl.associative_scan((vals, expand), 0, helper_function_0)[0]
1416+
v_5 = idxs != idxs_next
1417+
_BLOCK_SIZE_0_ = _BLOCK_SIZE_0
1418+
v_6 = _BLOCK_SIZE_0_.to(tl.int32)
1419+
v_7 = indices_0 % v_6
1420+
v_8 = tl.full([], 0, tl.int32)
1421+
v_9 = v_7 != v_8
1422+
v_10 = libdevice.signbit(v_7) != 0 if v_7.dtype is tl.float32 else v_7 < 0
1423+
v_11 = libdevice.signbit(v_6) != 0 if v_6.dtype is tl.float32 else v_6 < 0
1424+
v_12 = v_10 != v_11
1425+
v_13 = v_9 & v_12
1426+
v_14 = v_7 + v_6
1427+
v_15 = tl.where(v_13, v_14, v_7)
1428+
sub_1 = -1 + _BLOCK_SIZE_0
1429+
v_16 = sub_1.to(tl.int32)
1430+
v_17 = v_15 == v_16
1431+
v_18 = v_5 | v_17
1432+
unsqueeze_1 = v_18[:, None]
1433+
v_19 = 0.0
1434+
v_20 = v_19[None, None]
1435+
v_21 = tl.where(unsqueeze_1, out_vals, v_20)
1436+
tl.atomic_add(output + (idxs[:, None] * output_stride_0 + indices_1[None, :] * output_stride_1), v_21, mask=mask_0[:, None] & mask_1[None, :], sem='relaxed')
1437+
1438+
def segmented_reduction_helion(indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int):
1439+
num_elements, num_features = input_data.shape
1440+
output = torch.zeros((num_nodes, num_features), dtype=input_data.dtype, device=input_data.device)
1441+
_BLOCK_SIZE_0 = 32
1442+
_BLOCK_SIZE_1 = 32
1443+
_segmented_reduction_helion_kernel[triton.cdiv(num_elements, _BLOCK_SIZE_0) * triton.cdiv(num_features, _BLOCK_SIZE_1),](input_data, indices, output, indices.stride(0), input_data.stride(0), input_data.stride(1), output.stride(0), output.stride(1), num_elements, num_features, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1444+
return output
1445+
1446+
def _segmented_reduction_helion_make_precompiler(indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int):
1447+
num_elements, num_features = input_data.shape
1448+
output = torch.zeros((num_nodes, num_features), dtype=input_data.dtype, device=input_data.device)
1449+
_BLOCK_SIZE_0 = 32
1450+
_BLOCK_SIZE_1 = 32
1451+
from helion.runtime.precompile_shim import make_precompiler
1452+
return make_precompiler(_segmented_reduction_helion_kernel)(input_data, indices, output, indices.stride(0), input_data.stride(0), input_data.stride(1), output.stride(0), output.stride(1), num_elements, num_features, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
1453+
13761454
--- assertExpectedJournal(TestExamples.test_softmax)
13771455
from __future__ import annotations
13781456

test/test_examples.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,31 @@ def test_jagged_mean(self):
498498
)
499499
)
500500

501+
def test_segment_reduction(self):
502+
num_nodes = 100
503+
num_edges = 1000
504+
num_features = 32
505+
dtype = torch.float32
506+
507+
# Create sorted indices for segmented reduction
508+
indices = torch.randint(0, num_nodes, (num_edges,), device=DEVICE).sort()[0]
509+
input_data = torch.randn(num_edges, num_features, device=DEVICE, dtype=dtype)
510+
511+
args = (indices, input_data, num_nodes)
512+
513+
# Import and use the reference implementation
514+
mod = import_path(EXAMPLES_DIR / "segment_reduction.py")
515+
expected = mod.segmented_reduction_pytorch(*args)
516+
517+
self.assertExpectedJournal(
518+
check_example(
519+
"segment_reduction",
520+
args,
521+
expected,
522+
fn_name="segmented_reduction_helion",
523+
)
524+
)
525+
501526

502527
if __name__ == "__main__":
503528
unittest.main()

0 commit comments

Comments
 (0)