From c0f1e08aac9e331f4996f92ab0a6ade0cd75cf9f Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 14 Jul 2025 22:51:14 -0700 Subject: [PATCH 1/2] Add HELION_DEV_LOW_VRAM env var for low GPU memory machines Some dev machine (e.g. gpu laptop) has low VRAM which causes some tritonbench inputs to OOM. This PR adds HELION_DEV_LOW_VRAM env var and uses smaller inputs if the env var is set. User can choose to opt into this mode by setting the env var, instead of passively having smaller inputs due to low VRAM. stack-info: PR: https://github.com/pytorch-labs/helion/pull/325, branch: yf225/stack/31 --- examples/jagged_mean.py | 7 ++++--- helion/utils.py | 35 ----------------------------------- 2 files changed, 4 insertions(+), 38 deletions(-) delete mode 100644 helion/utils.py diff --git a/examples/jagged_mean.py b/examples/jagged_mean.py index 540865b1..cbc6e99d 100644 --- a/examples/jagged_mean.py +++ b/examples/jagged_mean.py @@ -1,14 +1,15 @@ from __future__ import annotations +import os + import torch import helion from helion._testing import run_example import helion.language as hl -from helion.utils import get_gpu_memory_info -# TritonBench configuration - adjust based on available GPU memory -if get_gpu_memory_info()[0] < 16.0: +# TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable +if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1": # Low memory configuration TRITONBENCH_ARGS = {"B": 32, "M": 8, "seqlen": 64} diff --git a/helion/utils.py b/helion/utils.py deleted file mode 100644 index 0e6f9177..00000000 --- a/helion/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -import torch - - -def get_gpu_memory_info(device_id: int | None = None) -> tuple[float, float]: - """ - Get total and available GPU memory in GB. - - Args: - device_id: GPU device ID. If None, uses current device. - - Returns: - Tuple of (total_memory_gb, available_memory_gb) - """ - if not torch.cuda.is_available(): - return (0.0, 0.0) - - if device_id is None: - device_id = torch.cuda.current_device() - - # Get total memory - total_memory = torch.cuda.get_device_properties(device_id).total_memory - - # Get reserved memory (memory allocated by the caching allocator) - reserved_memory = torch.cuda.memory_reserved(device_id) - - # Available memory is approximately total - reserved - available_memory = total_memory - reserved_memory - - # Convert to GB - total_gb = total_memory / (1024**3) - available_gb = available_memory / (1024**3) - - return (total_gb, available_gb) From 6063283c60a38d22c269f4baf824416fdcb39a36 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Mon, 14 Jul 2025 23:06:01 -0700 Subject: [PATCH 2/2] Fix variable scoping in nested loops for multi-pass kernels (#324) Fixes issue where variables defined in outer loop scopes were not accessible in subsequent inner loops within the same outer loop iteration. This pattern is common in multi-pass algorithms like layernorm. stack-info: PR: https://github.com/pytorch-labs/helion/pull/326, branch: yf225/stack/32 --- helion/_compiler/device_ir.py | 3 +- test/test_examples.expected | 8 +-- test/test_loops.expected | 123 ++++++++++++++++++++++++++++++++++ test/test_loops.py | 102 ++++++++++++++++++++++++++++ 4 files changed, 231 insertions(+), 5 deletions(-) diff --git a/helion/_compiler/device_ir.py b/helion/_compiler/device_ir.py index 6eea6476..cbb56240 100644 --- a/helion/_compiler/device_ir.py +++ b/helion/_compiler/device_ir.py @@ -553,7 +553,8 @@ def run_subgraph(*args: object) -> list[object]: k: v for k, v in subgraph_walker.scope.items() if k in rw.writes - and (k not in self.scope or self.scope[k] is not v) + # Only propagate variables that existed before the loop and have been modified + and (k in self.scope and self.scope[k] is not v) } ) return outputs.get_tensor_args() diff --git a/test/test_examples.expected b/test/test_examples.expected index deef0e16..32dc18c1 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1321,9 +1321,9 @@ def _softmax_two_pass_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_s di_copy_1 = di mi_copy_1_0 = mi_copy_1 di_copy_1_0 = di_copy_1 - values = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + values_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) subscript_1 = mi_copy_1_0[:, None] - v_7 = values - subscript_1 + v_7 = values_1 - subscript_1 v_8 = tl_math.exp(v_7) subscript_2 = di_copy_1_0[:, None] v_9 = v_8 / subscript_2 @@ -1382,9 +1382,9 @@ def _softmax_two_pass_kernel(x, out, out_size_0, out_size_1, x_size_0, x_size_1, di_copy_1 = di mi_copy_1_0 = mi_copy_1 di_copy_1_0 = di_copy_1 - values = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') + values_1 = tl.load(tl.make_block_ptr(x, [x_size_0, x_size_1], [x_stride_0, x_stride_1], [offset_0, offset_2], [_BLOCK_SIZE_0, _BLOCK_SIZE_1], [1, 0]), boundary_check=[0, 1], padding_option='zero') subscript_1 = mi_copy_1_0[:, None] - v_7 = values - subscript_1 + v_7 = values_1 - subscript_1 v_8 = tl_math.exp(v_7) subscript_2 = di_copy_1_0[:, None] v_9 = v_8 / subscript_2 diff --git a/test/test_loops.expected b/test/test_loops.expected index 93e931b9..a12b6bb3 100644 --- a/test/test_loops.expected +++ b/test/test_loops.expected @@ -855,6 +855,62 @@ def addToBoth(a, b, c, *, _launcher=_default_launcher): _launcher(_addToBoth_kernel, (triton.cdiv(a_n, _BLOCK_SIZE_0) * triton.cdiv(a_m, _BLOCK_SIZE_1) + triton.cdiv(b_n, _BLOCK_SIZE_2) * triton.cdiv(b_m, _BLOCK_SIZE_3) + triton.cdiv(c_n, _BLOCK_SIZE_4) * triton.cdiv(c_m, _BLOCK_SIZE_5),), x0, x1, x2, x0.stride(0), x0.stride(1), x1.stride(0), x1.stride(1), x2.stride(0), x2.stride(1), a_n, a_m, c0, b_n, b_m, c1, c_n, c_m, c2, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, _BLOCK_SIZE_5, num_warps=4, num_stages=3) return (x0, x1, x2) +--- assertExpectedJournal(TestLoops.test_nested_loop_accumulator) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _nested_loop_accumulator_kernel(x, out, out_stride_0, out_stride_1, out_stride_2, x_stride_0, x_stride_1, x_stride_2, N, M, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + indices_0 = offset_0 + tl.zeros([1], tl.int32) + acc = tl.full([1], 0.0, tl.float32) + for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < N + acc_copy = acc + acc = acc_copy + for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < M + acc_copy_0_copy = acc + acc_copy_0_copy_0 = acc_copy_0_copy + vals = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_1[None, :, None] * x_stride_1 + indices_2[None, None, :] * x_stride_2), mask_1[None, :, None] & mask_2[None, None, :], other=0) + sum_1 = tl.sum(vals, 2) + sum_2 = tl.sum(sum_1, 1) + acc = acc_copy_0_copy_0 + sum_2 + mul = M * N + v_1 = mul.to(tl.float32) + v_2 = acc / v_1 + for offset_3 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_3): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) + mask_3 = indices_3 < N + v_2_copy = v_2 + v_2_copy_0 = v_2_copy + for offset_4 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_4): + indices_4 = offset_4 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32) + mask_4 = indices_4 < M + v_2_copy_0_copy = v_2_copy_0 + v_2_copy_0_copy_0 = v_2_copy_0_copy + vals_1 = tl.load(x + (indices_0[:, None, None] * x_stride_0 + indices_3[None, :, None] * x_stride_1 + indices_4[None, None, :] * x_stride_2), mask_3[None, :, None] & mask_4[None, None, :], other=0) + subscript = v_2_copy_0_copy_0[:, None, None] + v_3 = vals_1 - subscript + tl.store(out + (indices_0[:, None, None] * out_stride_0 + indices_3[None, :, None] * out_stride_1 + indices_4[None, None, :] * out_stride_2), v_3, mask_3[None, :, None] & mask_4[None, None, :]) + +def nested_loop_accumulator(x: torch.Tensor, *, _launcher=_default_launcher): + B, N, M = x.size() + out = torch.zeros_like(x) + _BLOCK_SIZE_1 = 2 + _BLOCK_SIZE_2 = 4 + _BLOCK_SIZE_3 = 2 + _BLOCK_SIZE_4 = 4 + _launcher(_nested_loop_accumulator_kernel, (B,), x, out, out.stride(0), out.stride(1), out.stride(2), x.stride(0), x.stride(1), x.stride(2), N, M, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3) + return out + --- assertExpectedJournal(TestLoops.test_pointwise_device_loop) from __future__ import annotations @@ -977,3 +1033,70 @@ def matmul(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher): _BLOCK_SIZE_2 = 64 _launcher(_matmul_kernel, (triton.cdiv(256, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3) return out + +--- assertExpectedJournal(TestLoops.test_three_pass_kernel) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_compat import libdevice +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _three_pass_kernel_kernel(x, out, out_stride_0, out_stride_1, x_stride_0, x_stride_1, B, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < B + sum_val = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32) + for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < M + sum_val_copy = sum_val + sum_val_copy_0 = sum_val_copy + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + sum_1 = tl.sum(load, 1) + sum_val = sum_val_copy_0 + sum_1 + sum_sq = tl.full([_BLOCK_SIZE_0], 0.0, tl.float32) + for offset_2 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_2): + indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32) + mask_2 = indices_2 < M + sum_sq_copy = sum_sq + sum_sq_copy_0 = sum_sq_copy + vals = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_2[None, :] * x_stride_1), mask_0[:, None] & mask_2[None, :], other=0) + v_1 = vals * vals + sum_2 = tl.sum(v_1, 1) + sum_sq = sum_sq_copy_0 + sum_2 + v_3 = M.to(tl.float32) + v_4 = sum_val / v_3 + v_5 = M.to(tl.float32) + v_6 = sum_sq / v_5 + v_7 = v_4 * v_4 + v_8 = v_6 - v_7 + v_9 = 1e-06 + v_10 = v_8 + v_9 + v_11 = libdevice.sqrt(v_10) + for offset_3 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_3): + indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32) + mask_3 = indices_3 < M + v_4_copy = v_4 + v_11_copy = v_11 + v_4_copy_0 = v_4_copy + v_11_copy_0 = v_11_copy + vals_1 = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_3[None, :] * x_stride_1), mask_0[:, None] & mask_3[None, :], other=0) + subscript = v_4_copy_0[:, None] + v_12 = vals_1 - subscript + subscript_1 = v_11_copy_0[:, None] + v_13 = v_12 / subscript_1 + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_3[None, :] * out_stride_1), v_13, mask_0[:, None] & mask_3[None, :]) + +def three_pass_kernel(x: torch.Tensor, *, _launcher=_default_launcher): + B, M = x.size() + out = torch.zeros_like(x) + _BLOCK_SIZE_0 = 2 + _BLOCK_SIZE_1 = 8 + _BLOCK_SIZE_2 = 8 + _BLOCK_SIZE_3 = 8 + _launcher(_three_pass_kernel_kernel, (triton.cdiv(B, _BLOCK_SIZE_0),), x, out, out.stride(0), out.stride(1), x.stride(0), x.stride(1), B, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3) + return out diff --git a/test/test_loops.py b/test/test_loops.py index 98123ac3..792af2c6 100644 --- a/test/test_loops.py +++ b/test/test_loops.py @@ -989,6 +989,108 @@ def kernel_with_dynamic_fill( expected = x + fill_value[0] torch.testing.assert_close(result, expected) + def test_nested_loop_accumulator(self): + """Test variable scoping with nested loops and accumulator pattern.""" + + @helion.kernel() + def nested_loop_accumulator(x: torch.Tensor) -> torch.Tensor: + B, N, M = x.size() + out = torch.zeros_like(x) + + # Outer loop (like processing each batch in jagged) + for tile_b in hl.tile(B): + # Initialize accumulator for this batch + acc = hl.zeros([tile_b], dtype=torch.float32) + + # First nested loop: accumulate values + for tile_n in hl.tile(N): + for tile_m in hl.tile(M): + vals = x[tile_b, tile_n, tile_m].to(torch.float32) + # Accumulate sum + acc = acc + vals.sum(dim=2).sum(dim=1) + + # Compute average from accumulated sum + avg = acc / (N * M) + + # Second nested loop: use the average + for tile_n in hl.tile(N): + for tile_m in hl.tile(M): + vals = x[tile_b, tile_n, tile_m].to(torch.float32) + result = vals - avg[:, None, None] + out[tile_b, tile_n, tile_m] = result.to(x.dtype) + + return out + + x = torch.randn(2, 4, 8, device=DEVICE, dtype=torch.float32) + + code, result = code_and_output( + nested_loop_accumulator, + (x,), + block_sizes=[1, 2, 4, 2, 4], + ) + + expected = torch.zeros_like(x) + for b in range(x.size(0)): + batch_sum = x[b].sum() + batch_avg = batch_sum / (x.size(1) * x.size(2)) + expected[b] = x[b] - batch_avg + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + self.assertExpectedJournal(code) + + def test_three_pass_kernel(self): + """Test variable scoping with three-pass pattern like layer norm.""" + + @helion.kernel() + def three_pass_kernel(x: torch.Tensor) -> torch.Tensor: + B, M = x.size() + out = torch.zeros_like(x) + + for tile_b in hl.tile(B): + # Pass 1: Compute sum + sum_val = hl.zeros([tile_b], dtype=torch.float32) + for tile_m in hl.tile(M): + sum_val = sum_val + x[tile_b, tile_m].to(torch.float32).sum(dim=1) + + # Pass 2: Compute sum of squares + sum_sq = hl.zeros([tile_b], dtype=torch.float32) + for tile_m in hl.tile(M): + vals = x[tile_b, tile_m].to(torch.float32) + sum_sq = sum_sq + (vals * vals).sum(dim=1) + + # Compute mean and variance + mean = sum_val / M + var = sum_sq / M - mean * mean + std = torch.sqrt(var + 1e-6) + + # Pass 3: Normalize using mean and std + for tile_m in hl.tile(M): + vals = x[tile_b, tile_m].to(torch.float32) + # Error likely here - mean and std might not be accessible + normalized = (vals - mean[:, None]) / std[:, None] + out[tile_b, tile_m] = normalized.to(x.dtype) + + return out + + x = torch.randn(4, 16, device=DEVICE, dtype=torch.float32) + + code, result = code_and_output( + three_pass_kernel, + (x,), + block_sizes=[2, 8, 8, 8], + ) + + expected = torch.zeros_like(x) + for b in range(x.size(0)): + batch_data = x[b] + mean = batch_data.mean() + var = batch_data.var(unbiased=False) + std = torch.sqrt(var + 1e-6) + expected[b] = (batch_data - mean) / std + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + self.assertExpectedJournal(code) + if __name__ == "__main__": unittest.main()