Skip to content

Commit eaf93f8

Browse files
committed
One shot all reduce & symm mem sync
stack-info: PR: #245, branch: joydddd/stack/12
1 parent 0a2f058 commit eaf93f8

File tree

3 files changed

+309
-5
lines changed

3 files changed

+309
-5
lines changed

examples/all_reduce.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
from __future__ import annotations
2+
3+
import os
4+
5+
import torch
6+
import torch.distributed as dist
7+
import torch.distributed._symmetric_memory as symm_mem
8+
import triton
9+
import triton.language as tl
10+
11+
import helion
12+
import helion.language as hl
13+
14+
15+
# Symmemtric Memory Helpers
16+
@triton.jit
17+
def triton_copy(
18+
inp: tl.int64, # pyright: ignore[reportInvalidTypeForm]
19+
out: tl.tensor,
20+
SIZE: tl.constexpr,
21+
) -> None:
22+
tl.static_assert(out.dtype.is_ptr())
23+
inp = inp.to(tl.pointer_type(out.dtype.element_ty)) # pyright: ignore[reportAttributeAccessIssue]
24+
addrs = tl.load(inp + tl.arange(0, SIZE))
25+
tl.store(out + tl.arange(0, SIZE), addrs)
26+
27+
28+
def dev_array_to_tensor_short(
29+
dev_array_ptr: int, shape: tuple[int], dtype: torch.dtype, device: torch.device
30+
) -> torch.Tensor:
31+
tensor = torch.empty(shape, dtype=dtype, device=device)
32+
triton_copy[1,](dev_array_ptr, tensor, tensor.numel()) # pyright: ignore[reportArgumentType]
33+
return tensor
34+
35+
36+
@helion.jit(
37+
config=helion.Config(
38+
block_sizes=[8192],
39+
num_warps=32,
40+
),
41+
)
42+
def one_shot_all_reduce_kernel_8(
43+
signal_pad_addrs: torch.Tensor,
44+
local_signal_pad: torch.Tensor,
45+
a_shared_tuple: tuple[torch.Tensor, ...],
46+
my_rank: hl.constexpr,
47+
) -> torch.Tensor:
48+
_, world_size = local_signal_pad.size()
49+
world_size = hl.specialize(world_size)
50+
out = torch.empty_like(a_shared_tuple[0])
51+
N = out.size(0)
52+
53+
for tile_n in hl.tile(N):
54+
for multicast_tile in hl.tile(world_size, block_size=world_size):
55+
# offset the barrier pointers in bytes. 4 bytes per torch.int32 barrier.
56+
peer_bar_offset = (tile_n.id * world_size + my_rank) * 4 # pyright: ignore[reportOperatorIssue]
57+
hl.signal(
58+
signal_pad_addrs[multicast_tile] + peer_bar_offset,
59+
wait_for=0,
60+
signal=1,
61+
op="atomic_cas",
62+
sem="relaxed",
63+
scope="sys",
64+
skip_sync=True,
65+
as_ptrs=True,
66+
)
67+
hl.wait(
68+
local_signal_pad,
69+
[tile_n.id, multicast_tile],
70+
signal=1,
71+
update=0,
72+
scope="sys",
73+
op="atomic_cas",
74+
)
75+
76+
acc = hl.zeros([tile_n], dtype=torch.float32, device=local_signal_pad.device)
77+
78+
# TODO(joydddd): support indexing into a tuple with iterator from tl.static_range
79+
# For now, manually unroll the loop
80+
acc += a_shared_tuple[0][tile_n]
81+
acc += a_shared_tuple[1][tile_n]
82+
acc += a_shared_tuple[2][tile_n]
83+
acc += a_shared_tuple[3][tile_n]
84+
acc += a_shared_tuple[4][tile_n]
85+
acc += a_shared_tuple[5][tile_n]
86+
acc += a_shared_tuple[6][tile_n]
87+
acc += a_shared_tuple[7][tile_n]
88+
89+
out[tile_n] = acc
90+
91+
for multicast_tile in hl.tile(world_size, block_size=world_size):
92+
peer_bar_offset = (tile_n.id * world_size + my_rank) * 4 # pyright: ignore[reportOperatorIssue]
93+
hl.signal(
94+
signal_pad_addrs[multicast_tile] + peer_bar_offset,
95+
wait_for=0,
96+
signal=1,
97+
op="atomic_cas",
98+
sem="relaxed",
99+
scope="sys",
100+
as_ptrs=True,
101+
)
102+
hl.wait(
103+
local_signal_pad,
104+
[tile_n.id, multicast_tile],
105+
signal=1,
106+
update=0,
107+
scope="sys",
108+
op="atomic_cas",
109+
skip_sync=True,
110+
)
111+
return out
112+
113+
114+
def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
115+
assert dist.group.WORLD is not None
116+
117+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
118+
119+
a_shared_tuple = tuple(
120+
[
121+
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
122+
for i in range(symm_mem_hdl.world_size)
123+
]
124+
)
125+
126+
local_signal_pad = symm_mem_hdl.get_signal_pad(
127+
symm_mem_hdl.rank, dtype=torch.int32
128+
).view(-1, symm_mem_hdl.world_size)
129+
130+
signal_pad_addrs = dev_array_to_tensor_short(
131+
symm_mem_hdl.signal_pad_ptrs_dev,
132+
(symm_mem_hdl.world_size,),
133+
dtype=torch.uint64,
134+
device=a_shared.device,
135+
)
136+
137+
assert symm_mem_hdl.world_size == 8
138+
139+
return one_shot_all_reduce_kernel_8(
140+
signal_pad_addrs,
141+
local_signal_pad,
142+
a_shared_tuple,
143+
my_rank=symm_mem_hdl.rank,
144+
)
145+
146+
147+
def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
148+
dist_group = dist.group.WORLD
149+
assert dist_group is not None
150+
151+
world_size = dist.get_world_size()
152+
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
153+
154+
a_shared_clone = symm_mem.empty(
155+
a_shared.shape,
156+
dtype=a_shared.dtype,
157+
device=a_shared.device,
158+
)
159+
symm_mem.rendezvous(a_shared_clone, dist_group.group_name)
160+
a_shared_clone.copy_(a_shared)
161+
162+
a_out = helion_one_shot_all_reduce(a_shared)
163+
164+
gloden_o = torch.ops.symm_mem.one_shot_all_reduce(
165+
a_shared_clone, "sum", dist_group.group_name
166+
)
167+
168+
torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1)
169+
170+
171+
def main() -> None:
172+
rank = int(os.environ["LOCAL_RANK"])
173+
torch.manual_seed(42 + rank)
174+
device = torch.device(f"cuda:{rank}")
175+
torch.cuda.set_device(device)
176+
dist.init_process_group("nccl")
177+
test(16384, device, torch.bfloat16)
178+
179+
dist.destroy_process_group()
180+
181+
182+
if __name__ == "__main__":
183+
"""
184+
Run with:
185+
torchrun \
186+
--nnodes 1 --nproc-per-node 8 \
187+
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
188+
--no_python python3 examples/all_reduce.py
189+
"""
190+
main()

helion/language/creation_ops.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
__all__ = ["arange", "full", "zeros"]
1919

2020

21-
def zeros(shape: list[object], dtype: torch.dtype = torch.float32) -> torch.Tensor:
21+
def zeros(
22+
shape: list[object],
23+
dtype: torch.dtype = torch.float32,
24+
device: torch.device | None = None,
25+
) -> torch.Tensor:
2226
"""
2327
Return a device-tensor filled with zeros.
2428
@@ -54,12 +58,17 @@ def process_kernel(input: torch.Tensor) -> torch.Tensor:
5458
- :func:`~helion.language.full`: For filling with arbitrary values
5559
- :func:`~helion.language.arange`: For creating sequences
5660
"""
57-
return full(shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype)
61+
return full(
62+
shape, 0.0 if dtype.is_floating_point else 0, dtype=dtype, device=device
63+
)
5864

5965

6066
@_decorators.api(tiles_as_sizes=True)
6167
def full(
62-
shape: list[object], value: float, dtype: torch.dtype = torch.float32
68+
shape: list[object],
69+
value: float,
70+
dtype: torch.dtype = torch.float32,
71+
device: torch.device | None = None,
6372
) -> torch.Tensor:
6473
"""
6574
Create a device-tensor filled with a specified value.
@@ -103,6 +112,7 @@ def _full_fake(
103112
shape: list[int | torch.SymInt],
104113
value: float,
105114
dtype: torch.dtype = torch.float32,
115+
device: torch.device | None = None,
106116
) -> torch.Tensor:
107117
if not isinstance(shape, (list, tuple)):
108118
raise TypeError(f"Expected list[SymInt], got {type(shape).__name__}")
@@ -111,7 +121,7 @@ def _full_fake(
111121
return torch.empty(
112122
[*shape],
113123
dtype=dtype,
114-
device=env.device,
124+
device=env.device if device is None else device,
115125
)
116126

117127

@@ -147,6 +157,7 @@ def _(
147157
def arange(
148158
*args: int,
149159
dtype: torch.dtype | None = None,
160+
device: torch.device | None = None,
150161
**kwargs: object,
151162
) -> torch.Tensor:
152163
"""
@@ -175,5 +186,5 @@ def arange(
175186
*args,
176187
**kwargs,
177188
dtype=dtype,
178-
device=env.device,
189+
device=env.device if device is None else device,
179190
)

test/test_distributed.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.distributed._symmetric_memory as symm_mem
6+
from torch.testing._internal.common_distributed import MultiProcessTestCase
7+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
8+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
9+
from torch.testing._internal.common_utils import run_tests
10+
11+
import helion
12+
from helion._testing import code_and_output
13+
import helion.language as hl
14+
15+
16+
@helion.jit
17+
def symm_mem_sync_kernel(
18+
remote_signal_pad_ptrs: torch.Tensor,
19+
local_signal_pad: torch.Tensor,
20+
rank: hl.constexpr,
21+
) -> None:
22+
N, world_size = local_signal_pad.size()
23+
world_size = hl.specialize(world_size)
24+
25+
assert world_size == remote_signal_pad_ptrs.size(0)
26+
for n in hl.grid(N):
27+
for multicast_tile in hl.tile(world_size, block_size=world_size):
28+
peer_bar_offset = (n * world_size + rank) * 4
29+
hl.signal(
30+
remote_signal_pad_ptrs[multicast_tile] + peer_bar_offset,
31+
wait_for=0,
32+
signal=1,
33+
op="atomic_cas",
34+
sem="relaxed",
35+
scope="sys",
36+
skip_sync=True,
37+
as_ptrs=True,
38+
)
39+
hl.wait(
40+
local_signal_pad,
41+
[n, multicast_tile],
42+
signal=1,
43+
update=0,
44+
scope="sys",
45+
op="atomic_cas",
46+
)
47+
48+
49+
@instantiate_parametrized_tests
50+
class SymmMemBarrier(MultiProcessTestCase):
51+
def setUp(self) -> None:
52+
super().setUp()
53+
self._spawn_processes()
54+
55+
@property
56+
def world_size(self) -> int:
57+
# world_size > 2 is needed to verify accumulation order
58+
return 4
59+
60+
@property
61+
def device(self) -> torch.device:
62+
return torch.device(f"cuda:{self.rank}")
63+
64+
def _init_process(self):
65+
torch.cuda.set_device(self.device)
66+
store = dist.FileStore(self.file_name, self.world_size)
67+
dist.init_process_group(
68+
backend="nccl",
69+
world_size=self.world_size,
70+
rank=self.rank,
71+
store=store,
72+
)
73+
torch.manual_seed(42 + self.rank)
74+
75+
@skip_if_lt_x_gpu(4)
76+
def test_symm_mem_barrier(self):
77+
self._init_process()
78+
t = symm_mem.empty(4096, device=self.device)
79+
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
80+
local_signal_pad_t = symm_mem_hdl.get_signal_pad(
81+
symm_mem_hdl.rank, (32, symm_mem_hdl.world_size), dtype=torch.int32
82+
)
83+
signal_pad_pointers_t = torch.as_tensor(
84+
symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64
85+
).to(self.device)
86+
87+
code, result = code_and_output(
88+
symm_mem_sync_kernel,
89+
(
90+
signal_pad_pointers_t,
91+
local_signal_pad_t,
92+
symm_mem_hdl.rank,
93+
),
94+
)
95+
96+
signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank)
97+
assert signal_pad.eq(0).all().item()
98+
99+
dist.destroy_process_group()
100+
101+
102+
if __name__ == "__main__":
103+
run_tests()

0 commit comments

Comments
 (0)