Skip to content

One shot all reduce & symm mem sync #245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: joydddd/stack/16
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions examples/all_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from __future__ import annotations

import os

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

import helion
import helion.language as hl


@helion.jit(
config=helion.Config(
block_sizes=[4096],
num_warps=32,
),
)
def one_shot_all_reduce_kernel_8(
signal_pad_addrs: torch.Tensor,
local_signal_pad: torch.Tensor,
a_shared_tuple: tuple[torch.Tensor, ...],
my_rank: hl.constexpr,
) -> torch.Tensor:
_, world_size = local_signal_pad.size()
world_size = hl.specialize(world_size)
out = torch.empty_like(a_shared_tuple[0])
N = out.size(0)

for tile_n in hl.tile(N):
for multicast_tile in hl.tile(world_size, block_size=world_size):
# offset the barrier pointers in btyes. 4 bytes per torch.int32 barrier.
peer_bar_offset = (tile_n.id * world_size + my_rank) * 4 # pyright: ignore[reportOperatorIssue]
hl.signal(
signal_pad_addrs[multicast_tile] + peer_bar_offset,
wait_for=0,
signal=1,
op="atomic_cas",
sem="relaxed",
scope="sys",
skip_sync=True,
as_ptrs=True,
)
hl.wait(
local_signal_pad,
[tile_n.id, multicast_tile],
signal=1,
update=0,
scope="sys",
op="atomic_cas",
)

acc = hl.zeros([tile_n], dtype=torch.float32, device=local_signal_pad.device)

# TODO(joydddd): support indexing into a tuple with iterator from tl.static_range
# For now, manually unroll the loop
acc += a_shared_tuple[0][tile_n]
acc += a_shared_tuple[1][tile_n]
acc += a_shared_tuple[2][tile_n]
acc += a_shared_tuple[3][tile_n]
acc += a_shared_tuple[4][tile_n]
acc += a_shared_tuple[5][tile_n]
acc += a_shared_tuple[6][tile_n]
acc += a_shared_tuple[7][tile_n]

out[tile_n] = acc

for multicast_tile in hl.tile(world_size, block_size=world_size):
peer_bar_offset = (tile_n.id * world_size + my_rank) * 4 # pyright: ignore[reportOperatorIssue]
hl.signal(
signal_pad_addrs[multicast_tile] + peer_bar_offset,
wait_for=0,
signal=1,
op="atomic_cas",
sem="relaxed",
scope="sys",
as_ptrs=True,
)
hl.wait(
local_signal_pad,
[tile_n.id, multicast_tile],
signal=1,
update=0,
scope="sys",
op="atomic_cas",
skip_sync=True,
)
return out


def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
assert dist.group.WORLD is not None

symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)

a_shared_tuple = tuple(
[
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
for i in range(symm_mem_hdl.world_size)
]
)

local_signal_pad = symm_mem_hdl.get_signal_pad(
symm_mem_hdl.rank, dtype=torch.int32
).view(-1, symm_mem_hdl.world_size)

signal_pad_addrs = torch.as_tensor(
symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64
).to(a_shared.device)

assert symm_mem_hdl.world_size == 8

return one_shot_all_reduce_kernel_8(
signal_pad_addrs,
local_signal_pad,
a_shared_tuple,
my_rank=symm_mem_hdl.rank,
)


def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
dist_group = dist.group.WORLD
assert dist_group is not None

world_size = dist.get_world_size()
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()

a_shared_clone = symm_mem.empty(
a_shared.shape,
dtype=a_shared.dtype,
device=a_shared.device,
)
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
a_shared_clone.copy_(a_shared)

a_out = helion_one_shot_all_reduce(a_shared)

gloden_o = torch.ops.symm_mem.one_shot_all_reduce(
a_shared_clone, "sum", dist_group.group_name
)

torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1)


def main() -> None:
rank = int(os.environ["LOCAL_RANK"])
torch.manual_seed(42 + rank)
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
dist.init_process_group("nccl")
test(16384, device, torch.bfloat16)

dist.destroy_process_group()


if __name__ == "__main__":
"""
Run with:
torchrun \
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 examples/all_reduce.py
"""
main()
21 changes: 16 additions & 5 deletions helion/language/creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
__all__ = ["arange", "full", "zeros"]


def zeros(shape: list[object], dtype: torch.dtype = torch.float32) -> torch.Tensor:
def zeros(
shape: list[object],
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
"""
Return a device-tensor filled with zeros.

Expand Down Expand Up @@ -54,12 +58,17 @@ def process_kernel(input: torch.Tensor) -> torch.Tensor:
- :func:`~helion.language.full`: For filling with arbitrary values
- :func:`~helion.language.arange`: For creating sequences
"""
return full(shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype)
return full(
shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype, device=device
)


@_decorators.api(tiles_as_sizes=True)
def full(
shape: list[object], value: float, dtype: torch.dtype = torch.float32
shape: list[object],
value: float,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
"""
Create a device-tensor filled with a specified value.
Expand Down Expand Up @@ -103,6 +112,7 @@ def _full_fake(
shape: list[int | torch.SymInt],
value: float,
dtype: torch.dtype = torch.float32,
device: torch.device | None = None,
) -> torch.Tensor:
if not isinstance(shape, (list, tuple)):
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
Expand All @@ -111,7 +121,7 @@ def _full_fake(
return torch.empty(
[*shape],
dtype=dtype,
device=env.device,
device=env.device if device is None else device,
)


Expand Down Expand Up @@ -147,6 +157,7 @@ def _(
def arange(
*args: int,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
**kwargs: object,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -175,5 +186,5 @@ def arange(
*args,
**kwargs,
dtype=dtype,
device=env.device,
device=env.device if device is None else device,
)
103 changes: 103 additions & 0 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import instantiate_parametrized_tests
from torch.testing._internal.common_utils import run_tests

import helion
from helion._testing import code_and_output
import helion.language as hl


@helion.jit
def symm_mem_sync_kernel(
remote_signal_pad_ptrs: torch.Tensor,
local_signal_pad: torch.Tensor,
rank: hl.constexpr,
) -> None:
N, world_size = local_signal_pad.size()
world_size = hl.specialize(world_size)

assert world_size == remote_signal_pad_ptrs.size(0)
for n in hl.grid(N):
for multicast_tile in hl.tile(world_size, block_size=world_size):
peer_bar_offset = (n * world_size + rank) * 4
hl.signal(
remote_signal_pad_ptrs[multicast_tile] + peer_bar_offset,
wait_for=0,
signal=1,
op="atomic_cas",
sem="relaxed",
scope="sys",
skip_sync=True,
as_ptrs=True,
)
hl.wait(
local_signal_pad,
[n, multicast_tile],
signal=1,
update=0,
scope="sys",
op="atomic_cas",
)


@instantiate_parametrized_tests
class SymmMemBarrier(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()

@property
def world_size(self) -> int:
# world_size > 2 is needed to verify accumulation order
return 4

@property
def device(self) -> torch.device:
return torch.device(f"cuda:{self.rank}")

def _init_process(self):
torch.cuda.set_device(self.device)
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend="nccl",
world_size=self.world_size,
rank=self.rank,
store=store,
)
torch.manual_seed(42 + self.rank)

@skip_if_lt_x_gpu(4)
def test_symm_mem_barrier(self):
self._init_process()
t = symm_mem.empty(4096, device=self.device)
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
local_signal_pad_t = symm_mem_hdl.get_signal_pad(
symm_mem_hdl.rank, (32, symm_mem_hdl.world_size), dtype=torch.int32
)
signal_pad_pointers_t = torch.as_tensor(
symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64
).to(self.device)

code, result = code_and_output(
symm_mem_sync_kernel,
(
signal_pad_pointers_t,
local_signal_pad_t,
symm_mem_hdl.rank,
),
)

signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank)
assert signal_pad.eq(0).all().item()

dist.destroy_process_group()


if __name__ == "__main__":
run_tests()
Loading