Skip to content

Commit 2c79dd9

Browse files
committed
One shot all reduce & symm mem sync
stack-info: PR: #245, branch: joydddd/stack/12
1 parent dbb9579 commit 2c79dd9

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

helion/language/signal_wait.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _(
8383
f"Invalid memory semantic '{sem}'. Must be one of {valid_sems}."
8484
)
8585

86-
if op == "atomic_cas" and not update:
86+
if op == "atomic_cas" and update is None:
8787
raise ValueError(
8888
f"{op} without an update value. Do you want to use 'ld' instead? "
8989
)

test/test_distributed.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.distributed._symmetric_memory as symm_mem
6+
from torch.testing._internal.common_distributed import MultiProcessTestCase
7+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
8+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
9+
from torch.testing._internal.common_utils import run_tests
10+
11+
import helion
12+
import helion.language as hl
13+
14+
from helion._testing import code_and_output
15+
16+
17+
@helion.jit
18+
def symm_mem_sync_kernel(
19+
remote_signal_pad_pointer: torch.Tensor, # shape[world_size]
20+
local_signal_pad: torch.Tensor,
21+
rank: int,
22+
) -> None:
23+
N, world_size = local_signal_pad.size()
24+
for n in hl.grid(N):
25+
for tile in hl.tile(world_size, block_size=world_size):
26+
peer_bars = remote_signal_pad_pointer[tile] + n * world_size + rank
27+
hl.signal(peer_bars, [tile], signal=1, scope="sys", skip_sync=True)
28+
hl.wait(local_signal_pad, [n, tile], signal=1, update=0, scope="sys", op="atomic_cas")
29+
30+
31+
@instantiate_parametrized_tests
32+
class SymmMemBarrier(MultiProcessTestCase):
33+
def setUp(self) -> None:
34+
super().setUp()
35+
self._spawn_processes()
36+
37+
@property
38+
def world_size(self) -> int:
39+
# world_size > 2 is needed to verify accumulation order
40+
return 4
41+
42+
@property
43+
def device(self) -> torch.device:
44+
return torch.device(f"cuda:{self.rank}")
45+
46+
def _init_process(self):
47+
torch.cuda.set_device(self.device)
48+
store = dist.FileStore(self.file_name, self.world_size)
49+
dist.init_process_group(
50+
backend="nccl",
51+
world_size=self.world_size,
52+
rank=self.rank,
53+
store=store,
54+
)
55+
torch.manual_seed(42 + self.rank)
56+
57+
@skip_if_lt_x_gpu(4)
58+
def test_symm_mem_barrier(self):
59+
self._init_process()
60+
t = symm_mem.empty(4096, device=self.device)
61+
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
62+
local_signal_pad_t = symm_mem_hdl.get_buffer(symm_mem_hdl.rank, (32, symm_mem_hdl.world_size), dtype=torch.int32)
63+
signa_pad_pointers_t = torch.as_tensor(symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64).to(self.device)
64+
65+
code, result = code_and_output(symm_mem_sync_kernel, (signa_pad_pointers_t, local_signal_pad_t, symm_mem_hdl.rank,))
66+
67+
signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank)
68+
assert signal_pad.eq(0).all().item()
69+
70+
dist.destroy_process_group()
71+
72+
73+
if __name__ == "__main__":
74+
run_tests()

0 commit comments

Comments
 (0)