Skip to content

Fix variable scoping in nested loops for multi-pass kernels (#324) #326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/jagged_mean.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import os

import torch

import helion
from helion._testing import run_example
import helion.language as hl
from helion.utils import get_gpu_memory_info

# TritonBench configuration - adjust based on available GPU memory
if get_gpu_memory_info()[0] < 16.0:
# TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
# Low memory configuration
TRITONBENCH_ARGS = {"B": 32, "M": 8, "seqlen": 64}

Expand Down
3 changes: 2 additions & 1 deletion helion/_compiler/device_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,8 @@ def run_subgraph(*args: object) -> list[object]:
k: v
for k, v in subgraph_walker.scope.items()
if k in rw.writes
and (k not in self.scope or self.scope[k] is not v)
# Only propagate variables that existed before the loop and have been modified
and (k in self.scope and self.scope[k] is not v)
}
)
return outputs.get_tensor_args()
Expand Down
35 changes: 0 additions & 35 deletions helion/utils.py

This file was deleted.

8 changes: 4 additions & 4 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1321,9 +1321,9 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s
di_copy_1 = di
mi_copy_1_0 = mi_copy_1
di_copy_1_0 = di_copy_1
values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
values_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
subscript_1 = mi_copy_1_0[:, None]
v_7 = values - subscript_1
v_7 = values_1 - subscript_1
v_8 = tl_math.exp(v_7)
subscript_2 = di_copy_1_0[:, None]
v_9 = v_8 / subscript_2
Expand Down Expand Up @@ -1382,9 +1382,9 @@ def _softmax_two_pass_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1,
di_copy_1 = di
mi_copy_1_0 = mi_copy_1
di_copy_1_0 = di_copy_1
values = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
values_1 = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero')
subscript_1 = mi_copy_1_0[:, None]
v_7 = values - subscript_1
v_7 = values_1 - subscript_1
v_8 = tl_math.exp(v_7)
subscript_2 = di_copy_1_0[:, None]
v_9 = v_8 / subscript_2
Expand Down
123 changes: 123 additions & 0 deletions test/test_loops.expected
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,62 @@ def addToBoth(a, b, c, *, _launcher=_default_launcher):
_launcher(_addToBoth_kernel, (triton.cdiv(a_n, _BLOCK_SIZE_0) * triton.cdiv(a_m, _BLOCK_SIZE_1) + triton.cdiv(b_n, _BLOCK_SIZE_2) * triton.cdiv(b_m, _BLOCK_SIZE_3) + triton.cdiv(c_n, _BLOCK_SIZE_4) * triton.cdiv(c_m, _BLOCK_SIZE_5),), x0, x1, x2, x0.stride(0), x0.stride(1), x1.stride(0), x1.stride(1), x2.stride(0), x2.stride(1), a_n, a_m, c0, b_n, b_m, c1, c_n, c_m, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=3)
return (x0, x1, x2)

--- assertExpectedJournal(TestLoops.test_nested_loop_accumulator)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _nested_loop_accumulator_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, N, M, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_0 = offset_0 + tl.zeros([1], tl.int32)
acc = tl.full([1], 0.0, tl.float32)
for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < N
acc_copy = acc
acc = acc_copy
for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
mask_2 = indices_2 < M
acc_copy_0_copy = acc
acc_copy_0_copy_0 = acc_copy_0_copy
vals = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_1[None, :, None] * x_stride_1 + indices_2[None, None, :] * x_stride_2), mask_1[None, :, None] & mask_2[None, None, :], other=0)
sum_1 = tl.sum(vals, 2)
sum_2 = tl.sum(sum_1, 1)
acc = acc_copy_0_copy_0 + sum_2
mul = M * N
v_1 = mul.to(tl.float32)
v_2 = acc / v_1
for offset_3 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_3):
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
mask_3 = indices_3 < N
v_2_copy = v_2
v_2_copy_0 = v_2_copy
for offset_4 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_4):
indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
mask_4 = indices_4 < M
v_2_copy_0_copy = v_2_copy_0
v_2_copy_0_copy_0 = v_2_copy_0_copy
vals_1 = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_3[None, :, None] * x_stride_1 + indices_4[None, None, :] * x_stride_2), mask_3[None, :, None] & mask_4[None, None, :], other=0)
subscript = v_2_copy_0_copy_0[:, None, None]
v_3 = vals_1 - subscript
tl.store(out + (indices_0[:, None, None] * out_stride_0 + indices_3[None, :, None] * out_stride_1 + indices_4[None, None, :] * out_stride_2), v_3, mask_3[None, :, None] & mask_4[None, None, :])

def nested_loop_accumulator(x: torch.Tensor, *, _launcher=_default_launcher):
B, N, M = x.size()
out = torch.zeros_like(x)
_BLOCK_SIZE_1 = 2
_BLOCK_SIZE_2 = 4
_BLOCK_SIZE_3 = 2
_BLOCK_SIZE_4 = 4
_launcher(_nested_loop_accumulator_kernel, (B,), x, out, out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), N, M, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestLoops.test_pointwise_device_loop)
from __future__ import annotations

Expand Down Expand Up @@ -977,3 +1033,70 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
_BLOCK_SIZE_2 = 64
_launcher(_matmul_kernel, (triton.cdiv(256, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestLoops.test_three_pass_kernel)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _three_pass_kernel_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, B, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < B
sum_val = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < M
sum_val_copy = sum_val
sum_val_copy_0 = sum_val_copy
load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
sum_1 = tl.sum(load, 1)
sum_val = sum_val_copy_0 + sum_1
sum_sq = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32)
for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
mask_2 = indices_2 < M
sum_sq_copy = sum_sq
sum_sq_copy_0 = sum_sq_copy
vals = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0)
v_1 = vals * vals
sum_2 = tl.sum(v_1, 1)
sum_sq = sum_sq_copy_0 + sum_2
v_3 = M.to(tl.float32)
v_4 = sum_val / v_3
v_5 = M.to(tl.float32)
v_6 = sum_sq / v_5
v_7 = v_4 * v_4
v_8 = v_6 - v_7
v_9 = 1e-06
v_10 = v_8 + v_9
v_11 = libdevice.sqrt(v_10)
for offset_3 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_3):
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
mask_3 = indices_3 < M
v_4_copy = v_4
v_11_copy = v_11
v_4_copy_0 = v_4_copy
v_11_copy_0 = v_11_copy
vals_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_0[:, None] & mask_3[None, :], other=0)
subscript = v_4_copy_0[:, None]
v_12 = vals_1 - subscript
subscript_1 = v_11_copy_0[:, None]
v_13 = v_12 / subscript_1
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_3[None, :] * out_stride_1), v_13, mask_0[:, None] & mask_3[None, :])

def three_pass_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
B, M = x.size()
out = torch.zeros_like(x)
_BLOCK_SIZE_0 = 2
_BLOCK_SIZE_1 = 8
_BLOCK_SIZE_2 = 8
_BLOCK_SIZE_3 = 8
_launcher(_three_pass_kernel_kernel, (triton.cdiv(B, _BLOCK_SIZE_0),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), B, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
return out
102 changes: 102 additions & 0 deletions test/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,108 @@ def kernel_with_dynamic_fill(
expected = x + fill_value[0]
torch.testing.assert_close(result, expected)

def test_nested_loop_accumulator(self):
"""Test variable scoping with nested loops and accumulator pattern."""

@helion.kernel()
def nested_loop_accumulator(x: torch.Tensor) -> torch.Tensor:
B, N, M = x.size()
out = torch.zeros_like(x)

# Outer loop (like processing each batch in jagged)
for tile_b in hl.tile(B):
# Initialize accumulator for this batch
acc = hl.zeros([tile_b], dtype=torch.float32)

# First nested loop: accumulate values
for tile_n in hl.tile(N):
for tile_m in hl.tile(M):
vals = x[tile_b, tile_n, tile_m].to(torch.float32)
# Accumulate sum
acc = acc + vals.sum(dim=2).sum(dim=1)

# Compute average from accumulated sum
avg = acc / (N * M)

# Second nested loop: use the average
for tile_n in hl.tile(N):
for tile_m in hl.tile(M):
vals = x[tile_b, tile_n, tile_m].to(torch.float32)
result = vals - avg[:, None, None]
out[tile_b, tile_n, tile_m] = result.to(x.dtype)

return out

x = torch.randn(2, 4, 8, device=DEVICE, dtype=torch.float32)

code, result = code_and_output(
nested_loop_accumulator,
(x,),
block_sizes=[1, 2, 4, 2, 4],
)

expected = torch.zeros_like(x)
for b in range(x.size(0)):
batch_sum = x[b].sum()
batch_avg = batch_sum / (x.size(1) * x.size(2))
expected[b] = x[b] - batch_avg

torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
self.assertExpectedJournal(code)

def test_three_pass_kernel(self):
"""Test variable scoping with three-pass pattern like layer norm."""

@helion.kernel()
def three_pass_kernel(x: torch.Tensor) -> torch.Tensor:
B, M = x.size()
out = torch.zeros_like(x)

for tile_b in hl.tile(B):
# Pass 1: Compute sum
sum_val = hl.zeros([tile_b], dtype=torch.float32)
for tile_m in hl.tile(M):
sum_val = sum_val + x[tile_b, tile_m].to(torch.float32).sum(dim=1)

# Pass 2: Compute sum of squares
sum_sq = hl.zeros([tile_b], dtype=torch.float32)
for tile_m in hl.tile(M):
vals = x[tile_b, tile_m].to(torch.float32)
sum_sq = sum_sq + (vals * vals).sum(dim=1)

# Compute mean and variance
mean = sum_val / M
var = sum_sq / M - mean * mean
std = torch.sqrt(var + 1e-6)

# Pass 3: Normalize using mean and std
for tile_m in hl.tile(M):
vals = x[tile_b, tile_m].to(torch.float32)
# Error likely here - mean and std might not be accessible
normalized = (vals - mean[:, None]) / std[:, None]
out[tile_b, tile_m] = normalized.to(x.dtype)

return out

x = torch.randn(4, 16, device=DEVICE, dtype=torch.float32)

code, result = code_and_output(
three_pass_kernel,
(x,),
block_sizes=[2, 8, 8, 8],
)

expected = torch.zeros_like(x)
for b in range(x.size(0)):
batch_data = x[b]
mean = batch_data.mean()
var = batch_data.var(unbiased=False)
std = torch.sqrt(var + 1e-6)
expected[b] = (batch_data - mean) / std

torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5)
self.assertExpectedJournal(code)


if __name__ == "__main__":
unittest.main()
Loading