Skip to content

Commit 11b49c6

Browse files
authored
Fix variable scoping in nested loops for multi-pass kernels (#324)
Fixes issue where variables defined in outer loop scopes were not accessible in subsequent inner loops within the same outer loop iteration. This pattern is common in multi-pass algorithms like layernorm.
1 parent dcfa500 commit 11b49c6

File tree

4 files changed

+231
-5
lines changed

4 files changed

+231
-5
lines changed

helion/_compiler/device_ir.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,8 @@ def run_subgraph(*args: object) -> list[object]:
553553
k: v
554554
for k, v in subgraph_walker.scope.items()
555555
if k in rw.writes
556-
and (k not in self.scope or self.scope[k] is not v)
556+
# Only propagate variables that existed before the loop and have been modified
557+
and (k in self.scope and self.scope[k] is not v)
557558
}
558559
)
559560
return outputs.get_tensor_args()

test/test_examples.expected

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,9 +1321,9 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
13211321
di_copy_1 = di
13221322
mi_copy_1_0 = mi_copy_1
13231323
di_copy_1_0 = di_copy_1
1324-
values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1324+
values_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
13251325
subscript_1 = mi_copy_1_0[:, None]
1326-
v_7 = values - subscript_1
1326+
v_7 = values_1 - subscript_1
13271327
v_8 = tl_math.exp(v_7)
13281328
subscript_2 = di_copy_1_0[:, None]
13291329
v_9 = v_8 / subscript_2
@@ -1382,9 +1382,9 @@ def _softmax_two_pass_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1,
13821382
di_copy_1 = di
13831383
mi_copy_1_0 = mi_copy_1
13841384
di_copy_1_0 = di_copy_1
1385-
values = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
1385+
values_1 = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
13861386
subscript_1 = mi_copy_1_0[:, None]
1387-
v_7 = values - subscript_1
1387+
v_7 = values_1 - subscript_1
13881388
v_8 = tl_math.exp(v_7)
13891389
subscript_2 = di_copy_1_0[:, None]
13901390
v_9 = v_8 / subscript_2

test/test_loops.expected

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,62 @@ def addToBoth(a, b, c, *, _launcher=_default_launcher):
855855
_launcher(_addToBoth_kernel, (triton.cdiv(a_n, _BLOCK_SIZE_0) * triton.cdiv(a_m, _BLOCK_SIZE_1) + triton.cdiv(b_n, _BLOCK_SIZE_2) * triton.cdiv(b_m, _BLOCK_SIZE_3) + triton.cdiv(c_n, _BLOCK_SIZE_4) * triton.cdiv(c_m, _BLOCK_SIZE_5),), x0, x1, x2, x0.stride(0), x0.stride(1), x1.stride(0), x1.stride(1), x2.stride(0), x2.stride(1), a_n, a_m, c0, b_n, b_m, c1, c_n, c_m, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=3)
856856
return (x0, x1, x2)
857857

858+
--- assertExpectedJournal(TestLoops.test_nested_loop_accumulator)
859+
from __future__ import annotations
860+
861+
import torch
862+
import triton
863+
import triton.language as tl
864+
from helion.runtime import default_launcher as _default_launcher
865+
866+
@triton.jit
867+
def _nested_loop_accumulator_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, N, M, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
868+
pid_0 = tl.program_id(0)
869+
offset_0 = pid_0
870+
indices_0 = offset_0 + tl.zeros([1], tl.int32)
871+
acc = tl.full([1], 0.0, tl.float32)
872+
for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1):
873+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
874+
mask_1 = indices_1 < N
875+
acc_copy = acc
876+
acc = acc_copy
877+
for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2):
878+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
879+
mask_2 = indices_2 < M
880+
acc_copy_0_copy = acc
881+
acc_copy_0_copy_0 = acc_copy_0_copy
882+
vals = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_1[None, :, None] * x_stride_1 + indices_2[None, None, :] * x_stride_2), mask_1[None, :, None] & mask_2[None, None, :], other=0)
883+
sum_1 = tl.sum(vals, 2)
884+
sum_2 = tl.sum(sum_1, 1)
885+
acc = acc_copy_0_copy_0 + sum_2
886+
mul = M * N
887+
v_1 = mul.to(tl.float32)
888+
v_2 = acc / v_1
889+
for offset_3 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_3):
890+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
891+
mask_3 = indices_3 < N
892+
v_2_copy = v_2
893+
v_2_copy_0 = v_2_copy
894+
for offset_4 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_4):
895+
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
896+
mask_4 = indices_4 < M
897+
v_2_copy_0_copy = v_2_copy_0
898+
v_2_copy_0_copy_0 = v_2_copy_0_copy
899+
vals_1 = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_3[None, :, None] * x_stride_1 + indices_4[None, None, :] * x_stride_2), mask_3[None, :, None] & mask_4[None, None, :], other=0)
900+
subscript = v_2_copy_0_copy_0[:, None, None]
901+
v_3 = vals_1 - subscript
902+
tl.store(out + (indices_0[:, None, None] * out_stride_0 + indices_3[None, :, None] * out_stride_1 + indices_4[None, None, :] * out_stride_2), v_3, mask_3[None, :, None] & mask_4[None, None, :])
903+
904+
def nested_loop_accumulator(x: torch.Tensor, *, _launcher=_default_launcher):
905+
B, N, M = x.size()
906+
out = torch.zeros_like(x)
907+
_BLOCK_SIZE_1 = 2
908+
_BLOCK_SIZE_2 = 4
909+
_BLOCK_SIZE_3 = 2
910+
_BLOCK_SIZE_4 = 4
911+
_launcher(_nested_loop_accumulator_kernel, (B,), x, out, out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), N, M, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
912+
return out
913+
858914
--- assertExpectedJournal(TestLoops.test_pointwise_device_loop)
859915
from __future__ import annotations
860916

@@ -977,3 +1033,70 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
9771033
_BLOCK_SIZE_2 = 64
9781034
_launcher(_matmul_kernel, (triton.cdiv(256, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
9791035
return out
1036+
1037+
--- assertExpectedJournal(TestLoops.test_three_pass_kernel)
1038+
from __future__ import annotations
1039+
1040+
import torch
1041+
import triton
1042+
import triton.language as tl
1043+
from torch._inductor.runtime.triton_compat import libdevice
1044+
from helion.runtime import default_launcher as _default_launcher
1045+
1046+
@triton.jit
1047+
def _three_pass_kernel_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, B, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
1048+
pid_0 = tl.program_id(0)
1049+
offset_0 = pid_0 * _BLOCK_SIZE_0
1050+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1051+
mask_0 = indices_0 < B
1052+
sum_val = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
1053+
for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
1054+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1055+
mask_1 = indices_1 < M
1056+
sum_val_copy = sum_val
1057+
sum_val_copy_0 = sum_val_copy
1058+
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
1059+
sum_1 = tl.sum(load, 1)
1060+
sum_val = sum_val_copy_0 + sum_1
1061+
sum_sq = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
1062+
for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2):
1063+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1064+
mask_2 = indices_2 < M
1065+
sum_sq_copy = sum_sq
1066+
sum_sq_copy_0 = sum_sq_copy
1067+
vals = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
1068+
v_1 = vals * vals
1069+
sum_2 = tl.sum(v_1, 1)
1070+
sum_sq = sum_sq_copy_0 + sum_2
1071+
v_3 = M.to(tl.float32)
1072+
v_4 = sum_val / v_3
1073+
v_5 = M.to(tl.float32)
1074+
v_6 = sum_sq / v_5
1075+
v_7 = v_4 * v_4
1076+
v_8 = v_6 - v_7
1077+
v_9 = 1e-06
1078+
v_10 = v_8 + v_9
1079+
v_11 = libdevice.sqrt(v_10)
1080+
for offset_3 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_3):
1081+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
1082+
mask_3 = indices_3 < M
1083+
v_4_copy = v_4
1084+
v_11_copy = v_11
1085+
v_4_copy_0 = v_4_copy
1086+
v_11_copy_0 = v_11_copy
1087+
vals_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_0[:, None] & mask_3[None, :], other=0)
1088+
subscript = v_4_copy_0[:, None]
1089+
v_12 = vals_1 - subscript
1090+
subscript_1 = v_11_copy_0[:, None]
1091+
v_13 = v_12 / subscript_1
1092+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_3[None, :] * out_stride_1), v_13, mask_0[:, None] & mask_3[None, :])
1093+
1094+
def three_pass_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
1095+
B, M = x.size()
1096+
out = torch.zeros_like(x)
1097+
_BLOCK_SIZE_0 = 2
1098+
_BLOCK_SIZE_1 = 8
1099+
_BLOCK_SIZE_2 = 8
1100+
_BLOCK_SIZE_3 = 8
1101+
_launcher(_three_pass_kernel_kernel, (triton.cdiv(B, _BLOCK_SIZE_0),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), B, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
1102+
return out

test/test_loops.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,108 @@ def kernel_with_dynamic_fill(
989989
expected = x + fill_value[0]
990990
torch.testing.assert_close(result, expected)
991991

992+
def test_nested_loop_accumulator(self):
993+
"""Test variable scoping with nested loops and accumulator pattern."""
994+
995+
@helion.kernel()
996+
def nested_loop_accumulator(x: torch.Tensor) -> torch.Tensor:
997+
B, N, M = x.size()
998+
out = torch.zeros_like(x)
999+
1000+
# Outer loop (like processing each batch in jagged)
1001+
for tile_b in hl.tile(B):
1002+
# Initialize accumulator for this batch
1003+
acc = hl.zeros([tile_b], dtype=torch.float32)
1004+
1005+
# First nested loop: accumulate values
1006+
for tile_n in hl.tile(N):
1007+
for tile_m in hl.tile(M):
1008+
vals = x[tile_b, tile_n, tile_m].to(torch.float32)
1009+
# Accumulate sum
1010+
acc = acc + vals.sum(dim=2).sum(dim=1)
1011+
1012+
# Compute average from accumulated sum
1013+
avg = acc / (N * M)
1014+
1015+
# Second nested loop: use the average
1016+
for tile_n in hl.tile(N):
1017+
for tile_m in hl.tile(M):
1018+
vals = x[tile_b, tile_n, tile_m].to(torch.float32)
1019+
result = vals - avg[:, None, None]
1020+
out[tile_b, tile_n, tile_m] = result.to(x.dtype)
1021+
1022+
return out
1023+
1024+
x = torch.randn(2, 4, 8, device=DEVICE, dtype=torch.float32)
1025+
1026+
code, result = code_and_output(
1027+
nested_loop_accumulator,
1028+
(x,),
1029+
block_sizes=[1, 2, 4, 2, 4],
1030+
)
1031+
1032+
expected = torch.zeros_like(x)
1033+
for b in range(x.size(0)):
1034+
batch_sum = x[b].sum()
1035+
batch_avg = batch_sum / (x.size(1) * x.size(2))
1036+
expected[b] = x[b] - batch_avg
1037+
1038+
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
1039+
self.assertExpectedJournal(code)
1040+
1041+
def test_three_pass_kernel(self):
1042+
"""Test variable scoping with three-pass pattern like layer norm."""
1043+
1044+
@helion.kernel()
1045+
def three_pass_kernel(x: torch.Tensor) -> torch.Tensor:
1046+
B, M = x.size()
1047+
out = torch.zeros_like(x)
1048+
1049+
for tile_b in hl.tile(B):
1050+
# Pass 1: Compute sum
1051+
sum_val = hl.zeros([tile_b], dtype=torch.float32)
1052+
for tile_m in hl.tile(M):
1053+
sum_val = sum_val + x[tile_b, tile_m].to(torch.float32).sum(dim=1)
1054+
1055+
# Pass 2: Compute sum of squares
1056+
sum_sq = hl.zeros([tile_b], dtype=torch.float32)
1057+
for tile_m in hl.tile(M):
1058+
vals = x[tile_b, tile_m].to(torch.float32)
1059+
sum_sq = sum_sq + (vals * vals).sum(dim=1)
1060+
1061+
# Compute mean and variance
1062+
mean = sum_val / M
1063+
var = sum_sq / M - mean * mean
1064+
std = torch.sqrt(var + 1e-6)
1065+
1066+
# Pass 3: Normalize using mean and std
1067+
for tile_m in hl.tile(M):
1068+
vals = x[tile_b, tile_m].to(torch.float32)
1069+
# Error likely here - mean and std might not be accessible
1070+
normalized = (vals - mean[:, None]) / std[:, None]
1071+
out[tile_b, tile_m] = normalized.to(x.dtype)
1072+
1073+
return out
1074+
1075+
x = torch.randn(4, 16, device=DEVICE, dtype=torch.float32)
1076+
1077+
code, result = code_and_output(
1078+
three_pass_kernel,
1079+
(x,),
1080+
block_sizes=[2, 8, 8, 8],
1081+
)
1082+
1083+
expected = torch.zeros_like(x)
1084+
for b in range(x.size(0)):
1085+
batch_data = x[b]
1086+
mean = batch_data.mean()
1087+
var = batch_data.var(unbiased=False)
1088+
std = torch.sqrt(var + 1e-6)
1089+
expected[b] = (batch_data - mean) / std
1090+
1091+
torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
1092+
self.assertExpectedJournal(code)
1093+
9921094

9931095
if __name__ == "__main__":
9941096
unittest.main()

0 commit comments

Comments
 (0)