diff --git a/examples/all_gather_matmul.py b/examples/all_gather_matmul.py new file mode 100644 index 00000000..e0a87ab4 --- /dev/null +++ b/examples/all_gather_matmul.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import os +from typing import Any + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem + +import helion +import helion.language as hl + + +def copy_engine_all_gather_w_progress( + output: torch.Tensor, + inp: torch.Tensor, # Must be symmetric tensor + progress: torch.Tensor, + splits_per_rank: int, + backend_stream: torch.cuda.Stream | None = None, +) -> torch.cuda.Stream: + backend_stream = symm_mem._get_backend_stream(priority=-1) + assert inp.is_contiguous() + symm_mem_group = dist.group.WORLD + if symm_mem_group is None: + raise RuntimeError("No symmetric memory group available") + symm_mem_hdl = symm_mem.rendezvous(inp, group=symm_mem_group) + assert symm_mem_hdl is not None + + rank = symm_mem_hdl.rank + world_size = symm_mem_hdl.world_size + + assert inp.numel() % splits_per_rank == 0 + assert progress.numel() >= world_size * splits_per_rank + + output_shape = list(inp.shape) + output_shape[0] *= world_size + assert list(output.shape) == output_shape, (list(output.shape), output_shape) + + chunks = output.chunk(world_size * splits_per_rank) + + symm_mem_hdl.barrier() + backend_stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(backend_stream): + for step in range(world_size): + src_rank = (rank + step + 1) % world_size + for split_id in range(splits_per_rank): + src_buf = symm_mem_hdl.get_buffer( + src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id + ) + chunks[src_rank * splits_per_rank + split_id].copy_(src_buf) + # cuStreamWriteValue32 issues a system level fence before the write + symm_mem_hdl.stream_write_value32( + progress, + offset=src_rank * splits_per_rank + split_id, + val=1, + ) + symm_mem_hdl.barrier() + + return backend_stream + + +# TODO(joydddd): add support for auto-tuning on multiple process runs. +# Please hardcode helion config for multiprocess runs initiated by torchrun. +@helion.jit( + config=helion.Config( + block_sizes=[128, 256, 64], + num_warps=8, + num_stages=3, + indexing="block_ptr", + ), + static_shapes=True, +) +def helion_matmul_w_progress( + a: torch.Tensor, + a_shared: torch.Tensor, + b: torch.Tensor, + progress: torch.Tensor, + SPLITS_PER_RANK: int, + RANK: int, +) -> torch.Tensor: + M, K = a.size() + K2, N = b.size() + assert K2 == K, f"size mismatch {K2} != {K}" + + out = torch.empty( + [M, N], dtype=torch.promote_types(a.dtype, b.dtype), device=a.device + ) + + M_per_rank = a_shared.size(0) + + for tile_m, tile_n in hl.tile([M, N]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + hl.wait( + progress, + [ + tile_m.begin // (M_per_rank // SPLITS_PER_RANK), + ], + signal=1, + update=None, + op="ld", + scope="gpu", + sem="acquire", + ) + for tile_k in hl.tile(K): + # TODO(joydddd): use a_shared and skipp barrier when data is available on local rank. + # if tile_k.begin // M_per_rank == RANK: + # acc = torch.addmm(acc, a_shared[tile_m.index - RANK * M_per_rank, tile_k], b[tile_k, tile_n]) + # else: + # hl.wait(progress, [tile_m.begin // (M_per_rank // SPLITS_PER_RANK),], signal=1, update=None, op="ld", scope="gpu", sem="acquire") + acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n]) + out[tile_m, tile_n] = acc + return out + + +def helion_all_gather_matmul( + a_shared: torch.Tensor, + b: torch.Tensor, + a_out: torch.Tensor | None = None, + progress: torch.Tensor | None = None, + **kwargs: Any, +) -> tuple[torch.Tensor, torch.Tensor]: + configs = { + "SPLITS_PER_RANK": kwargs.get("splits_per_rank", 1), + } + + symm_mem_group = dist.group.WORLD + if symm_mem_group is None: + raise RuntimeError("No symmetric memory group available") + + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=symm_mem_group) + + a_shape = list(a_shared.shape) + a_shape[0] *= symm_mem_hdl.world_size + + configs["RANK"] = symm_mem_hdl.rank + configs["WORLD_SIZE"] = symm_mem_hdl.world_size + + if a_out is None: + a_out = torch.empty(a_shape, dtype=a_shared.dtype, device=a_shared.device) + + if progress is None: + progress = torch.zeros( + symm_mem_hdl.world_size * configs["SPLITS_PER_RANK"], + dtype=torch.uint32, + device=a_shared.device, + ) + else: + progress.fill_( + 0 + ) # Reset progress to 0. Maybe we should reset inside the kernel using cas? + + backend_stream = copy_engine_all_gather_w_progress( + a_out, a_shared, progress, configs["SPLITS_PER_RANK"] + ) + + c = helion_matmul_w_progress( + a_out, + a_shared, + b, + progress, + SPLITS_PER_RANK=configs["SPLITS_PER_RANK"], + RANK=configs["RANK"], + ) + assert type(c) is torch.Tensor + + torch.cuda.current_stream().wait_stream(backend_stream) + + return a_out, c + + +def test(M: int, N: int, K: int, world_size: int, device: torch.device) -> None: + a_shared = symm_mem.empty( + M // world_size, K, dtype=torch.bfloat16, device=device + ).normal_() + b = torch.randn((K, N), device="cuda", dtype=torch.bfloat16).T.contiguous().T + + a_out, c = helion_all_gather_matmul(a_shared, b) + + golden_a = a_shared.clone() + dist_group = dist.group.WORLD + if dist_group is None: + raise RuntimeError("No distributed group available") + ag_golden, mm_golden = torch.ops.symm_mem.fused_all_gather_matmul( + golden_a, [b], gather_dim=0, group_name=dist_group.group_name + ) + torch.testing.assert_close(c, mm_golden[0], rtol=1e-1, atol=1e-1) + torch.testing.assert_close(a_out, ag_golden) + + +def main() -> None: + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.manual_seed(42 + rank) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dist.init_process_group("nccl") + test(4096, 6656, 16384, world_size, device) + + dist.destroy_process_group() + + +if __name__ == "__main__": + """ + Run with: + torchrun \ + --nnodes 1 --nproc-per-node 8 \ + --rdzv-backend c10d --rdzv-endpoint localhost:0 \ + --no_python python3 examples/all_gather_matmul.py + """ + main() diff --git a/helion/language/__init__.py b/helion/language/__init__.py index 884d845b..277f1cc5 100644 --- a/helion/language/__init__.py +++ b/helion/language/__init__.py @@ -15,6 +15,7 @@ from .scan_ops import associative_scan as associative_scan from .scan_ops import cumprod as cumprod from .scan_ops import cumsum as cumsum +from .signal_wait import wait as wait from .tile_ops import tile_begin as tile_begin from .tile_ops import tile_block_size as tile_block_size from .tile_ops import tile_end as tile_end diff --git a/helion/language/signal_wait.py b/helion/language/signal_wait.py new file mode 100644 index 00000000..7d8b81cd --- /dev/null +++ b/helion/language/signal_wait.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch.fx import has_side_effect + +from .. import exc +from . import _decorators + +if TYPE_CHECKING: + import ast + + from .._compiler.inductor_lowering import CodegenState + + +@has_side_effect +@_decorators.api(tiles_as_sizes=True) +def wait( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + update: int | None = None, + op: str = "ld", + sem: str = "acquire", + scope: str = "gpu", + skip_sync: bool = False, +) -> None: + """Wait until all entries of the signal_pad slice are equal to the signal value. + Args: + signal_pad: The signal pad tensor to wait on + index: Indices to index into the signal_pad tensor + signal: the value to wait for + update: Atomically update the signal_pad tensor with this value once the signal is observed. (default: None) + op: The memory op for acquring the lock (default: 'ld') + sem: The memory sematic for acquring the lock (default: 'acquire') + scope: The scope of the lock (default: 'gpu') + skip_sync: Skip the syncthreads after the wait (default: False) + + Returns: + None + """ + raise exc.NotInsideKernel + + +@_decorators.prepare_args(wait) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + update: int | None = None, + op: str = "ld", + sem: str = "acquire", + scope: str = "gpu", + skip_sync: bool = False, +) -> tuple[torch.Tensor, object, int, int | None, str, str, str, bool]: + from helion.language.tile_proxy import Tile + + valid_ops = {"ld", "atomic_cas"} + valid_sems = {"relaxed", "acquire", "acq_rel"} + valid_scopes = {"sys", "gpu"} + + if op not in valid_ops: + raise ValueError(f"Invalid Wait op '{op}'. Must be one of {valid_ops}. ") + + if sem == "release": + raise ValueError( + f"Do not use '{sem}' for wait patterns. Wait sem must be one of {valid_sems}." + ) + + if sem not in valid_sems: + raise ValueError( + f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}." + ) + + if op == "atomic_cas" and not update: + raise ValueError( + f"{op} without an update value. Do you want to use 'ld' instead? " + ) + + if op == "ld": + assert update is None + update = 0 + + if scope not in valid_scopes: + raise ValueError(f"Invalid scope '{scope}'. Must be one of {valid_scopes}.") + + # TODO(joydddd): add support for non scalar index into signal_pad + for i in index: + assert isinstance(i, int | torch.SymInt) + + index = Tile._prepare_index(index) + index = Tile._tiles_to_sizes(index) + + return (signal_pad, index, signal, update, op, sem, scope, skip_sync) + + +@_decorators.register_fake(wait) +def _( + signal_pad: torch.Tensor, + index: list[object], + signal: int = 1, + update: int | None = None, + op: str = "ld", + sem: str = "acquire", + scope: str = "sys", + skip_sync: bool = False, +) -> None: + return None + + +@_decorators.codegen(wait) +def _(state: CodegenState) -> ast.AST: + import ast + + from .._compiler.ast_extension import expr_from_string + from .._compiler.indexing_strategy import SubscriptIndexing + + signal_pad = state.proxy_arg(0) + index = state.proxy_arg(1) + signal = state.proxy_arg(2) + update = state.proxy_arg(3) + op = state.proxy_arg(4) + sem = state.proxy_arg(5) + scope = state.proxy_arg(6) + skip_sync = state.proxy_arg(7) + + assert isinstance(signal_pad, torch.Tensor) + assert isinstance(index, (list)) + + indices = SubscriptIndexing.create(state, signal_pad, index) + signal_pad_name = state.device_function.tensor_arg(signal_pad).name + + signal_expr = ast.Constant(value=signal) + update_expr = ast.Constant(value=update) + + assert type(op) is str + assert type(sem) is str + assert type(scope) is str + + call_triton_wait_signal = f"helion.runtime.triton_wait_signal(addr={signal_pad_name} + offset, expect=signal, update=update, sem='{sem}', scope='{scope}', op='{op}', skip_sync={skip_sync})" + + return expr_from_string( + call_triton_wait_signal, + offset=indices.index_expr, + signal=signal_expr, + update=update_expr, + ) diff --git a/helion/runtime/__init__.py b/helion/runtime/__init__.py index 98ce513b..4b7259fb 100644 --- a/helion/runtime/__init__.py +++ b/helion/runtime/__init__.py @@ -8,6 +8,7 @@ from .config import Config as Config from .kernel import Kernel as Kernel from .kernel import kernel as kernel +from .triton_helpers import triton_wait_signal as triton_wait_signal def _alloc_fn(size: int, alignment: int, stream: int | None) -> torch.Tensor: diff --git a/helion/runtime/triton_helpers.py b/helion/runtime/triton_helpers.py new file mode 100644 index 00000000..a96e2d41 --- /dev/null +++ b/helion/runtime/triton_helpers.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import triton +import triton.language as tl + +__all__ = ["triton_wait_signal"] + + +@triton.jit +def triton_wait_signal( + addr: tl.tensor, + expect: tl.constexpr, + update: tl.constexpr, + sem: tl.constexpr, + scope: tl.constexpr, + op: tl.constexpr, + skip_sync: tl.constexpr, +) -> None: + """ + Wait for a global memory barrier to reach the expected value. + + This function implements a spin-wait loop that continuously checks a memory location + until it reaches the expected value, providing synchronization across CTAs. + + Args: + addr: Memory address of the barrier to wait on (Must be a scalar) + expect: Expected value to wait for + update: Update the barrier with once acquired + sem: Memory semantics for the atomic operation. Options: "acquire", "relaxed". + scope: Scope of the atomic operation. Options: "gpu", "sys" + op: Atomic operation type: "ld", "atomic_cas" + skip_sync: Skip CTA sync after acquiring the barrier (default: False) + """ + tl.static_assert( + addr.type.is_ptr(), + "Barrier address must be a scalar. Do you want to use '_triton_wait_multiple_signal'? ", + ) + + tl.static_assert( + sem == "acquire" or sem == "relaxed", + "Invalid memory semantic. options: 'acquire', 'relaxed'. ", + ) + tl.static_assert( + scope == "gpu" or scope == "sys", "Invalid scope. options: 'gpu', 'sys'. " + ) + tl.static_assert( + op == "ld" or op == "atomic_cas", + "Invalid op. options: 'ld', 'atomic_cas'. ", + ) + + # Spin-wait loop: + # Uses atomic_add with update=0 for ld.global.{sem}.{scope} + # Triton generates smem broadcasting of tl.atomic_add return value in ptx, + # but it is optimized away by ptxas in SASS, hence no performance overhead. + if op == "ld": + tl.static_assert( + update == 0, "ld wait on gmem_barriers cannot update the lock. " + ) + while tl.atomic_add(addr, 0, sem=sem, scope=scope) != expect: + pass + elif op == "atomic_cas": + while tl.atomic_cas(addr, expect, update, sem=sem, scope=scope) != expect: + pass + else: + raise NotImplementedError( + f"Unsupported op '{op}' for wait signal on gmem barrier. " + ) + + if not skip_sync: + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) + # tl.debug_barrier() cause significant performance loss. (Perhaps breaks triton prefetching?) diff --git a/test/test_signal_wait.expected b/test/test_signal_wait.expected new file mode 100644 index 00000000..f9f30a22 --- /dev/null +++ b/test/test_signal_wait.expected @@ -0,0 +1,71 @@ +This file is automatically generated by assertExpectedJournal calls in test_signal_wait.py. +Update expected outputs by running tests with the EXPECTTEST_ACCEPT=1 environment variable set. + +--- assertExpectedJournal(TestWait.test_wait_2d_tile) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl + +@triton.jit +def _wait_for_2d_tile_kernel_kernel(signal_pad, x, out, out_stride_0, out_stride_1, signal_pad_stride_0, signal_pad_stride_1, x_stride_0, x_stride_1, n, m, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr): + num_blocks_0 = tl.cdiv(n, _BLOCK_SIZE_0) + pid_0 = tl.program_id(0) % num_blocks_0 + pid_1 = tl.program_id(0) // num_blocks_0 + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < n + offset_1 = pid_1 * _BLOCK_SIZE_1 + indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32) + mask_1 = indices_1 < m + tile_id = offset_0 // _BLOCK_SIZE_0 + tile_id_1 = offset_1 // _BLOCK_SIZE_1 + helion.runtime.triton_wait_signal(addr=signal_pad + (tile_id * signal_pad_stride_0 + tile_id_1 * signal_pad_stride_1), expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) + load = tl.load(x + (indices_0[:, None] * x_stride_0 + indices_1[None, :] * x_stride_1), mask_0[:, None] & mask_1[None, :], other=0) + tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), load, mask_0[:, None] & mask_1[None, :]) + +def wait_for_2d_tile_kernel(signal_pad: torch.Tensor, x: torch.Tensor): + out = torch.empty_like(x) + n, m = x.shape + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + _wait_for_2d_tile_kernel_kernel[triton.cdiv(n, _BLOCK_SIZE_0) * triton.cdiv(m, _BLOCK_SIZE_1),](signal_pad, x, out, out.stride(0), out.stride(1), signal_pad.stride(0), signal_pad.stride(1), x.stride(0), x.stride(1), n, m, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + return out + +def _wait_for_2d_tile_kernel_make_precompiler(signal_pad: torch.Tensor, x: torch.Tensor): + out = torch.empty_like(x) + n, m = x.shape + _BLOCK_SIZE_0 = 16 + _BLOCK_SIZE_1 = 16 + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_wait_for_2d_tile_kernel_kernel)(signal_pad, x, out, out.stride(0), out.stride(1), signal_pad.stride(0), signal_pad.stride(1), x.stride(0), x.stride(1), n, m, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=4, num_stages=3) + +--- assertExpectedJournal(TestWait.test_wait_basic) +from __future__ import annotations + +import torch +import helion +import triton +import triton.language as tl + +@triton.jit +def _gmem_wait_kernel_kernel(signal_pad, out, out_stride_0, signal_pad_stride_0): + pid_0 = tl.program_id(0) + offset_0 = pid_0 + helion.runtime.triton_wait_signal(addr=signal_pad + offset_0 * signal_pad_stride_0, expect=1, update=0, sem='acquire', scope='gpu', op='ld', skip_sync=False) + tl.store(out + offset_0 * out_stride_0, offset_0, None) + +def gmem_wait_kernel(signal_pad: torch.Tensor): + out = torch.empty_like(signal_pad) + n, = signal_pad.shape + _gmem_wait_kernel_kernel[n,](signal_pad, out, out.stride(0), signal_pad.stride(0), num_warps=4, num_stages=3) + return out + +def _gmem_wait_kernel_make_precompiler(signal_pad: torch.Tensor): + out = torch.empty_like(signal_pad) + n, = signal_pad.shape + from helion.runtime.precompile_shim import make_precompiler + return make_precompiler(_gmem_wait_kernel_kernel)(signal_pad, out, out.stride(0), signal_pad.stride(0), num_warps=4, num_stages=3) + diff --git a/test/test_signal_wait.py b/test/test_signal_wait.py new file mode 100644 index 00000000..fbb5ed36 --- /dev/null +++ b/test/test_signal_wait.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import unittest + +import torch + +import helion +from helion._testing import DEVICE +from helion._testing import TestCase +from helion._testing import code_and_output +import helion.language as hl + + +class TestWait(TestCase): + def test_wait_basic(self): + @helion.kernel + def gmem_wait_kernel(signal_pad: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(signal_pad) + (n,) = signal_pad.shape + for i in hl.grid(n): + hl.wait(signal_pad, [i], signal=1) + out[i] = i + + return out + + signal_pad = torch.ones(4, device=DEVICE, dtype=torch.int32) + code, result = code_and_output(gmem_wait_kernel, (signal_pad,)) + torch.testing.assert_close( + result, torch.arange(4, device=DEVICE, dtype=torch.int32) + ) + self.maxDiff = None + self.assertExpectedJournal(code) + + def test_wait_2d_tile(self): + @helion.kernel + def wait_for_2d_tile_kernel( + signal_pad: torch.Tensor, x: torch.Tensor + ) -> torch.Tensor: + out = torch.empty_like(x) + (n, m) = x.shape + for tile_n, tile_m in hl.tile([n, m]): + hl.wait(signal_pad, [tile_n.id, tile_m.id], signal=1) + out[tile_n, tile_m] = x[tile_n, tile_m] + return out + + signal_pad = torch.ones([4, 4], device=DEVICE, dtype=torch.int32) + x = torch.randn([64, 64], device=DEVICE, dtype=torch.bfloat16) + code, result = code_and_output( + wait_for_2d_tile_kernel, + (signal_pad, x), + block_size=[16, 16], + ) + + torch.testing.assert_close(result, x) + self.assertExpectedJournal(code) + + +if __name__ == "__main__": + unittest.main()