Skip to content

Commit 20c357c

Browse files
committed
One shot all reduce & symm mem sync
stack-info: PR: #245, branch: joydddd/stack/12
1 parent 8f3f053 commit 20c357c

File tree

2 files changed

+259
-0
lines changed

2 files changed

+259
-0
lines changed

examples/all_reduce.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
import torch
6+
import torch.distributed as dist
7+
import torch.distributed._symmetric_memory as symm_mem
8+
9+
import helion
10+
import helion.language as hl
11+
12+
13+
@helion.jit(
14+
config=helion.Config(
15+
block_sizes=[4096],
16+
num_warps=32,
17+
),
18+
static_shapes=True,
19+
)
20+
def one_shot_all_reduce_kernel_8(
21+
signal_pad_addrs: torch.Tensor,
22+
local_signal_pad: torch.Tensor,
23+
a_shared_tuple: tuple[torch.Tensor, ...],
24+
my_rank: hl.constexpr,
25+
):
26+
_, world_size = local_signal_pad.size()
27+
world_size = hl.specialize(world_size)
28+
out = torch.empty_like(a_shared_tuple[0])
29+
N = out.size(0)
30+
31+
for tile_n in hl.tile(N):
32+
peer_bar_offset = (
33+
tile_n.id * world_size + my_rank
34+
) * 4 # offset the barrier pointers in btyes. 4 bytes per torch.int32 barrier.
35+
for multicast_tile in hl.tile(world_size, block_size=world_size):
36+
hl.signal(
37+
signal_pad_addrs[multicast_tile] + peer_bar_offset,
38+
wait_for=0,
39+
signal=1,
40+
op="atomic_cas",
41+
sem="relaxed",
42+
scope="sys",
43+
skip_sync=True,
44+
as_ptrs=True,
45+
)
46+
hl.wait(
47+
local_signal_pad,
48+
[tile_n.id, multicast_tile],
49+
signal=1,
50+
update=0,
51+
scope="sys",
52+
op="atomic_cas",
53+
)
54+
55+
acc = torch.zeros([tile_n], dtype=torch.float32, device=out.device)
56+
57+
# TODO(joydddd): support indexing into a tuple with iterator from tl.static_range
58+
# For now, manually unroll the loop
59+
acc += a_shared_tuple[0][tile_n]
60+
acc += a_shared_tuple[1][tile_n]
61+
acc += a_shared_tuple[2][tile_n]
62+
acc += a_shared_tuple[3][tile_n]
63+
acc += a_shared_tuple[4][tile_n]
64+
acc += a_shared_tuple[5][tile_n]
65+
acc += a_shared_tuple[6][tile_n]
66+
acc += a_shared_tuple[7][tile_n]
67+
68+
out[tile_n] = acc
69+
70+
for multicast_tile in hl.tile(world_size, block_size=world_size):
71+
hl.signal(
72+
signal_pad_addrs[multicast_tile] + peer_bar_offset,
73+
wait_for=0,
74+
signal=1,
75+
op="atomic_cas",
76+
sem="relaxed",
77+
scope="sys",
78+
as_ptrs=True,
79+
)
80+
hl.wait(
81+
local_signal_pad,
82+
[tile_n.id, multicast_tile],
83+
signal=1,
84+
update=0,
85+
scope="sys",
86+
op="atomic_cas",
87+
skip_sync=True,
88+
)
89+
return out
90+
91+
92+
def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
93+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
94+
95+
a_shared_tuple = tuple(
96+
[
97+
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
98+
for i in range(symm_mem_hdl.world_size)
99+
]
100+
)
101+
102+
local_signal_pad = symm_mem_hdl.get_signal_pad(
103+
symm_mem_hdl.rank, dtype=torch.int32
104+
).view(-1, symm_mem_hdl.world_size)
105+
106+
signal_pad_addrs = torch.as_tensor(
107+
symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64
108+
).to(a_shared.device)
109+
110+
return one_shot_all_reduce_kernel_8(
111+
signal_pad_addrs,
112+
local_signal_pad,
113+
a_shared_tuple,
114+
my_rank=symm_mem_hdl.rank,
115+
)
116+
117+
118+
def test(N: int, device: torch.device, dtype=torch.bfloat16) -> None:
119+
world_size = dist.get_world_size()
120+
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
121+
122+
golden_a = a_shared.clone()
123+
a_out = helion_one_shot_all_reduce(a_shared)
124+
125+
dist_group = dist.group.WORLD
126+
if dist_group is None:
127+
raise RuntimeError("No distributed group available")
128+
129+
print(a_out)
130+
gloden_o = torch.ops.symm_mem.one_shot_all_reduce(
131+
golden_a, "sum", dist_group.group_name
132+
)
133+
134+
torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1)
135+
136+
137+
def main() -> None:
138+
rank = int(os.environ["LOCAL_RANK"])
139+
torch.manual_seed(42 + rank)
140+
device = torch.device(f"cuda:{rank}")
141+
torch.cuda.set_device(device)
142+
dist.init_process_group("nccl")
143+
test(16384, device)
144+
145+
dist.destroy_process_group()
146+
147+
148+
if __name__ == "__main__":
149+
"""
150+
Run with:
151+
torchrun \
152+
--nnodes 1 --nproc-per-node 8 \
153+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
154+
--no_python python3 examples/all_reduce.py
155+
"""
156+
main()

test/test_distributed.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.distributed._symmetric_memory as symm_mem
6+
from torch.testing._internal.common_distributed import MultiProcessTestCase
7+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
8+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
9+
from torch.testing._internal.common_utils import run_tests
10+
11+
import helion
12+
from helion._testing import code_and_output
13+
import helion.language as hl
14+
15+
16+
@helion.jit
17+
def symm_mem_sync_kernel(
18+
remote_signal_pad_ptrs: torch.Tensor,
19+
local_signal_pad: torch.Tensor,
20+
rank: hl.constexpr,
21+
) -> None:
22+
N, world_size = local_signal_pad.size()
23+
world_size = hl.specialize(world_size)
24+
25+
assert world_size == remote_signal_pad_ptrs.size(0)
26+
for n in hl.grid(N):
27+
for multicast_tile in hl.tile(world_size, block_size=world_size):
28+
peer_bar_offset = (n * world_size + rank) * 4
29+
hl.signal(
30+
remote_signal_pad_ptrs[multicast_tile] + peer_bar_offset,
31+
wait_for=0,
32+
signal=1,
33+
op="atomic_cas",
34+
sem="relaxed",
35+
scope="sys",
36+
skip_sync=True,
37+
as_ptrs=True,
38+
)
39+
hl.wait(
40+
local_signal_pad,
41+
[n, multicast_tile],
42+
signal=1,
43+
update=0,
44+
scope="sys",
45+
op="atomic_cas",
46+
)
47+
48+
49+
@instantiate_parametrized_tests
50+
class SymmMemBarrier(MultiProcessTestCase):
51+
def setUp(self) -> None:
52+
super().setUp()
53+
self._spawn_processes()
54+
55+
@property
56+
def world_size(self) -> int:
57+
# world_size > 2 is needed to verify accumulation order
58+
return 4
59+
60+
@property
61+
def device(self) -> torch.device:
62+
return torch.device(f"cuda:{self.rank}")
63+
64+
def _init_process(self):
65+
torch.cuda.set_device(self.device)
66+
store = dist.FileStore(self.file_name, self.world_size)
67+
dist.init_process_group(
68+
backend="nccl",
69+
world_size=self.world_size,
70+
rank=self.rank,
71+
store=store,
72+
)
73+
torch.manual_seed(42 + self.rank)
74+
75+
@skip_if_lt_x_gpu(4)
76+
def test_symm_mem_barrier(self):
77+
self._init_process()
78+
t = symm_mem.empty(4096, device=self.device)
79+
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
80+
local_signal_pad_t = symm_mem_hdl.get_signal_pad(
81+
symm_mem_hdl.rank, (32, symm_mem_hdl.world_size), dtype=torch.int32
82+
)
83+
signal_pad_pointers_t = torch.as_tensor(
84+
symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64
85+
).to(self.device)
86+
87+
code, result = code_and_output(
88+
symm_mem_sync_kernel,
89+
(
90+
signal_pad_pointers_t,
91+
local_signal_pad_t,
92+
symm_mem_hdl.rank,
93+
),
94+
)
95+
96+
signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank)
97+
assert signal_pad.eq(0).all().item()
98+
99+
dist.destroy_process_group()
100+
101+
102+
if __name__ == "__main__":
103+
run_tests()

0 commit comments

Comments
 (0)