diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py index 2c474aaa..0bb5622d 100644 --- a/helion/language/signal_wait.py +++ b/helion/language/signal_wait.py @@ -21,13 +21,14 @@ @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def wait( signal_pad: torch.Tensor, - index: list[object], + index: list[object] | None = None, signal: int = 1, update: int | None = None, op: str = "ld", sem: str = "acquire", scope: str = "gpu", skip_sync: bool = False, + as_ptrs: bool = False, ) -> None: """Wait until all entries of the signal_pad slice are equal to the signal value. Args: @@ -39,6 +40,7 @@ def wait( sem: The memory semantic for acquiring the lock (default: 'acquire') scope: The scope of the lock (default: 'gpu') skip_sync: Skip the syncthreads after the wait (default: False) + as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False) Returns: None @@ -49,14 +51,15 @@ def wait( @_decorators.prepare_args(wait) def _( signal_pad: torch.Tensor, - index: list[object], + index: list[object] | None = None, signal: int = 1, update: int | None = None, op: str = "ld", sem: str = "acquire", scope: str = "gpu", skip_sync: bool = False, -) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]: + as_ptrs: bool = False, +) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool, bool]: from .tile_proxy import Tile valid_ops = {"ld", "atomic_cas"} @@ -88,22 +91,37 @@ def _( if scope not in valid_scopes: raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") + if as_ptrs: + if index is not None: + raise ValueError( + f"When as_ptrs=True, signal_pad must be used without indexing. " + f"Expected 0 indices but got {len(index)}. " + ) + if signal_pad.dtype not in (torch.uint64, torch.int64): + raise ValueError( + f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 " + f"to represent memory pointers. Got dtype {signal_pad.dtype}. " + ) + if index is None: + index = [] + index = Tile._prepare_index(index) index = Tile._tiles_to_sizes(index) - return (signal_pad, index, signal, update, op, sem, scope, skip_sync) + return (signal_pad, index, signal, update, op, sem, scope, skip_sync, as_ptrs) @_decorators.register_fake(wait) def _( signal_pad: torch.Tensor, - index: list[object], + index: list[object] | None = None, signal: int = 1, update: int | None = None, op: str = "ld", sem: str = "acquire", scope: str = "sys", skip_sync: bool = False, + as_ptrs: bool = False, ) -> None: return None @@ -123,35 +141,38 @@ def _(state: CodegenState) -> ast.AST: sem = state.proxy_arg(5) scope = state.proxy_arg(6) skip_sync = state.proxy_arg(7) + as_ptrs = state.proxy_arg(8) assert isinstance(signal_pad, torch.Tensor) assert isinstance(index, (list)) - indices = SubscriptIndexing.create(state, signal_pad, index) - signal_pad_name = state.device_function.tensor_arg(signal_pad).name - - signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType] - update_expr = ast.Constant(value=update) # pyright: ignore[reportArgumentType] - assert type(op) is str assert type(sem) is str assert type(scope) is str - bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index) - is_scalar = len(bar_tensor_shape) == 0 - - if is_scalar: - call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" + if as_ptrs: + bar_tensor_shape = signal_pad.shape + bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))" else: + indices = SubscriptIndexing.create(state, signal_pad, index) if signal_pad.dtype not in (torch.int32, torch.uint32): raise NotImplementedError( f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32." ) - call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" + signal_pad_name = state.device_function.tensor_arg(signal_pad).name + bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index) + bar_addrs = f"{signal_pad_name} + signal_pad_arg" + + signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType] + update_expr = ast.Constant(value=update) # pyright: ignore[reportArgumentType] + + is_scalar = len(bar_tensor_shape) == 0 + + call_triton_wait_signal = f"helion.runtime.triton_wait_{'' if is_scalar else 'multiple_'}signal(addr={bar_addrs}, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" return expr_from_string( call_triton_wait_signal, - offset=indices.index_expr, + signal_pad_arg=state.ast_arg(0) if as_ptrs else indices.index_expr, # pyright: ignore[reportPossiblyUnboundVariable] signal=signal_expr, update=update_expr, ) @@ -161,13 +182,14 @@ def _(state: CodegenState) -> ast.AST: @_decorators.api(tiles_as_sizes=True, allow_host_tensor=True) def signal( signal_pad: torch.Tensor, - index: list[object], + index: list[object] | None = None, signal: int = 1, wait_for: int | None = None, op: str = "atomic_xchg", sem: str = "release", scope: str = "gpu", skip_sync: bool = False, + as_ptrs: bool = False, ) -> torch.Tensor: """Set the signal_pad slice to the signal value. Args: @@ -179,6 +201,9 @@ def signal( sem: The memory semantic for acquiring the lock (default: 'release') scope: The scope of the lock (default: 'gpu') skip_sync: Skip the syncthreads before sending signal (default: False) + as_ptrs: Treat signal_pad as pointers to global memory barriers (default: False) + Returns: + The old value of the signal_pad slice before the update. """ raise exc.NotInsideKernel @@ -186,14 +211,15 @@ def signal( @_decorators.prepare_args(signal) def _( signal_pad: torch.Tensor, - index: list[object], + index: list[object] | None = None, signal: int = 1, wait_for: int | None = None, op: str = "atomic_xchg", sem: str = "release", scope: str = "gpu", skip_sync: bool = False, -) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]: + as_ptrs: bool = False, +) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool, bool]: from .tile_proxy import Tile valid_ops = {"atomic_add", "atomic_xchg", "atomic_cas"} @@ -220,23 +246,42 @@ def _( if scope not in valid_scopes: raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") + if as_ptrs: + if index is not None: + raise ValueError( + f"When as_ptrs=True, signal_pad must be used without indexing. " + f"Expected 0 indices but got {len(index)}. " + ) + if signal_pad.dtype not in (torch.uint64, torch.int64): + raise ValueError( + f"When as_ptrs=True, signal_pad must have dtype torch.uint64 or torch.int64 " + f"to represent memory pointers. Got dtype {signal_pad.dtype}. " + ) + if index is None: + index = [] + index = Tile._prepare_index(index) index = Tile._tiles_to_sizes(index) - return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync) + return (signal_pad, index, signal, wait_for, op, sem, scope, skip_sync, as_ptrs) @_decorators.register_fake(signal) def _( signal_pad: torch.Tensor, - index: list[object], + index: list[object] | None = None, signal: int = 1, wait_for: int | None = None, op: str = "atomic_xchg", sem: str = "release", scope: str = "gpu", skip_sync: bool = False, + as_ptrs: bool = False, ) -> torch.Tensor: + if index is None: + index = [] + if as_ptrs: + return signal_pad.new_empty(signal_pad.shape) return signal_pad.new_empty(SubscriptIndexing.compute_shape(signal_pad, index)) @@ -255,12 +300,29 @@ def _(state: CodegenState) -> ast.AST: sem = state.proxy_arg(5) scope = state.proxy_arg(6) skip_sync = state.proxy_arg(7) + as_ptrs = state.proxy_arg(8) assert isinstance(signal_pad, torch.Tensor) assert isinstance(index, list) - indices = SubscriptIndexing.create(state, signal_pad, index) - signal_pad_name = state.device_function.tensor_arg(signal_pad).name + assert type(op) is str + assert type(sem) is str + assert type(scope) is str + + if as_ptrs: + bar_tensor_shape = signal_pad.shape + bar_addrs = "signal_pad_arg.to(tl.pointer_type(tl.int32))" + else: + indices = SubscriptIndexing.create(state, signal_pad, index) + if signal_pad.dtype not in (torch.int32, torch.uint32): + raise NotImplementedError( + f"Unsupported signal pad dtype: {signal_pad.dtype}. Must be of torch.int32 or torch.uint32." + ) + signal_pad_name = state.device_function.tensor_arg(signal_pad).name + bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index) + bar_addrs = f"{signal_pad_name} + signal_pad_arg" + + is_scalar = len(bar_tensor_shape) == 0 signal_expr = ast.Constant(value=signal) # pyright: ignore[reportArgumentType] if wait_for is not None: @@ -268,30 +330,21 @@ def _(state: CodegenState) -> ast.AST: else: wait_for_expr = ast.Constant(value=0) skip_sync_expr = ast.Constant(value=skip_sync) # pyright: ignore[reportArgumentType] - assert type(op) is str - assert type(sem) is str - assert type(scope) is str if op == "atomic_cas": - bar_tensor_shape = SubscriptIndexing.compute_shape(signal_pad, index) - is_scalar = len(bar_tensor_shape) == 0 - if is_scalar: - call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" - else: - call_triton_wait_signal = f"helion.runtime.triton_wait_multiple_signal(addr={signal_pad_name} + offset, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" - + call_triton_wait_signal = f"helion.runtime.triton_wait_{'' if is_scalar else 'multiple_'}signal(addr={bar_addrs}, expect=wait_for, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=True, sync_before=(not skip_sync))" return expr_from_string( call_triton_wait_signal, - offset=indices.index_expr, + signal_pad_arg=state.ast_arg(0) if as_ptrs else indices.index_expr, # pyright: ignore[reportPossiblyUnboundVariable] wait_for=wait_for_expr, signal=signal_expr, skip_sync=skip_sync_expr, ) - call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={signal_pad_name} + offset, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)" + call_triton_send_signal = f"helion.runtime.triton_send_signal(addr={bar_addrs}, update=signal, sem='{sem}', scope='{scope}', op='{op}', skip_sync=skip_sync)" return expr_from_string( call_triton_send_signal, - offset=indices.index_expr, + signal_pad_arg=state.ast_arg(0) if as_ptrs else indices.index_expr, # pyright: ignore[reportPossiblyUnboundVariable] signal=signal_expr, skip_sync=skip_sync_expr, ) diff --git a/helion/runtime/triton_helpers.py b/helion/runtime/triton_helpers.py index 1e93d446..b57d2463 100644 --- a/helion/runtime/triton_helpers.py +++ b/helion/runtime/triton_helpers.py @@ -171,6 +171,11 @@ def triton_wait_multiple_signal( "Invalid barrier value type. Only supports int32 for multi barrier signal. ", ) + if sync_before: + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) + addr = tl.ravel(addr) tl.static_assert(len(addr.shape) == 1, "addr must be a 1D tensor. ") diff --git a/test/test_signal_wait.expected b/test/test_signal_wait.expected index fff0ebce..c8532bd5 100644 --- a/test/test_signal_wait.expected +++ b/test/test_signal_wait.expected @@ -110,6 +110,35 @@ def gmem_signal_tensor_bar_kernel(signal_pad: torch.Tensor, *, _launcher=_defaul _launcher(_gmem_signal_tensor_bar_kernel_kernel, (triton.cdiv(n, _BLOCK_SIZE_0),), signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) return signal_pad +--- assertExpectedJournal(TestWait.test_signal_pointers) +from __future__ import annotations + +import torch +import helion +import helion.language as hl +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _gmem_signal_pointers_kernel_kernel(signal_pad_ptrs, signal_pad_ptrs_stride_0, N, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < N + load = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0) + symnode_0 = 4 * offset_0 + v_0 = symnode_0.to(tl.uint64) + v_1 = load + v_0 + helion.runtime.triton_send_signal(addr=v_1.to(tl.pointer_type(tl.int32)), update=1, sem='release', scope='gpu', op='atomic_xchg', skip_sync=False) + +def gmem_signal_pointers_kernel(signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr, *, _launcher=_default_launcher): + N = signal_pad_ptrs.size(0) + _BLOCK_SIZE_1 = N + _launcher(_gmem_signal_pointers_kernel_kernel, (4,), signal_pad_ptrs, signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return signal_pad_ptrs + --- assertExpectedJournal(TestWait.test_wait_2d_tile) from __future__ import annotations @@ -215,3 +244,34 @@ def gmem_wait_multi_bar_kernel_cas(signal_pad: torch.Tensor, *, _launcher=_defau _BLOCK_SIZE_0 = 4 _launcher(_gmem_wait_multi_bar_kernel_cas_kernel, (triton.cdiv(N, _BLOCK_SIZE_0),), signal_pad, signal_pad.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) return signal_pad + +--- assertExpectedJournal(TestWait.test_wait_pointers) +from __future__ import annotations + +import torch +import helion +import helion.language as hl +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _gmem_wait_pointers_kernel_kernel(signal_pad_ptrs, out, out_stride_0, signal_pad_ptrs_stride_0, N, _BLOCK_SIZE_1: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + for offset_1 in tl.range(0, N.to(tl.int32), _BLOCK_SIZE_1): + indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32) + mask_1 = indices_1 < N + load = tl.load(signal_pad_ptrs + indices_1 * signal_pad_ptrs_stride_0, mask_1, other=0) + symnode_0 = 4 * offset_0 + v_0 = symnode_0.to(tl.uint64) + v_1 = load + v_0 + helion.runtime.triton_wait_multiple_signal(addr=v_1.to(tl.pointer_type(tl.int32)), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) + tl.store(out + offset_0 * out_stride_0, offset_0, None) + +def gmem_wait_pointers_kernel(signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr, *, _launcher=_default_launcher): + out = torch.empty(4, device=signal_pad_ptrs.device, dtype=torch.int32) + N = signal_pad_ptrs.size(0) + _BLOCK_SIZE_1 = N + _launcher(_gmem_wait_pointers_kernel_kernel, (4,), signal_pad_ptrs, out, out.stride(0), signal_pad_ptrs.stride(0), N, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return out diff --git a/test/test_signal_wait.py b/test/test_signal_wait.py index dbff6046..c0b3f7d9 100644 --- a/test/test_signal_wait.py +++ b/test/test_signal_wait.py @@ -230,6 +230,142 @@ def gmem_multi_bar_sync_kernel(signal_pad: torch.Tensor) -> torch.Tensor: ) self.assertIn("atomic_cas", code) + def test_wait_pointers(self): + @helion.kernel + def gmem_wait_pointers_kernel( + signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr + ) -> torch.Tensor: + out = torch.empty( + pad_shape, device=signal_pad_ptrs.device, dtype=torch.int32 + ) + N = signal_pad_ptrs.size(0) + for i in hl.grid(pad_shape): + offset = i * 4 # number of btypes of each pointer (torch.int32) + for multicast_tile in hl.tile(N, block_size=N): + signal_pads = ( + signal_pad_ptrs[multicast_tile] + offset + ) # Load the pointers, and broadcast the offset + hl.wait(signal_pads, signal=1, as_ptrs=True) + out[i] = i + return out + + signal_pad_list = [ + torch.ones(4, device=DEVICE, dtype=torch.int32) for _ in range(4) + ] + signal_pad_ptrs = torch.as_tensor( + [p.data_ptr() for p in signal_pad_list], device=DEVICE, dtype=torch.uint64 + ) + code, result = code_and_output(gmem_wait_pointers_kernel, (signal_pad_ptrs, 4)) + torch.testing.assert_close( + result, torch.arange(4, device=DEVICE, dtype=torch.int32) + ) + self.assertExpectedJournal(code) + + def test_signal_pointers(self): + @helion.kernel + def gmem_signal_pointers_kernel( + signal_pad_ptrs: torch.Tensor, pad_shape: hl.constexpr + ) -> torch.Tensor: + N = signal_pad_ptrs.size(0) + for i in hl.grid(pad_shape): + offset = i * 4 # number of btypes of each pointer (torch.int32) + for multicast_tile in hl.tile(N, block_size=N): + signal_pads = signal_pad_ptrs[multicast_tile] + offset + hl.signal(signal_pads, signal=1, as_ptrs=True) + return signal_pad_ptrs + + signal_pad_list = [ + torch.zeros(4, device=DEVICE, dtype=torch.int32) for _ in range(4) + ] + signal_pad_ptrs = torch.as_tensor( + [p.data_ptr() for p in signal_pad_list], device=DEVICE, dtype=torch.uint64 + ) + code, result = code_and_output( + gmem_signal_pointers_kernel, (signal_pad_ptrs, 4) + ) + + for tensor in signal_pad_list: + torch.testing.assert_close( + tensor, torch.ones(4, device=DEVICE, dtype=torch.int32) + ) + self.assertExpectedJournal(code) + + def test_global_sync_on_pointers(self): + @helion.kernel + def gmem_multi_bar_sync_ptrs_kernel( + signal_pad_ptrs: torch.Tensor, + pad_shape: hl.constexpr, + ) -> torch.Tensor: + N = hl.specialize(signal_pad_ptrs.size(0)) + for i in hl.grid(pad_shape): + signal_offset = ( + i * 4 + ) # 4 = number of btypes of each pointer (torch.int32) + for multicast_tile in hl.tile(N, block_size=N): + hl.signal( + signal_pad_ptrs[multicast_tile] + signal_offset, + signal=1, + skip_sync=True, + as_ptrs=True, + ) + wait_offsets = torch.arange(N, device=DEVICE, dtype=torch.uint64) * 4 + wait_ptrs = signal_pad_ptrs[i] + wait_offsets + hl.wait(wait_ptrs, signal=1, as_ptrs=True) + + signal_pad_list = [ + torch.zeros(4, device=DEVICE, dtype=torch.int32) for _ in range(4) + ] + signal_pad_ptrs = torch.as_tensor( + [p.data_ptr() for p in signal_pad_list], device=DEVICE, dtype=torch.uint64 + ) + + code, result = code_and_output( + gmem_multi_bar_sync_ptrs_kernel, (signal_pad_ptrs, 4) + ) + for tensor in signal_pad_list: + torch.testing.assert_close( + tensor, torch.ones(4, device=DEVICE, dtype=torch.int32) + ) + + def test_global_sync_on_pointers_cas(self): + @helion.kernel + def gmem_multi_bar_sync_ptrs_kernel( + signal_pad_ptrs: torch.Tensor, + pad_shape: hl.constexpr, + ) -> torch.Tensor: + N = hl.specialize(signal_pad_ptrs.size(0)) + for i in hl.grid(pad_shape): + signal_offset = ( + i * 4 + ) # 4 = number of btypes of each pointer (torch.int32) + for multicast_tile in hl.tile(N, block_size=N): + hl.signal( + signal_pad_ptrs[multicast_tile] + signal_offset, + wait_for=0, + signal=1, + op="atomic_cas", + skip_sync=True, + as_ptrs=True, + ) + wait_offsets = torch.arange(N, device=DEVICE, dtype=torch.uint64) * 4 + wait_ptrs = signal_pad_ptrs[i] + wait_offsets + hl.wait(wait_ptrs, signal=1, update=0, op="atomic_cas", as_ptrs=True) + + signal_pad_list = [ + torch.zeros(4, device=DEVICE, dtype=torch.int32) for _ in range(4) + ] + signal_pad_ptrs = torch.as_tensor( + [p.data_ptr() for p in signal_pad_list], device=DEVICE, dtype=torch.uint64 + ) + + code, result = code_and_output( + gmem_multi_bar_sync_ptrs_kernel, (signal_pad_ptrs, 4) + ) + for tensor in signal_pad_list: + torch.testing.assert_close( + tensor, torch.zeros(4, device=DEVICE, dtype=torch.int32) + ) + if __name__ == "__main__": unittest.main()