|
| 1 | +# Code based on https://github.com/pytorch-labs/helion/issues/237 |
| 2 | +from __future__ import annotations |
| 3 | + |
| 4 | +import torch |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | + |
| 8 | +import helion |
| 9 | +from helion._testing import DEVICE |
| 10 | +from helion._testing import run_example |
| 11 | +import helion.language as hl |
| 12 | + |
| 13 | + |
| 14 | +def combine_fn_helion( |
| 15 | + left_values: torch.Tensor, |
| 16 | + left_indices: torch.Tensor, |
| 17 | + right_values: torch.Tensor, |
| 18 | + right_indices: torch.Tensor, |
| 19 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 20 | + combined_values = torch.where( |
| 21 | + left_indices == right_indices, left_values + right_values, right_values |
| 22 | + ) |
| 23 | + return combined_values, right_indices |
| 24 | + |
| 25 | + |
| 26 | +@helion.kernel() |
| 27 | +def segmented_reduction_helion( |
| 28 | + indices: torch.Tensor, input_data: torch.Tensor, num_nodes: int |
| 29 | +) -> torch.Tensor: |
| 30 | + num_elements, num_features = input_data.shape |
| 31 | + output = torch.zeros( |
| 32 | + (num_nodes, num_features), dtype=input_data.dtype, device=input_data.device |
| 33 | + ) |
| 34 | + for tile_e, tile_f in hl.tile([num_elements, num_features]): |
| 35 | + vals = input_data[tile_e, tile_f] |
| 36 | + idxs = indices[tile_e] |
| 37 | + idxs_next = hl.load( |
| 38 | + indices, [tile_e.index + 1], extra_mask=tile_e.index < num_elements - 1 |
| 39 | + ) |
| 40 | + tuple_in = (vals, idxs.float().unsqueeze(1).expand_as(vals)) |
| 41 | + out_vals, _ = hl.associative_scan(combine_fn_helion, tuple_in, dim=0) |
| 42 | + mask = (idxs != idxs_next) | ( |
| 43 | + tile_e.index % tile_e.block_size == tile_e.block_size - 1 |
| 44 | + ) |
| 45 | + segment_vals = torch.where(mask.unsqueeze(1), out_vals, 0.0) |
| 46 | + hl.atomic_add(output, [idxs, tile_f], segment_vals) |
| 47 | + return output |
| 48 | + |
| 49 | + |
| 50 | +@triton.jit |
| 51 | +def combine_fn_triton(left_values, left_indices, right_values, right_indices): |
| 52 | + same_segment = left_indices == right_indices |
| 53 | + combined_values = tl.where(same_segment, left_values + right_values, right_values) |
| 54 | + combined_indices = right_indices |
| 55 | + return combined_values, combined_indices |
| 56 | + |
| 57 | + |
| 58 | +@triton.autotune( |
| 59 | + configs=[ |
| 60 | + triton.Config( |
| 61 | + {"BLOCK_SIZE": bs}, |
| 62 | + ) |
| 63 | + for bs in [8, 16, 32, 64, 128] |
| 64 | + ], |
| 65 | + key=["C"], |
| 66 | + restore_value=["out_ptr"], |
| 67 | +) |
| 68 | +@triton.jit |
| 69 | +def _segmented_reduction_triton( |
| 70 | + index, # the input index tensor |
| 71 | + in_ptr, # the input tensor |
| 72 | + out_ptr, # the output value tensor |
| 73 | + E: tl.constexpr, # Number of elements in the input tensor (1d) |
| 74 | + C: tl.constexpr, # Number of features in the input tensor (2d) |
| 75 | + BLOCK_SIZE: tl.constexpr, # Block size for the scan |
| 76 | +): |
| 77 | + # Triton version adapted from |
| 78 | + # https://github.com/fishmingyu/GeoT/blob/main/geot/triton/seg_reduction.py |
| 79 | + pid = tl.program_id(axis=0) |
| 80 | + offset_pid = pid // C |
| 81 | + feature_id = pid % C |
| 82 | + offsets = offset_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) |
| 83 | + mask = offsets < E |
| 84 | + |
| 85 | + # Load input data |
| 86 | + vals = tl.load(in_ptr + offsets * C + feature_id, mask=mask) |
| 87 | + idxs = tl.load(index + offsets, mask=mask) |
| 88 | + idxs_next = tl.load(index + offsets + 1, offsets < E - 1) |
| 89 | + |
| 90 | + # Perform an inclusive scan using tl.associative_scan |
| 91 | + result_values, _ = tl.associative_scan( |
| 92 | + ( |
| 93 | + vals, |
| 94 | + idxs, |
| 95 | + ), |
| 96 | + axis=0, |
| 97 | + combine_fn=combine_fn_triton, |
| 98 | + ) |
| 99 | + # if offset % BLOCK_SIZE == -1, it means the last element of the segment |
| 100 | + segment_start = (idxs != idxs_next) | (offsets % BLOCK_SIZE == BLOCK_SIZE - 1) |
| 101 | + tl.atomic_add(out_ptr + idxs * C + feature_id, result_values, mask & segment_start) |
| 102 | + |
| 103 | + |
| 104 | +def segmented_reduction_triton(indices, input_data, num_nodes): |
| 105 | + E, C = input_data.shape |
| 106 | + output = torch.zeros( |
| 107 | + (num_nodes, C), dtype=input_data.dtype, device=input_data.device |
| 108 | + ) |
| 109 | + |
| 110 | + def grid(META): |
| 111 | + return (triton.cdiv(E, META["BLOCK_SIZE"]) * C,) |
| 112 | + |
| 113 | + _segmented_reduction_triton[grid](indices, input_data, output, E, C) |
| 114 | + return output |
| 115 | + |
| 116 | + |
| 117 | +def segmented_reduction_pytorch(indices, input_data, num_nodes): |
| 118 | + # Run PyTorch reference (scatter_add equivalent) |
| 119 | + num_features = input_data.size(1) |
| 120 | + pytorch_output = torch.zeros( |
| 121 | + num_nodes, num_features, device=input_data.device, dtype=input_data.dtype |
| 122 | + ) |
| 123 | + pytorch_output.scatter_add_( |
| 124 | + 0, indices.unsqueeze(1).expand(-1, num_features), input_data |
| 125 | + ) |
| 126 | + return pytorch_output |
| 127 | + |
| 128 | + |
| 129 | +def main(): |
| 130 | + num_nodes = 100 |
| 131 | + num_edges = 2000 |
| 132 | + num_features = 128 |
| 133 | + |
| 134 | + dtype = torch.float32 |
| 135 | + |
| 136 | + # Create sorted indices for segmented reduction |
| 137 | + indices = torch.randint(0, num_nodes, (num_edges,), device=DEVICE).sort()[0] |
| 138 | + input_data = torch.randn(num_edges, num_features, device=DEVICE, dtype=dtype) |
| 139 | + |
| 140 | + run_example( |
| 141 | + segmented_reduction_helion, |
| 142 | + { |
| 143 | + "triton": segmented_reduction_triton, |
| 144 | + "pytorch": segmented_reduction_pytorch, |
| 145 | + }, |
| 146 | + (indices, input_data, num_nodes), |
| 147 | + ) |
| 148 | + |
| 149 | + |
| 150 | +if __name__ == "__main__": |
| 151 | + main() |
0 commit comments