@@ -855,6 +855,62 @@ def addToBoth(a, b, c, *, _launcher=_default_launcher):
855
855
_launcher(_addToBoth_kernel, (triton.cdiv(a_n, _BLOCK_SIZE_0) * triton.cdiv(a_m, _BLOCK_SIZE_1) + triton.cdiv(b_n, _BLOCK_SIZE_2) * triton.cdiv(b_m, _BLOCK_SIZE_3) + triton.cdiv(c_n, _BLOCK_SIZE_4) * triton.cdiv(c_m, _BLOCK_SIZE_5),), x0, x1, x2, x0.stride(0), x0.stride(1), x1.stride(0), x1.stride(1), x2.stride(0), x2.stride(1), a_n, a_m, c0, b_n, b_m, c1, c_n, c_m, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=3)
856
856
return (x0, x1, x2)
857
857
858
+ --- assertExpectedJournal(TestLoops.test_nested_loop_accumulator)
859
+ from __future__ import annotations
860
+
861
+ import torch
862
+ import triton
863
+ import triton.language as tl
864
+ from helion.runtime import default_launcher as _default_launcher
865
+
866
+ @triton.jit
867
+ def _nested_loop_accumulator_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, N, M, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
868
+ pid_0 = tl.program_id(0)
869
+ offset_0 = pid_0
870
+ indices_0 = offset_0 + tl.zeros([1], tl.int32)
871
+ acc = tl.full([1], 0.0, tl.float32)
872
+ for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1):
873
+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
874
+ mask_1 = indices_1 < N
875
+ acc_copy = acc
876
+ acc = acc_copy
877
+ for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2):
878
+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
879
+ mask_2 = indices_2 < M
880
+ acc_copy_0_copy = acc
881
+ acc_copy_0_copy_0 = acc_copy_0_copy
882
+ vals = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_1[None, :, None] * x_stride_1 + indices_2[None, None, :] * x_stride_2), mask_1[None, :, None] & mask_2[None, None, :], other=0)
883
+ sum_1 = tl.sum(vals, 2)
884
+ sum_2 = tl.sum(sum_1, 1)
885
+ acc = acc_copy_0_copy_0 + sum_2
886
+ mul = M * N
887
+ v_1 = mul.to(tl.float32)
888
+ v_2 = acc / v_1
889
+ for offset_3 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_3):
890
+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
891
+ mask_3 = indices_3 < N
892
+ v_2_copy = v_2
893
+ v_2_copy_0 = v_2_copy
894
+ for offset_4 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_4):
895
+ indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
896
+ mask_4 = indices_4 < M
897
+ v_2_copy_0_copy = v_2_copy_0
898
+ v_2_copy_0_copy_0 = v_2_copy_0_copy
899
+ vals_1 = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_3[None, :, None] * x_stride_1 + indices_4[None, None, :] * x_stride_2), mask_3[None, :, None] & mask_4[None, None, :], other=0)
900
+ subscript = v_2_copy_0_copy_0[:, None, None]
901
+ v_3 = vals_1 - subscript
902
+ tl.store(out + (indices_0[:, None, None] * out_stride_0 + indices_3[None, :, None] * out_stride_1 + indices_4[None, None, :] * out_stride_2), v_3, mask_3[None, :, None] & mask_4[None, None, :])
903
+
904
+ def nested_loop_accumulator(x: torch.Tensor, *, _launcher=_default_launcher):
905
+ B, N, M = x.size()
906
+ out = torch.zeros_like(x)
907
+ _BLOCK_SIZE_1 = 2
908
+ _BLOCK_SIZE_2 = 4
909
+ _BLOCK_SIZE_3 = 2
910
+ _BLOCK_SIZE_4 = 4
911
+ _launcher(_nested_loop_accumulator_kernel, (B,), x, out, out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), N, M, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
912
+ return out
913
+
858
914
--- assertExpectedJournal(TestLoops.test_pointwise_device_loop)
859
915
from __future__ import annotations
860
916
@@ -977,3 +1033,70 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
977
1033
_BLOCK_SIZE_2 = 64
978
1034
_launcher(_matmul_kernel, (triton.cdiv(256, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
979
1035
return out
1036
+
1037
+ --- assertExpectedJournal(TestLoops.test_three_pass_kernel)
1038
+ from __future__ import annotations
1039
+
1040
+ import torch
1041
+ import triton
1042
+ import triton.language as tl
1043
+ from torch._inductor.runtime.triton_compat import libdevice
1044
+ from helion.runtime import default_launcher as _default_launcher
1045
+
1046
+ @triton.jit
1047
+ def _three_pass_kernel_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, B, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1048
+ pid_0 = tl.program_id(0)
1049
+ offset_0 = pid_0 * _BLOCK_SIZE_0
1050
+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1051
+ mask_0 = indices_0 < B
1052
+ sum_val = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
1053
+ for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
1054
+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1055
+ mask_1 = indices_1 < M
1056
+ sum_val_copy = sum_val
1057
+ sum_val_copy_0 = sum_val_copy
1058
+ load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1059
+ sum_1 = tl.sum(load, 1)
1060
+ sum_val = sum_val_copy_0 + sum_1
1061
+ sum_sq = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
1062
+ for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2):
1063
+ indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1064
+ mask_2 = indices_2 < M
1065
+ sum_sq_copy = sum_sq
1066
+ sum_sq_copy_0 = sum_sq_copy
1067
+ vals = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1068
+ v_1 = vals * vals
1069
+ sum_2 = tl.sum(v_1, 1)
1070
+ sum_sq = sum_sq_copy_0 + sum_2
1071
+ v_3 = M.to(tl.float32)
1072
+ v_4 = sum_val / v_3
1073
+ v_5 = M.to(tl.float32)
1074
+ v_6 = sum_sq / v_5
1075
+ v_7 = v_4 * v_4
1076
+ v_8 = v_6 - v_7
1077
+ v_9 = 1e-06
1078
+ v_10 = v_8 + v_9
1079
+ v_11 = libdevice.sqrt(v_10)
1080
+ for offset_3 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_3):
1081
+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1082
+ mask_3 = indices_3 < M
1083
+ v_4_copy = v_4
1084
+ v_11_copy = v_11
1085
+ v_4_copy_0 = v_4_copy
1086
+ v_11_copy_0 = v_11_copy
1087
+ vals_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_0[:, None] & mask_3[None, :], other=0)
1088
+ subscript = v_4_copy_0[:, None]
1089
+ v_12 = vals_1 - subscript
1090
+ subscript_1 = v_11_copy_0[:, None]
1091
+ v_13 = v_12 / subscript_1
1092
+ tl.store(out + (indices_0[:, None] * out_stride_0 + indices_3[None, :] * out_stride_1), v_13, mask_0[:, None] & mask_3[None, :])
1093
+
1094
+ def three_pass_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
1095
+ B, M = x.size()
1096
+ out = torch.zeros_like(x)
1097
+ _BLOCK_SIZE_0 = 2
1098
+ _BLOCK_SIZE_1 = 8
1099
+ _BLOCK_SIZE_2 = 8
1100
+ _BLOCK_SIZE_3 = 8
1101
+ _launcher(_three_pass_kernel_kernel, (triton.cdiv(B, _BLOCK_SIZE_0),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), B, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1102
+ return out
0 commit comments