Skip to content

Add indirect pointer to barrier support in hl.signal & hl.wait (as_ptrs) #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 91 additions & 38 deletions helion/language/signal_wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
def wait(
signal_pad: torch.Tensor,
index: list[object],
index: list[object] | None = None,
signal: int = 1,
update: int | None = None,
op: str = "ld",
sem: str = "acquire",
scope: str = "gpu",
skip_sync: bool = False,
as_ptrs: bool = False,
) -> None:
"""Wait until all entries of the signal_pad slice are equal to the signal value.
Args:
Expand All @@ -39,6 +40,7 @@ def wait(
sem: The memory semantic for acquiring the lock (default: 'acquire')
scope: The scope of the lock (default: 'gpu')
skip_sync: Skip the syncthreads after the wait (default: False)
as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)

Returns:
None
Expand All @@ -49,14 +51,15 @@ def wait(
@_decorators.prepare_args(wait)
def _(
signal_pad: torch.Tensor,
index: list[object],
index: list[object] | None = None,
signal: int = 1,
update: int | None = None,
op: str = "ld",
sem: str = "acquire",
scope: str = "gpu",
skip_sync: bool = False,
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]:
as_ptrs: bool = False,
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool, bool]:
from .tile_proxy import Tile

valid_ops = {"ld", "atomic_cas"}
Expand Down Expand Up @@ -88,22 +91,37 @@ def _(
if scope not in valid_scopes:
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")

if as_ptrs:
if index is not None:
raise ValueError(
f"When as_ptrs=True, signal_pad must be used without indexing. "
f"Expected 0 indices but got {len(index)}. "
)
if signal_pad.dtype not in (torch.uint64, torch.int64):
raise ValueError(
f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
f"to represent memory pointers. Got dtype {signal_pad.dtype}. "
)
if index is None:
index = []

index = Tile._prepare_index(index)
index = Tile._tiles_to_sizes(index)

return (signal_pad, index, signal, update, op, sem, scope, skip_sync)
return (signal_pad, index, signal, update, op, sem, scope, skip_sync, as_ptrs)


@_decorators.register_fake(wait)
def _(
signal_pad: torch.Tensor,
index: list[object],
index: list[object] | None = None,
signal: int = 1,
update: int | None = None,
op: str = "ld",
sem: str = "acquire",
scope: str = "sys",
skip_sync: bool = False,
as_ptrs: bool = False,
) -> None:
return None

Expand All @@ -123,35 +141,38 @@ def _(state: CodegenState) -> ast.AST:
sem = state.proxy_arg(5)
scope = state.proxy_arg(6)
skip_sync = state.proxy_arg(7)
as_ptrs = state.proxy_arg(8)

assert isinstance(signal_pad, torch.Tensor)
assert isinstance(index, (list))

indices = SubscriptIndexing.create(state, signal_pad, index)
signal_pad_name = state.device_function.tensor_arg(signal_pad).name

signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType]
update_expr = ast.Constant(value=update) # pyright: ignore[reportArgumentType]

assert type(op) is str
assert type(sem) is str
assert type(scope) is str

bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
is_scalar = len(bar_tensor_shape) == 0

if is_scalar:
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
if as_ptrs:
bar_tensor_shape = signal_pad.shape
bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
else:
indices = SubscriptIndexing.create(state, signal_pad, index)
if signal_pad.dtype not in (torch.int32, torch.uint32):
raise NotImplementedError(
f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32."
)
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
bar_addrs = f"{signal_pad_name} + signal_pad_arg"

signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType]
update_expr = ast.Constant(value=update) # pyright: ignore[reportArgumentType]

is_scalar = len(bar_tensor_shape) == 0

call_triton_wait_signal = f"helion.runtime.triton_wait_{'' if is_scalar else 'multiple_'}signal(addr={bar_addrs}, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})"

return expr_from_string(
call_triton_wait_signal,
offset=indices.index_expr,
signal_pad_arg=state.ast_arg(0) if as_ptrs else indices.index_expr, # pyright: ignore[reportPossiblyUnboundVariable]
signal=signal_expr,
update=update_expr,
)
Expand All @@ -161,13 +182,14 @@ def _(state: CodegenState) -> ast.AST:
@_decorators.api(tiles_as_sizes=True, allow_host_tensor=True)
def signal(
signal_pad: torch.Tensor,
index: list[object],
index: list[object] | None = None,
signal: int = 1,
wait_for: int | None = None,
op: str = "atomic_xchg",
sem: str = "release",
scope: str = "gpu",
skip_sync: bool = False,
as_ptrs: bool = False,
) -> torch.Tensor:
"""Set the signal_pad slice to the signal value.
Args:
Expand All @@ -179,21 +201,25 @@ def signal(
sem: The memory semantic for acquiring the lock (default: 'release')
scope: The scope of the lock (default: 'gpu')
skip_sync: Skip the syncthreads before sending signal (default: False)
as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False)
Returns:
The old value of the signal_pad slice before the update.
"""
raise exc.NotInsideKernel


@_decorators.prepare_args(signal)
def _(
signal_pad: torch.Tensor,
index: list[object],
index: list[object] | None = None,
signal: int = 1,
wait_for: int | None = None,
op: str = "atomic_xchg",
sem: str = "release",
scope: str = "gpu",
skip_sync: bool = False,
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]:
as_ptrs: bool = False,
) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool, bool]:
from .tile_proxy import Tile

valid_ops = {"atomic_add", "atomic_xchg", "atomic_cas"}
Expand All @@ -220,23 +246,42 @@ def _(
if scope not in valid_scopes:
raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.")

if as_ptrs:
if index is not None:
raise ValueError(
f"When as_ptrs=True, signal_pad must be used without indexing. "
f"Expected 0 indices but got {len(index)}. "
)
if signal_pad.dtype not in (torch.uint64, torch.int64):
raise ValueError(
f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 "
f"to represent memory pointers. Got dtype {signal_pad.dtype}. "
)
if index is None:
index = []

index = Tile._prepare_index(index)
index = Tile._tiles_to_sizes(index)

return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync)
return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync, as_ptrs)


@_decorators.register_fake(signal)
def _(
signal_pad: torch.Tensor,
index: list[object],
index: list[object] | None = None,
signal: int = 1,
wait_for: int | None = None,
op: str = "atomic_xchg",
sem: str = "release",
scope: str = "gpu",
skip_sync: bool = False,
as_ptrs: bool = False,
) -> torch.Tensor:
if index is None:
index = []
if as_ptrs:
return signal_pad.new_empty(signal_pad.shape)
return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index))


Expand All @@ -255,43 +300,51 @@ def _(state: CodegenState) -> ast.AST:
sem = state.proxy_arg(5)
scope = state.proxy_arg(6)
skip_sync = state.proxy_arg(7)
as_ptrs = state.proxy_arg(8)

assert isinstance(signal_pad, torch.Tensor)
assert isinstance(index, list)

indices = SubscriptIndexing.create(state, signal_pad, index)
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
assert type(op) is str
assert type(sem) is str
assert type(scope) is str

if as_ptrs:
bar_tensor_shape = signal_pad.shape
bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))"
else:
indices = SubscriptIndexing.create(state, signal_pad, index)
if signal_pad.dtype not in (torch.int32, torch.uint32):
raise NotImplementedError(
f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32."
)
signal_pad_name = state.device_function.tensor_arg(signal_pad).name
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
bar_addrs = f"{signal_pad_name} + signal_pad_arg"

is_scalar = len(bar_tensor_shape) == 0

signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType]
if wait_for is not None:
wait_for_expr = ast.Constant(value=wait_for) # pyright: ignore[reportArgumentType]
else:
wait_for_expr = ast.Constant(value=0)
skip_sync_expr = ast.Constant(value=skip_sync) # pyright: ignore[reportArgumentType]
assert type(op) is str
assert type(sem) is str
assert type(scope) is str

if op == "atomic_cas":
bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index)
is_scalar = len(bar_tensor_shape) == 0
if is_scalar:
call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
else:
call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"

call_triton_wait_signal = f"helion.runtime.triton_wait_{'' if is_scalar else 'multiple_'}signal(addr={bar_addrs}, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))"
return expr_from_string(
call_triton_wait_signal,
offset=indices.index_expr,
signal_pad_arg=state.ast_arg(0) if as_ptrs else indices.index_expr, # pyright: ignore[reportPossiblyUnboundVariable]
wait_for=wait_for_expr,
signal=signal_expr,
skip_sync=skip_sync_expr,
)
call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)"
call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={bar_addrs}, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)"

return expr_from_string(
call_triton_send_signal,
offset=indices.index_expr,
signal_pad_arg=state.ast_arg(0) if as_ptrs else indices.index_expr, # pyright: ignore[reportPossiblyUnboundVariable]
signal=signal_expr,
skip_sync=skip_sync_expr,
)
5 changes: 5 additions & 0 deletions helion/runtime/triton_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ def triton_wait_multiple_signal(
"Invalid barrier value type. Only supports int32 for multi barrier signal. ",
)

if sync_before:
tl.inline_asm_elementwise(
"bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1
)

addr = tl.ravel(addr)

tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ")
Expand Down
60 changes: 60 additions & 0 deletions test/test_signal_wait.expected
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,35 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor, *, _launcher=_defaul
_launcher(_gmem_signal_tensor_bar_kernel_kernel, (triton.cdiv(n, _BLOCK_SIZE_0),), signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return signal_pad

--- assertExpectedJournal(TestWait.test_signal_pointers)
from __future__ import annotations

import torch
import helion
import helion.language as hl
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _gmem_signal_pointers_kernel_kernel(signal_pad_ptrs, signal_pad_ptrs_stride_0, N, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < N
load = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
symnode_0 = 4 * offset_0
v_0 = symnode_0.to(tl.uint64)
v_1 = load + v_0
helion.runtime.triton_send_signal(addr=v_1.to(tl.pointer_type(tl.int32)), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False)

def gmem_signal_pointers_kernel(signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr, *, _launcher=_default_launcher):
N = signal_pad_ptrs.size(0)
_BLOCK_SIZE_1 = N
_launcher(_gmem_signal_pointers_kernel_kernel, (4,), signal_pad_ptrs, signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return signal_pad_ptrs

--- assertExpectedJournal(TestWait.test_wait_2d_tile)
from __future__ import annotations

Expand Down Expand Up @@ -215,3 +244,34 @@ def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor, *, _launcher=_defau
_BLOCK_SIZE_0 = 4
_launcher(_gmem_wait_multi_bar_kernel_cas_kernel, (triton.cdiv(N, _BLOCK_SIZE_0),), signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return signal_pad

--- assertExpectedJournal(TestWait.test_wait_pointers)
from __future__ import annotations

import torch
import helion
import helion.language as hl
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _gmem_wait_pointers_kernel_kernel(signal_pad_ptrs, out, out_stride_0, signal_pad_ptrs_stride_0, N, _BLOCK_SIZE_1: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0
for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < N
load = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0)
symnode_0 = 4 * offset_0
v_0 = symnode_0.to(tl.uint64)
v_1 = load + v_0
helion.runtime.triton_wait_multiple_signal(addr=v_1.to(tl.pointer_type(tl.int32)), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False)
tl.store(out + offset_0 * out_stride_0, offset_0, None)

def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr, *, _launcher=_default_launcher):
out = torch.empty(4, device=signal_pad_ptrs.device, dtype=torch.int32)
N = signal_pad_ptrs.size(0)
_BLOCK_SIZE_1 = N
_launcher(_gmem_wait_pointers_kernel_kernel, (4,), signal_pad_ptrs, out, out.stride(0), signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out
Loading
Loading