From eaf93f8b796c9ad8cecb61ce590635c9eee256dd Mon Sep 17 00:00:00 2001 From: joydddd Date: Tue, 8 Jul 2025 12:07:02 -0700 Subject: [PATCH] One shot all reduce & symm mem sync stack-info: PR: https://github.com/pytorch-labs/helion/pull/245, branch: joydddd/stack/12 --- examples/all_reduce.py | 190 ++++++++++++++++++++++++++++++++ helion/language/creation_ops.py | 21 +++- test/test_distributed.py | 103 +++++++++++++++++ 3 files changed, 309 insertions(+), 5 deletions(-) create mode 100644 examples/all_reduce.py create mode 100644 test/test_distributed.py diff --git a/examples/all_reduce.py b/examples/all_reduce.py new file mode 100644 index 00000000..b6e315a7 --- /dev/null +++ b/examples/all_reduce.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import os + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +import triton +import triton.language as tl + +import helion +import helion.language as hl + + +# Symmemtric Memory Helpers +@triton.jit +def triton_copy( + inp: tl.int64, # pyright: ignore[reportInvalidTypeForm] + out: tl.tensor, + SIZE: tl.constexpr, +) -> None: + tl.static_assert(out.dtype.is_ptr()) + inp = inp.to(tl.pointer_type(out.dtype.element_ty)) # pyright: ignore[reportAttributeAccessIssue] + addrs = tl.load(inp + tl.arange(0, SIZE)) + tl.store(out + tl.arange(0, SIZE), addrs) + + +def dev_array_to_tensor_short( + dev_array_ptr: int, shape: tuple[int], dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + tensor = torch.empty(shape, dtype=dtype, device=device) + triton_copy[1,](dev_array_ptr, tensor, tensor.numel()) # pyright: ignore[reportArgumentType] + return tensor + + +@helion.jit( + config=helion.Config( + block_sizes=[8192], + num_warps=32, + ), +) +def one_shot_all_reduce_kernel_8( + signal_pad_addrs: torch.Tensor, + local_signal_pad: torch.Tensor, + a_shared_tuple: tuple[torch.Tensor, ...], + my_rank: hl.constexpr, +) -> torch.Tensor: + _, world_size = local_signal_pad.size() + world_size = hl.specialize(world_size) + out = torch.empty_like(a_shared_tuple[0]) + N = out.size(0) + + for tile_n in hl.tile(N): + for multicast_tile in hl.tile(world_size, block_size=world_size): + # offset the barrier pointers in bytes. 4 bytes per torch.int32 barrier. + peer_bar_offset = (tile_n.id * world_size + my_rank) * 4 # pyright: ignore[reportOperatorIssue] + hl.signal( + signal_pad_addrs[multicast_tile] + peer_bar_offset, + wait_for=0, + signal=1, + op="atomic_cas", + sem="relaxed", + scope="sys", + skip_sync=True, + as_ptrs=True, + ) + hl.wait( + local_signal_pad, + [tile_n.id, multicast_tile], + signal=1, + update=0, + scope="sys", + op="atomic_cas", + ) + + acc = hl.zeros([tile_n], dtype=torch.float32, device=local_signal_pad.device) + + # TODO(joydddd): support indexing into a tuple with iterator from tl.static_range + # For now, manually unroll the loop + acc += a_shared_tuple[0][tile_n] + acc += a_shared_tuple[1][tile_n] + acc += a_shared_tuple[2][tile_n] + acc += a_shared_tuple[3][tile_n] + acc += a_shared_tuple[4][tile_n] + acc += a_shared_tuple[5][tile_n] + acc += a_shared_tuple[6][tile_n] + acc += a_shared_tuple[7][tile_n] + + out[tile_n] = acc + + for multicast_tile in hl.tile(world_size, block_size=world_size): + peer_bar_offset = (tile_n.id * world_size + my_rank) * 4 # pyright: ignore[reportOperatorIssue] + hl.signal( + signal_pad_addrs[multicast_tile] + peer_bar_offset, + wait_for=0, + signal=1, + op="atomic_cas", + sem="relaxed", + scope="sys", + as_ptrs=True, + ) + hl.wait( + local_signal_pad, + [tile_n.id, multicast_tile], + signal=1, + update=0, + scope="sys", + op="atomic_cas", + skip_sync=True, + ) + return out + + +def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: + assert dist.group.WORLD is not None + + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) + + a_shared_tuple = tuple( + [ + symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype) + for i in range(symm_mem_hdl.world_size) + ] + ) + + local_signal_pad = symm_mem_hdl.get_signal_pad( + symm_mem_hdl.rank, dtype=torch.int32 + ).view(-1, symm_mem_hdl.world_size) + + signal_pad_addrs = dev_array_to_tensor_short( + symm_mem_hdl.signal_pad_ptrs_dev, + (symm_mem_hdl.world_size,), + dtype=torch.uint64, + device=a_shared.device, + ) + + assert symm_mem_hdl.world_size == 8 + + return one_shot_all_reduce_kernel_8( + signal_pad_addrs, + local_signal_pad, + a_shared_tuple, + my_rank=symm_mem_hdl.rank, + ) + + +def test(N: int, device: torch.device, dtype: torch.dtype) -> None: + dist_group = dist.group.WORLD + assert dist_group is not None + + world_size = dist.get_world_size() + a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_() + + a_shared_clone = symm_mem.empty( + a_shared.shape, + dtype=a_shared.dtype, + device=a_shared.device, + ) + symm_mem.rendezvous(a_shared_clone, dist_group.group_name) + a_shared_clone.copy_(a_shared) + + a_out = helion_one_shot_all_reduce(a_shared) + + gloden_o = torch.ops.symm_mem.one_shot_all_reduce( + a_shared_clone, "sum", dist_group.group_name + ) + + torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1) + + +def main() -> None: + rank = int(os.environ["LOCAL_RANK"]) + torch.manual_seed(42 + rank) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + dist.init_process_group("nccl") + test(16384, device, torch.bfloat16) + + dist.destroy_process_group() + + +if __name__ == "__main__": + """ + Run with: + torchrun \ + --nnodes 1 --nproc-per-node 8 \ + --rdzv-backend c10d --rdzv-endpoint localhost:0 \ + --no_python python3 examples/all_reduce.py + """ + main() diff --git a/helion/language/creation_ops.py b/helion/language/creation_ops.py index b22dc090..3352ca2b 100644 --- a/helion/language/creation_ops.py +++ b/helion/language/creation_ops.py @@ -18,7 +18,11 @@ __all__ = ["arange", "full", "zeros"] -def zeros(shape: list[object], dtype: torch.dtype = torch.float32) -> torch.Tensor: +def zeros( + shape: list[object], + dtype: torch.dtype = torch.float32, + device: torch.device | None = None, +) -> torch.Tensor: """ Return a device-tensor filled with zeros. @@ -54,12 +58,17 @@ def process_kernel(input: torch.Tensor) -> torch.Tensor: - :func:`~helion.language.full`: For filling with arbitrary values - :func:`~helion.language.arange`: For creating sequences """ - return full(shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype) + return full( + shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype, device=device + ) @_decorators.api(tiles_as_sizes=True) def full( - shape: list[object], value: float, dtype: torch.dtype = torch.float32 + shape: list[object], + value: float, + dtype: torch.dtype = torch.float32, + device: torch.device | None = None, ) -> torch.Tensor: """ Create a device-tensor filled with a specified value. @@ -103,6 +112,7 @@ def _full_fake( shape: list[int | torch.SymInt], value: float, dtype: torch.dtype = torch.float32, + device: torch.device | None = None, ) -> torch.Tensor: if not isinstance(shape, (list, tuple)): raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}") @@ -111,7 +121,7 @@ def _full_fake( return torch.empty( [*shape], dtype=dtype, - device=env.device, + device=env.device if device is None else device, ) @@ -147,6 +157,7 @@ def _( def arange( *args: int, dtype: torch.dtype | None = None, + device: torch.device | None = None, **kwargs: object, ) -> torch.Tensor: """ @@ -175,5 +186,5 @@ def arange( *args, **kwargs, dtype=dtype, - device=env.device, + device=env.device if device is None else device, ) diff --git a/test/test_distributed.py b/test/test_distributed.py new file mode 100644 index 00000000..c5533776 --- /dev/null +++ b/test/test_distributed.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +from torch.testing._internal.common_distributed import MultiProcessTestCase +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_utils import instantiate_parametrized_tests +from torch.testing._internal.common_utils import run_tests + +import helion +from helion._testing import code_and_output +import helion.language as hl + + +@helion.jit +def symm_mem_sync_kernel( + remote_signal_pad_ptrs: torch.Tensor, + local_signal_pad: torch.Tensor, + rank: hl.constexpr, +) -> None: + N, world_size = local_signal_pad.size() + world_size = hl.specialize(world_size) + + assert world_size == remote_signal_pad_ptrs.size(0) + for n in hl.grid(N): + for multicast_tile in hl.tile(world_size, block_size=world_size): + peer_bar_offset = (n * world_size + rank) * 4 + hl.signal( + remote_signal_pad_ptrs[multicast_tile] + peer_bar_offset, + wait_for=0, + signal=1, + op="atomic_cas", + sem="relaxed", + scope="sys", + skip_sync=True, + as_ptrs=True, + ) + hl.wait( + local_signal_pad, + [n, multicast_tile], + signal=1, + update=0, + scope="sys", + op="atomic_cas", + ) + + +@instantiate_parametrized_tests +class SymmMemBarrier(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + # world_size > 2 is needed to verify accumulation order + return 4 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + torch.manual_seed(42 + self.rank) + + @skip_if_lt_x_gpu(4) + def test_symm_mem_barrier(self): + self._init_process() + t = symm_mem.empty(4096, device=self.device) + symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD) + local_signal_pad_t = symm_mem_hdl.get_signal_pad( + symm_mem_hdl.rank, (32, symm_mem_hdl.world_size), dtype=torch.int32 + ) + signal_pad_pointers_t = torch.as_tensor( + symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64 + ).to(self.device) + + code, result = code_and_output( + symm_mem_sync_kernel, + ( + signal_pad_pointers_t, + local_signal_pad_t, + symm_mem_hdl.rank, + ), + ) + + signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank) + assert signal_pad.eq(0).all().item() + + dist.destroy_process_group() + + +if __name__ == "__main__": + run_tests()