Skip to content

Commit 2f52133

Browse files
committed
One shot all reduce & symm mem sync
stack-info: PR: #245, branch: joydddd/stack/12
1 parent 93dc903 commit 2f52133

File tree

3 files changed

+195
-0
lines changed

3 files changed

+195
-0
lines changed

examples/all_reduce.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.distributed._symmetric_memory as symm_mem
6+
7+
8+
import helion
9+
10+
11+
@helion.jit(
12+
config=helion.Config(
13+
block_sizes=[24],
14+
num_warps=32,
15+
indexing="pointers",
16+
),
17+
static_shapes=True,
18+
)
19+
def one_shot_all_reduce_kernel(
20+
buffer_ptr_addrs,
21+
signal_pad_ptrs,
22+
output_ptr,
23+
numel: tl.constexpr,
24+
rank: tl.constexpr,
25+
world_size: tl.constexpr,
26+
BLOCK_SIZE: tl.constexpr,
27+
):
28+
output = torch.empty_like(x)
29+
ptx_utils.symm_mem_sync(
30+
signal_pad_ptrs, None, rank, world_size, hasSubsequenceMemAccess=True
31+
)
32+
33+
pid = tl.program_id(axis=0)
34+
buffer_ptr_addrs = buffer_ptr_addrs.to(tl.pointer_type(tl.uint64))
35+
output_ptr = output_ptr.to(tl.pointer_type(tl.bfloat16))
36+
block_start = pid * BLOCK_SIZE
37+
38+
while block_start < numel:
39+
# Each thread processes 128 bits.
40+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
41+
mask = offsets < numel
42+
43+
acc = tl.zeros((BLOCK_SIZE,), dtype=tl.bfloat16)
44+
for i in range(world_size):
45+
buffer_ptr = tl.load(buffer_ptr_addrs + i).to(tl.pointer_type(tl.bfloat16))
46+
tl.multiple_of(buffer_ptr, 16)
47+
x = tl.load(buffer_ptr + offsets, mask=mask)
48+
acc += x
49+
tl.store(output_ptr + offsets, acc, mask=mask)
50+
block_start += tl.num_programs(axis=0) * BLOCK_SIZE
51+
52+
ptx_utils.symm_mem_sync(
53+
signal_pad_ptrs, None, rank, world_size, hasPreviousMemAccess=True
54+
)
55+
56+
57+
def one_shot_all_reduce(tensor: torch.Tensor, **kwargs) -> torch.Tensor:
58+
config = {
59+
"max_num_blocks": kwargs.get("max_num_blocks", 24),
60+
"num_warps": kwargs.get("num_warps", 32),
61+
"BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 8192),
62+
}
63+
64+
assert tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now."
65+
assert tensor.numel() % 8 == 0, "The number of elements must be 128-bit aligned."
66+
assert config["BLOCK_SIZE"] % (config["num_warps"] * 32) == 0, (
67+
"BLOCK_SIZE must be a multiple of num_warps * 32"
68+
)
69+
70+
num_blocks = min(
71+
triton.cdiv(tensor.numel(), config["BLOCK_SIZE"]), config["max_num_blocks"]
72+
)
73+
74+
symm_mem_hdl = symm_mem.rendezvous(tensor, group=dist.group.WORLD)
75+
output = torch.empty_like(tensor)
76+
77+
signal_pads = tuple(
78+
[symm_mem_hdl.get_signal_pad(i, dtype=torch.int32) for i in range(symm_mem_hdl.world_size)]
79+
)
80+
81+
one_shot_all_reduce_kernel[(num_blocks, 1, 1)](
82+
symm_mem_hdl.buffer_ptrs_dev,
83+
signal_pads,
84+
output,
85+
numel=tensor.numel(),
86+
rank=symm_mem_hdl.rank,
87+
world_size=symm_mem_hdl.world_size,
88+
BLOCK_SIZE=config["BLOCK_SIZE"],
89+
num_warps=config["num_warps"],
90+
)
91+
92+
return output

helion/_compiler/type_propagation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,18 @@ def populate_symbol_origins(self, origin: Origin) -> None:
12101210
for i, subtype in enumerate(self.element_types):
12111211
subtype.populate_symbol_origins(GetItemOrigin(origin, i))
12121212

1213+
def propagate_getitem(self, key: TypeInfo, origin: Origin) -> TypeInfo:
1214+
# Check if all elements have the same type
1215+
first_type = self.element_types[0]
1216+
for element_type in self.element_types[1:]:
1217+
if type(element_type) != type(first_type):
1218+
raise exc.TypeInferenceError(
1219+
f"Sequence contains mixed types: cannot safely index. "
1220+
f"Found {type(first_type).__name__} and {type(element_type).__name__}"
1221+
)
1222+
1223+
return first_type
1224+
12131225
def merge(self, other: TypeInfo) -> TypeInfo:
12141226
if isinstance(other, SequenceType):
12151227
self_elements = self.element_types

test/test_distributed.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.distributed._symmetric_memory as symm_mem
6+
from torch.testing._internal.common_distributed import MultiProcessTestCase
7+
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
8+
from torch.testing._internal.common_utils import instantiate_parametrized_tests
9+
from torch.testing._internal.common_utils import run_tests
10+
11+
import helion
12+
from helion._testing import code_and_output
13+
import helion.language as hl
14+
15+
16+
@helion.jit
17+
def symm_mem_sync_kernel(
18+
remote_signal_pad_pointer: torch.Tensor, # shape[world_size]
19+
local_signal_pad: torch.Tensor,
20+
rank: int,
21+
) -> None:
22+
N, world_size = local_signal_pad.size()
23+
for n in hl.grid(N):
24+
for tile in hl.tile(world_size, block_size=world_size):
25+
peer_bars = remote_signal_pad_pointer[tile] + n * world_size + rank
26+
hl.signal(peer_bars, [tile], signal=1, scope="sys", skip_sync=True)
27+
hl.wait(
28+
local_signal_pad,
29+
[n, tile],
30+
signal=1,
31+
update=0,
32+
scope="sys",
33+
op="atomic_cas",
34+
)
35+
36+
37+
@instantiate_parametrized_tests
38+
class SymmMemBarrier(MultiProcessTestCase):
39+
def setUp(self) -> None:
40+
super().setUp()
41+
self._spawn_processes()
42+
43+
@property
44+
def world_size(self) -> int:
45+
# world_size > 2 is needed to verify accumulation order
46+
return 4
47+
48+
@property
49+
def device(self) -> torch.device:
50+
return torch.device(f"cuda:{self.rank}")
51+
52+
def _init_process(self):
53+
torch.cuda.set_device(self.device)
54+
store = dist.FileStore(self.file_name, self.world_size)
55+
dist.init_process_group(
56+
backend="nccl",
57+
world_size=self.world_size,
58+
rank=self.rank,
59+
store=store,
60+
)
61+
torch.manual_seed(42 + self.rank)
62+
63+
@skip_if_lt_x_gpu(4)
64+
def test_symm_mem_barrier(self):
65+
self._init_process()
66+
t = symm_mem.empty(4096, device=self.device)
67+
symm_mem_hdl = symm_mem.rendezvous(t, group=dist.group.WORLD)
68+
local_signal_pad_t = symm_mem_hdl.get_buffer(
69+
symm_mem_hdl.rank, (32, symm_mem_hdl.world_size), dtype=torch.int32
70+
)
71+
signa_pad_pointers_t = torch.as_tensor(
72+
symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64
73+
).to(self.device)
74+
75+
code, result = code_and_output(
76+
symm_mem_sync_kernel,
77+
(
78+
signa_pad_pointers_t,
79+
local_signal_pad_t,
80+
symm_mem_hdl.rank,
81+
),
82+
)
83+
84+
signal_pad = symm_mem_hdl.get_signal_pad(symm_mem_hdl.rank)
85+
assert signal_pad.eq(0).all().item()
86+
87+
dist.destroy_process_group()
88+
89+
90+
if __name__ == "__main__":
91+
run_tests()

0 commit comments

Comments
 (0)