Skip to content

Add hl.signal #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .scan_ops import associative_scan as associative_scan
from .scan_ops import cumprod as cumprod
from .scan_ops import cumsum as cumsum
from .signal_wait import signal as signal
from .signal_wait import wait as wait
from .tile_ops import tile_begin as tile_begin
from .tile_ops import tile_block_size as tile_block_size
Expand Down
143 changes: 143 additions & 0 deletions helion/language/signal_wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
from torch.fx import has_side_effect

from .. import exc
from .._compiler.indexing_strategy import SubscriptIndexing
from . import _decorators

if TYPE_CHECKING:
import ast

from .._compiler.inductor_lowering import CodegenState

__all__ = ["signal", "wait"]


@has_side_effect
@_decorators.api(tiles_as_sizes=True)
Expand Down Expand Up @@ -146,3 +149,143 @@ def _(state: CodegenState) -> ast.AST:
signal=signal_expr,
update=update_expr,
)


@has_side_effect
@_decorators.api(tiles_as_sizes=True)
def signal(
signal_pad: torch.Tensor,
index: list[object],
signal: int = 1,
wait_for: int | None = None,
op: str = "atomic_xchg",
sem: str = "release",
scope: str = "gpu",
skip_sync: bool = False,
) -> torch.Tensor:
"""Set the signal_pad slice to the signal value.
Args:
signal_pad: The signal pad to signal
index: Indices to index into the signal_pad tensor
signal: the value to send
wait_for: The value to wait for before sending the signal. Only valid for op = 'atomic_cas'.
op: The memory op for acquring the lock (default: 'atomic_xchg')
sem: The memory sematic for acquring the lock (default: 'release')
scope: The scope of the lock (default: 'gpu')
skip_sync: Skip the syncthreads before sending signal (default: False)
"""
raise exc.NotInsideKernel


@_decorators.prepare_args(signal)
def _(
signal_pad: torch.Tensor,
index: list[object],
signal: int = 1,
wait_for: int | None = None,
op: str = "atomic_xchg",
sem: str = "release",
scope: str = "gpu",
skip_sync: bool = False,
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]:
from helion.language.tile_proxy import Tile

valid_ops = {"atomic_add", "atomic_xchg", "atomic_cas"}
valid_sems = {"relaxed", "release", "acq_rel"}
valid_scopes = {"sys", "gpu"}

if op not in valid_ops:
raise ValueError(f"Invalid signal op '{op}'. Must be one of {valid_ops}. ")

if op == "atomic_cas" and wait_for is None:
raise ValueError(
f"{op} without a wait_for value. Do you want to use 'atomic_add' or 'atomic_xchg' instead? "
)
if op in {"atomic_add", "atomic_xchg"} and wait_for is not None:
raise ValueError(
f"{op} with a wait_for value. Do you want to use 'atomic_cas' instead? "
)

if sem not in valid_sems:
raise ValueError(
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
)

if scope not in valid_scopes:
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")

index = Tile._prepare_index(index)
index = Tile._tiles_to_sizes(index)

return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync)


@_decorators.register_fake(signal)
def _(
signal_pad: torch.Tensor,
index: list[object],
signal: int = 1,
wait_for: int | None = None,
op: str = "atomic_xchg",
sem: str = "release",
scope: str = "gpu",
skip_sync: bool = False,
) -> torch.Tensor:
return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index))


@_decorators.codegen(signal)
def _(state: CodegenState) -> ast.AST:
import ast

from .._compiler.ast_extension import expr_from_string
from .._compiler.indexing_strategy import SubscriptIndexing

signal_pad = state.proxy_arg(0)
index = state.proxy_arg(1)
signal = state.proxy_arg(2)
wait_for = state.proxy_arg(3)
op = state.proxy_arg(4)
sem = state.proxy_arg(5)
scope = state.proxy_arg(6)
skip_sync = state.proxy_arg(7)

assert isinstance(signal_pad, torch.Tensor)
assert isinstance(index, list)

indices = SubscriptIndexing.create(state, signal_pad, index)
signal_pad_name = state.device_function.tensor_arg(signal_pad).name

signal_expr = ast.Constant(value=signal)
if wait_for is not None:
wait_for_expr = ast.Constant(value=wait_for)
else:
wait_for_expr = ast.Constant(value=0)
skip_sync_expr = ast.Constant(value=skip_sync)
assert type(op) is str
assert type(sem) is str
assert type(scope) is str

if op == "atomic_cas":
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
is_scalar = len(bar_tensor_shape) == 0
if is_scalar:
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
else:
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync), sync_after=True)"

return expr_from_string(
call_triton_wait_signal,
offset=indices.index_expr,
wait_for=wait_for_expr,
signal=signal_expr,
skip_sync=skip_sync_expr,
)
call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)"

return expr_from_string(
call_triton_send_signal,
offset=indices.index_expr,
signal=signal_expr,
skip_sync=skip_sync_expr,
)
1 change: 1 addition & 0 deletions helion/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .config import Config as Config
from .kernel import Kernel as Kernel
from .kernel import kernel as kernel
from .triton_helpers import triton_send_signal as triton_send_signal
from .triton_helpers import triton_wait_signal as triton_wait_signal


Expand Down
77 changes: 74 additions & 3 deletions helion/runtime/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,56 @@
import triton
import triton.language as tl

__all__ = ["triton_wait_signal"]
__all__ = ["triton_send_signal", "triton_wait_multiple_signal", "triton_wait_signal"]


@triton.jit
def triton_send_signal(
addr: tl.tensor,
update: tl.constexpr,
sem: tl.constexpr,
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
) -> tl.tensor:
"""
Signal global memory barrier(s).

This function atomically sets global memory barriers to a update value,
signaling to other CTAs waiting on the barrier(s).

Args:
addr: Memory address of the barrier(s) to wait on
update: Set the barrier to
sem: Memory semantics for the atomic operation. Options: "release", "relaxed".
scope: Scope of the atomic operation. Options: "gpu", "sys"
op: Atomic operation type: "atomic_xchg", "atomic_add"
skip_sync: Skip CTA synchronization before setting the barrier. (default: False)
Returns:
The old value of the barrier(s) before the update.
"""
if not skip_sync:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)

tl.static_assert(
sem == "release" or sem == "relaxed",
"Invalid memory semantic. options: 'release', 'relaxed'. ",
)
tl.static_assert(
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu','sys'. "
)

if op == "atomic_xchg":
barrier_status = tl.atomic_xchg(addr, update, sem=sem, scope=scope)
elif op == "atomic_add":
barrier_status = tl.atomic_add(addr, update, sem=sem, scope=scope)
else:
raise NotImplementedError(
f"Unsupported op '{op}' for send signal on gmem barrier. "
)
return barrier_status


@triton.jit
Expand All @@ -15,6 +64,7 @@ def triton_wait_signal(
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
sync_before: tl.constexpr = False, # pyre-ignore[9]
) -> None:
"""
Wait for a global memory barrier to reach the expected value.
Expand All @@ -30,15 +80,16 @@ def triton_wait_signal(
scope: Scope of the atomic operation. Options: "gpu", "sys"
op: Atomic operation type: "ld", "atomic_cas"
skip_sync: Skip CTA sync after acquiring the barrier (default: False)
sync_before: Add a CTA sync before the wait (default: False)
"""
tl.static_assert(
addr.type.is_ptr(),
"Barrier address must be a scalar. Do you want to use '_triton_wait_multiple_signal'? ",
)

tl.static_assert(
sem == "acquire" or sem == "relaxed",
"Invalid memory semantic. options: 'acquire', 'relaxed'. ",
(sem == "acquire" or sem == "relaxed") or sem == "release",
"Invalid memory semantic. options: 'acquire', 'relaxed', 'release'. ",
)
tl.static_assert(
scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. "
Expand All @@ -48,6 +99,11 @@ def triton_wait_signal(
"Invalid op. options: 'ld', 'atomic_cas'. ",
)

if sync_before:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)

# Spin-wait loop:
# Uses atomic_add with update=0 for ld.global.{sem}.{scope}
# Triton generates smem broadcasting of tl.atomic_add return value in ptx,
Expand All @@ -71,3 +127,18 @@ def triton_wait_signal(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)
# tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?)


@triton.jit
def triton_wait_multiple_signal(
addr: tl.tensor,
expect: tl.constexpr, # wait until lock is set to expect
update: tl.constexpr, # update the lock once it is aquired.
sem: tl.constexpr,
scope: tl.constexpr,
op: tl.constexpr,
skip_sync: tl.constexpr,
sync_before: tl.constexpr = False, # pyre-ignore[9]
) -> None:
raise NotImplementedError("Waiting on multiple barriers is not implemented yet. ")
# TODO(joydddd): waiting on multiple barriers at the same time whereeach thread waits on a different barrier
75 changes: 75 additions & 0 deletions test/test_signal_wait.expected
Original file line number Diff line number Diff line change
@@ -1,6 +1,81 @@
This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py.
Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set.

--- assertExpectedJournal(TestWait.test_signal_basic)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _gmem_signal_scalar_bar_kernel_kernel(signal_pad, signal_pad_stride_0):
pid_0 = tl.program_id(0)
offset_0 = pid_0
helion.runtime.triton_send_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)

def gmem_signal_scalar_bar_kernel(signal_pad: torch.Tensor):
n, = signal_pad.shape
_gmem_signal_scalar_bar_kernel_kernel[n,](signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
return signal_pad

def _gmem_signal_scalar_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
n, = signal_pad.shape
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_gmem_signal_scalar_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)

--- assertExpectedJournal(TestWait.test_signal_cas)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _gmem_signal_cas_kernel_kernel(signal_pad, signal_pad_stride_0):
pid_0 = tl.program_id(0)
offset_0 = pid_0
helion.runtime.triton_wait_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, expect=0, update=1, sem='release', scope='gpu', op='atomic_cas', skip_sync=True, sync_before=not False)

def gmem_signal_cas_kernel(signal_pad: torch.Tensor):
n, = signal_pad.shape
_gmem_signal_cas_kernel_kernel[n,](signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)
return signal_pad

def _gmem_signal_cas_kernel_make_precompiler(signal_pad: torch.Tensor):
n, = signal_pad.shape
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_gmem_signal_cas_kernel_kernel)(signal_pad, signal_pad.stride(0), num_warps=4, num_stages=3)

--- assertExpectedJournal(TestWait.test_signal_multiple)
from __future__ import annotations

import torch
import helion
import triton
import triton.language as tl

@triton.jit
def _gmem_signal_tensor_bar_kernel_kernel(signal_pad, signal_pad_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
helion.runtime.triton_send_signal(addr=signal_pad + indices_0 * signal_pad_stride_0, update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)

def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor):
n, = signal_pad.shape
_BLOCK_SIZE_0 = 4
_gmem_signal_tensor_bar_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0),](signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return signal_pad

def _gmem_signal_tensor_bar_kernel_make_precompiler(signal_pad: torch.Tensor):
n, = signal_pad.shape
_BLOCK_SIZE_0 = 4
from helion.runtime.precompile_shim import make_precompiler
return make_precompiler(_gmem_signal_tensor_bar_kernel_kernel)(signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)

--- assertExpectedJournal(TestWait.test_wait_2d_tile)
from __future__ import annotations

Expand Down
Loading
Loading