|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import os |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.distributed as dist |
| 7 | +import torch.distributed._symmetric_memory as symm_mem |
| 8 | + |
| 9 | +import helion |
| 10 | +import helion.language as hl |
| 11 | + |
| 12 | + |
| 13 | +@helion.jit( |
| 14 | + config=helion.Config( |
| 15 | + block_sizes=[4096], |
| 16 | + num_warps=32, |
| 17 | + ), |
| 18 | + static_shapes=True, |
| 19 | +) |
| 20 | +def one_shot_all_reduce_kernel_8( |
| 21 | + signal_pad_addrs: torch.Tensor, |
| 22 | + local_signal_pad: torch.Tensor, |
| 23 | + a_shared_tuple: tuple[torch.Tensor, ...], |
| 24 | + my_rank: hl.constexpr, |
| 25 | +): |
| 26 | + _, world_size = local_signal_pad.size() |
| 27 | + world_size = hl.specialize(world_size) |
| 28 | + out = torch.empty_like(a_shared_tuple[0]) |
| 29 | + N = out.size(0) |
| 30 | + |
| 31 | + for tile_n in hl.tile(N): |
| 32 | + peer_bar_offset = ( |
| 33 | + tile_n.id * world_size + my_rank |
| 34 | + ) * 4 # offset the barrier pointers in btyes. 4 bytes per torch.int32 barrier. |
| 35 | + for multicast_tile in hl.tile(world_size, block_size=world_size): |
| 36 | + hl.signal( |
| 37 | + signal_pad_addrs[multicast_tile] + peer_bar_offset, |
| 38 | + wait_for=0, |
| 39 | + signal=1, |
| 40 | + op="atomic_cas", |
| 41 | + sem="relaxed", |
| 42 | + scope="sys", |
| 43 | + skip_sync=True, |
| 44 | + as_ptrs=True, |
| 45 | + ) |
| 46 | + hl.wait( |
| 47 | + local_signal_pad, |
| 48 | + [tile_n.id, multicast_tile], |
| 49 | + signal=1, |
| 50 | + update=0, |
| 51 | + scope="sys", |
| 52 | + op="atomic_cas", |
| 53 | + ) |
| 54 | + |
| 55 | + acc = torch.zeros([tile_n], dtype=torch.float32, device=out.device) |
| 56 | + |
| 57 | + # TODO(joydddd): support indexing into a tuple with iterator from tl.static_range |
| 58 | + # For now, manually unroll the loop |
| 59 | + acc += a_shared_tuple[0][tile_n] |
| 60 | + acc += a_shared_tuple[1][tile_n] |
| 61 | + acc += a_shared_tuple[2][tile_n] |
| 62 | + acc += a_shared_tuple[3][tile_n] |
| 63 | + acc += a_shared_tuple[4][tile_n] |
| 64 | + acc += a_shared_tuple[5][tile_n] |
| 65 | + acc += a_shared_tuple[6][tile_n] |
| 66 | + acc += a_shared_tuple[7][tile_n] |
| 67 | + |
| 68 | + out[tile_n] = acc |
| 69 | + |
| 70 | + for multicast_tile in hl.tile(world_size, block_size=world_size): |
| 71 | + hl.signal( |
| 72 | + signal_pad_addrs[multicast_tile] + peer_bar_offset, |
| 73 | + wait_for=0, |
| 74 | + signal=1, |
| 75 | + op="atomic_cas", |
| 76 | + sem="relaxed", |
| 77 | + scope="sys", |
| 78 | + as_ptrs=True, |
| 79 | + ) |
| 80 | + hl.wait( |
| 81 | + local_signal_pad, |
| 82 | + [tile_n.id, multicast_tile], |
| 83 | + signal=1, |
| 84 | + update=0, |
| 85 | + scope="sys", |
| 86 | + op="atomic_cas", |
| 87 | + skip_sync=True, |
| 88 | + ) |
| 89 | + return out |
| 90 | + |
| 91 | + |
| 92 | +def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: |
| 93 | + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) |
| 94 | + |
| 95 | + a_shared_tuple = tuple( |
| 96 | + [ |
| 97 | + symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype) |
| 98 | + for i in range(symm_mem_hdl.world_size) |
| 99 | + ] |
| 100 | + ) |
| 101 | + |
| 102 | + local_signal_pad = symm_mem_hdl.get_signal_pad( |
| 103 | + symm_mem_hdl.rank, dtype=torch.int32 |
| 104 | + ).view(-1, symm_mem_hdl.world_size) |
| 105 | + |
| 106 | + signal_pad_addrs = torch.as_tensor( |
| 107 | + symm_mem_hdl.signal_pad_ptrs, dtype=torch.uint64 |
| 108 | + ).to(a_shared.device) |
| 109 | + |
| 110 | + return one_shot_all_reduce_kernel_8( |
| 111 | + signal_pad_addrs, |
| 112 | + local_signal_pad, |
| 113 | + a_shared_tuple, |
| 114 | + my_rank=symm_mem_hdl.rank, |
| 115 | + ) |
| 116 | + |
| 117 | + |
| 118 | +def test(N: int, device: torch.device, dtype=torch.bfloat16) -> None: |
| 119 | + world_size = dist.get_world_size() |
| 120 | + a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_() |
| 121 | + |
| 122 | + golden_a = a_shared.clone() |
| 123 | + a_out = helion_one_shot_all_reduce(a_shared) |
| 124 | + |
| 125 | + dist_group = dist.group.WORLD |
| 126 | + if dist_group is None: |
| 127 | + raise RuntimeError("No distributed group available") |
| 128 | + |
| 129 | + print(a_out) |
| 130 | + gloden_o = torch.ops.symm_mem.one_shot_all_reduce( |
| 131 | + golden_a, "sum", dist_group.group_name |
| 132 | + ) |
| 133 | + |
| 134 | + torch.testing.assert_close(a_out, gloden_o, rtol=1e-1, atol=1e-1) |
| 135 | + |
| 136 | + |
| 137 | +def main() -> None: |
| 138 | + rank = int(os.environ["LOCAL_RANK"]) |
| 139 | + torch.manual_seed(42 + rank) |
| 140 | + device = torch.device(f"cuda:{rank}") |
| 141 | + torch.cuda.set_device(device) |
| 142 | + dist.init_process_group("nccl") |
| 143 | + test(16384, device) |
| 144 | + |
| 145 | + dist.destroy_process_group() |
| 146 | + |
| 147 | + |
| 148 | +if __name__ == "__main__": |
| 149 | + """ |
| 150 | + Run with: |
| 151 | + torchrun \ |
| 152 | + --nnodes 1 --nproc-per-node 8 \ |
| 153 | + --rdzv-backend c10d --rdzv-endpoint localhost:0 \ |
| 154 | + --no_python python3 examples/all_reduce.py |
| 155 | + """ |
| 156 | + main() |
0 commit comments