Skip to content

Commit 7af5d2b

Browse files
committed
One shot all reduce & symm mem sync
stack-info: PR: #245, branch: joydddd/stack/12
1 parent 9f46f56 commit 7af5d2b

File tree

3 files changed

+209
-0
lines changed

3 files changed

+209
-0
lines changed

examples/all_reduce.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.distributed._symmetric_memory as symm_mem
6+
7+
import helion
8+
9+
10+
@helion.jit(
11+
config=helion.Config(
12+
block_sizes=[24],
13+
num_warps=32,
14+
indexing="pointers",
15+
),
16+
static_shapes=True,
17+
)
18+
def one_shot_all_reduce_kernel(
19+
buffer_ptr_addrs,
20+
signal_pad_ptrs,
21+
output_ptr,
22+
numel: tl.constexpr,
23+
rank: tl.constexpr,
24+
world_size: tl.constexpr,
25+
BLOCK_SIZE: tl.constexpr,
26+
):
27+
output = torch.empty_like(x)
28+
ptx_utils.symm_mem_sync(
29+
signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
30+
)
31+
32+
pid = tl.program_id(axis=0)
33+
buffer_ptr_addrs = buffer_ptr_addrs.to(tl.pointer_type(tl.uint64))
34+
output_ptr = output_ptr.to(tl.pointer_type(tl.bfloat16))
35+
block_start = pid * BLOCK_SIZE
36+
37+
while block_start < numel:
38+
# Each thread processes 128 bits.
39+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
40+
mask = offsets < numel
41+
42+
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16)
43+
for i in range(world_size):
44+
buffer_ptr = tl.load(buffer_ptr_addrs + i).to(tl.pointer_type(tl.bfloat16))
45+
tl.multiple_of(buffer_ptr, 16)
46+
x = tl.load(buffer_ptr + offsets, mask=mask)
47+
acc += x
48+
tl.store(output_ptr + offsets, acc, mask=mask)
49+
block_start += tl.num_programs(axis=0) * BLOCK_SIZE
50+
51+
ptx_utils.symm_mem_sync(
52+
signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True
53+
)
54+
55+
56+
def one_shot_all_reduce(tensor: torch.Tensor, **kwargs) -> torch.Tensor:
57+
config = {
58+
"max_num_blocks": kwargs.get("max_num_blocks", 24),
59+
"num_warps": kwargs.get("num_warps", 32),
60+
"BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 8192),
61+
}
62+
63+
assert tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now."
64+
assert tensor.numel() % 8 == 0, "The number of elements must be 128-bit aligned."
65+
assert config["BLOCK_SIZE"] % (config["num_warps"] * 32) == 0, (
66+
"BLOCK_SIZE must be a multiple of num_warps * 32"
67+
)
68+
69+
num_blocks = min(
70+
triton.cdiv(tensor.numel(), config["BLOCK_SIZE"]), config["max_num_blocks"]
71+
)
72+
73+
symm_mem_hdl = symm_mem.rendezvous(tensor, group=dist.group.WORLD)
74+
output = torch.empty_like(tensor)
75+
76+
signal_pads = tuple(
77+
[
78+
symm_mem_hdl.get_signal_pad(i, dtype=torch.int32)
79+
for i in range(symm_mem_hdl.world_size)
80+
]
81+
)
82+
83+
one_shot_all_reduce_kernel[(num_blocks, 1, 1)](
84+
symm_mem_hdl.buffer_ptrs_dev,
85+
signal_pads,
86+
output,
87+
numel=tensor.numel(),
88+
rank=symm_mem_hdl.rank,
89+
world_size=symm_mem_hdl.world_size,
90+
BLOCK_SIZE=config["BLOCK_SIZE"],
91+
num_warps=config["num_warps"],
92+
)
93+
94+
return output

helion/_compiler/type_propagation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,18 @@ def populate_symbol_origins(self, origin: Origin) -> None:
12031203
for i, subtype in enumerate(self.element_types):
12041204
subtype.populate_symbol_origins(GetItemOrigin(origin, i))
12051205

1206+
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1207+
# Check if all elements have the same type
1208+
first_type = self.element_types[0]
1209+
for element_type in self.element_types[1:]:
1210+
if type(element_type) != type(first_type):
1211+
raise exc.TypeInferenceError(
1212+
f"Sequence contains mixed types: cannot safely index. "
1213+
f"Found {type(first_type).__name__} and {type(element_type).__name__}"
1214+
)
1215+
1216+
return first_type
1217+
12061218
def merge(self, other: TypeInfo) -> TypeInfo:
12071219
if isinstance(other, SequenceType):
12081220
self_elements = self.element_types

test/test_distributed.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.distributed._symmetric_memory as symm_mem
6+
from torch.testing._internal.common_distributed import MultiProcessTestCase
7+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
8+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
9+
from torch.testing._internal.common_utils import run_tests
10+
11+
import helion
12+
from helion._testing import code_and_output
13+
import helion.language as hl
14+
15+
16+
@helion.jit
17+
def symm_mem_sync_kernel(
18+
remote_signal_pad_ptrs: torch.Tensor,
19+
local_signal_pad: torch.Tensor,
20+
rank: hl.constexpr,
21+
) -> None:
22+
N, world_size = local_signal_pad.size()
23+
world_size = hl.specialize(world_size)
24+
25+
assert world_size == remote_signal_pad_ptrs.size(0)
26+
for n in hl.grid(N):
27+
for multicast_tile in hl.tile(world_size, block_size=world_size):
28+
peer_bar_offset = (n * world_size + rank) * 4
29+
hl.signal(
30+
remote_signal_pad_ptrs[multicast_tile] + peer_bar_offset,
31+
wait_for=0,
32+
signal=1,
33+
op="atomic_cas",
34+
sem="relaxed",
35+
scope="sys",
36+
skip_sync=True,
37+
as_ptrs=True,
38+
)
39+
hl.wait(
40+
local_signal_pad,
41+
[n, multicast_tile],
42+
signal=1,
43+
update=0,
44+
scope="sys",
45+
op="atomic_cas",
46+
)
47+
48+
49+
@instantiate_parametrized_tests
50+
class SymmMemBarrier(MultiProcessTestCase):
51+
def setUp(self) -> None:
52+
super().setUp()
53+
self._spawn_processes()
54+
55+
@property
56+
def world_size(self) -> int:
57+
# world_size > 2 is needed to verify accumulation order
58+
return 4
59+
60+
@property
61+
def device(self) -> torch.device:
62+
return torch.device(f"cuda:{self.rank}")
63+
64+
def _init_process(self):
65+
torch.cuda.set_device(self.device)
66+
store = dist.FileStore(self.file_name, self.world_size)
67+
dist.init_process_group(
68+
backend="nccl",
69+
world_size=self.world_size,
70+
rank=self.rank,
71+
store=store,
72+
)
73+
torch.manual_seed(42 + self.rank)
74+
75+
@skip_if_lt_x_gpu(4)
76+
def test_symm_mem_barrier(self):
77+
self._init_process()
78+
t = symm_mem.empty(4096, device=self.device)
79+
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
80+
local_signal_pad_t = symm_mem_hdl.get_signal_pad(
81+
symm_mem_hdl.rank, (32, symm_mem_hdl.world_size), dtype=torch.int32
82+
)
83+
signa_pad_pointers_t = torch.as_tensor(
84+
symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64
85+
).to(self.device)
86+
87+
code, result = code_and_output(
88+
symm_mem_sync_kernel,
89+
(
90+
signa_pad_pointers_t,
91+
local_signal_pad_t,
92+
symm_mem_hdl.rank,
93+
),
94+
)
95+
96+
signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank)
97+
assert signal_pad.eq(0).all().item()
98+
99+
dist.destroy_process_group()
100+
101+
102+
if __name__ == "__main__":
103+
run_tests()

0 commit comments

Comments
 (0)