From 451a2a0379e43184b39a86b54f7d21543eeb54d0 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Thu, 13 Feb 2025 01:25:55 -0800 Subject: [PATCH 1/4] [cp] set up load balancing testbed [ghstack-poisoned] --- attn_gym/load_balance/__init__.py | 3 + attn_gym/load_balance/load_balancer.py | 16 +++ examples/distributed_benchmark.py | 191 +++++++++++++++++++++++++ 3 files changed, 210 insertions(+) create mode 100644 attn_gym/load_balance/__init__.py create mode 100644 attn_gym/load_balance/load_balancer.py create mode 100644 examples/distributed_benchmark.py diff --git a/attn_gym/load_balance/__init__.py b/attn_gym/load_balance/__init__.py new file mode 100644 index 0000000..6b353af --- /dev/null +++ b/attn_gym/load_balance/__init__.py @@ -0,0 +1,3 @@ +from attn_gym.load_balance.load_balancer import load_balance_algo + +__all__ = ["load_balance_algo"] diff --git a/attn_gym/load_balance/load_balancer.py b/attn_gym/load_balance/load_balancer.py new file mode 100644 index 0000000..c552269 --- /dev/null +++ b/attn_gym/load_balance/load_balancer.py @@ -0,0 +1,16 @@ +from typing import List + + +__all__ = ["load_balance_algo"] + + +def load_balance_algo(S: int, size: int, block_size: int) -> List[List[int]]: + assert S % (size * block_size) == 0 + num_local_blk = S // (size * block_size) + return [ + [ + local_blk_idx + rank * num_local_blk + for local_blk_idx in range(num_local_blk) + ] + for rank in range(size) + ] diff --git a/examples/distributed_benchmark.py b/examples/distributed_benchmark.py new file mode 100644 index 0000000..24470e4 --- /dev/null +++ b/examples/distributed_benchmark.py @@ -0,0 +1,191 @@ +from functools import lru_cache +from typing import Optional, List + +import os +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import distribute_tensor, DTensor, DeviceMesh, Partial, Replicate, Shard + + +from torch.nn.attention.flex_attention import ( + _DEFAULT_SPARSE_BLOCK_SIZE, + create_block_mask, + flex_attention, + _mask_mod_signature, +) + +from attn_gym.masks.document_mask import length_to_offsets +from attn_gym.masks import ( + causal_mask, + generate_doc_mask_mod, +) +from attn_gym.load_balance import load_balance_algo + + +def get_device_type() -> str: + return "cuda" + + +@lru_cache +def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"): + block_mask = create_block_mask(score_mod, B, H, M, N, device=device) + return block_mask + + +# TODO: re-write it into a wrapper??? +def rewrite_mask_mod_for_cp( + mask_mod: _mask_mod_signature, + rank: int, + block_size: int, + load_balancer_output: List[List[int]], +) -> _mask_mod_signature: + def local_q_idx_to_q_idx(local_q_idx) -> int: + # calculate local block_idx and block_offset + local_blk_idx, local_blk_offset = ( + local_q_idx // block_size, local_q_idx % block_size + ) + current_rank_blk_list = load_balancer_output[rank] + blk_idx = current_rank_blk_list[local_blk_idx] + return blk_idx * block_size + local_blk_offset + + return lambda b, h, q_idx, kv_idx: mask_mod( + b, h, local_q_idx_to_q_idx(q_idx), kv_idx + ) + + +def run_document_masking(device_mesh, max_seq_len, num_docs): + # initialize the document lengths + import random + + random.seed(0) + torch.cuda.manual_seed(0) + + def generate_random_lengths(total_length, num_documents): + # Initialize all lengths to 1 to ensure each document has at least one token + lengths = [1] * num_documents + remaining_length = total_length - num_documents + + # Randomly distribute the remaining length + for _ in range(remaining_length): + index = random.randint(0, num_documents - 1) + lengths[index] += 1 + + return lengths + + lengths = generate_random_lengths(max_seq_len, num_docs) + offsets = length_to_offsets(lengths, torch.device(f'cuda:{torch.cuda.current_device():d}')) # TODO: replace with a device mesh call + document_causal_mask = generate_doc_mask_mod(causal_mask, offsets) + test_mask_with_load_balance(device_mesh, mask_mod=document_causal_mask, S=max_seq_len) + + +def test_mask_with_load_balance( + device_mesh: DeviceMesh, + mask_mod: Optional[_mask_mod_signature] = None, + B: int = 16, + H: int = 16, + S: int = 8192, + D: int = 64, + skip_correctness: bool = False, + print_mask: bool = True, + device: str = "cuda", +): + data_type = torch.float16 + + # create block mask + block_mask = create_block_mask_cached(mask_mod, 1, 1, S, S, device=device) + block_size = _DEFAULT_SPARSE_BLOCK_SIZE # TODO: get block size from block mask + + # input initialization + qkv = [ + torch.rand( + (B, H, S, D), + device=device_mesh.device_type, + dtype=data_type, + requires_grad=True, + ) + for _ in range(3) + ] + + # TODO: input sharding with load-balancing + # sparsity_info = get_sparsity_info_from_block_mask(block_mask) + # load_balancer_output = load_balance_algo(sparsity_info) + cp_mesh_size = device_mesh.size() + load_balancer_output = load_balance_algo(S, cp_mesh_size, block_size) + + seq_dim = 2 + qkv_dist = [ + distribute_tensor( + t.detach().clone().requires_grad_(), device_mesh, [ + Shard(seq_dim) if i == 0 else Replicate() + ] + ) + for (i, t) in enumerate(qkv) + ] + + q_local, k_full, v_full = (dt.to_local() for dt in qkv_dist) + + # rewrite `block_mask` + mask_mod: _mask_mod_signature = block_mask.mask_mod + cp_rank = device_mesh.get_local_rank() + cp_mask_mod = rewrite_mask_mod_for_cp( + mask_mod, cp_rank, block_size, load_balancer_output + ) + cp_block_mask = create_block_mask_cached( + cp_mask_mod, B=1, H=1, M=S // cp_mesh_size, N=S, device=device + ) + + # Compile the flex_attention function + compiled_flex_attention = torch.compile(flex_attention, dynamic=False) + + # TODO: this doesn't address the return_lse=True case + cp_out = compiled_flex_attention( + q_local, + k_full, + v_full, + score_mod=None, + block_mask=cp_block_mask, + ) + assert isinstance(cp_out, torch.Tensor) + + # unshard + cp_out_dist = DTensor.from_local(cp_out, device_mesh, [Shard(seq_dim)]) + full_cp_out_dist = cp_out_dist.full_tensor() + blk_idx_to_origin = [idx for idx_list in load_balancer_output for idx in idx_list] + blk_list_rearranged = [None] * len(blk_idx_to_origin) + blk_list = torch.chunk(full_cp_out_dist, dim=seq_dim) + assert len(blk_idx_to_origin) == len(blk_list) + for blk, blk_idx_origin in zip(blk_list, blk_idx_to_origin): + blk_list_rearranged[blk_idx_origin] = blk + + full_cp_out_dist = torch.cat(blk_list_rearranged, dim=seq_dim) + + + + +def load_balancing_example(world_size: int, rank: int) -> None: + device_type = get_device_type() + device_handle = getattr(torch, device_type, None) + assert device_handle is not None, f"Unsupported device type: {device_type}" + num_devices_per_host = device_handle.device_count() + device_handle.set_device(rank % num_devices_per_host) + torch._dynamo.config.cache_size_limit = 1000 + + # init device mesh + device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,)) + + run_document_masking(device_mesh, max_seq_len=8192, num_docs=12) + + +if __name__ == "__main__": + # this script is launched via torchrun which automatically manages ProcessGroup + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + # assert world_size == 4 # our example uses 4 worker ranks + + try: + load_balancing_example(world_size, rank) + finally: + dist.barrier() + dist.destroy_process_group() From dc1a1472a349f916d10cd6d61c7b405500cd82c2 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Thu, 13 Feb 2025 22:45:14 -0800 Subject: [PATCH 2/4] Update on "[cp] set up load balancing testbed" [ghstack-poisoned] --- attn_gym/load_balance/load_balancer.py | 17 +++++++---------- examples/distributed_benchmark.py | 24 ++++++++++++++---------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/attn_gym/load_balance/load_balancer.py b/attn_gym/load_balance/load_balancer.py index c552269..8702e7f 100644 --- a/attn_gym/load_balance/load_balancer.py +++ b/attn_gym/load_balance/load_balancer.py @@ -1,16 +1,13 @@ from typing import List +import torch + __all__ = ["load_balance_algo"] -def load_balance_algo(S: int, size: int, block_size: int) -> List[List[int]]: - assert S % (size * block_size) == 0 - num_local_blk = S // (size * block_size) - return [ - [ - local_blk_idx + rank * num_local_blk - for local_blk_idx in range(num_local_blk) - ] - for rank in range(size) - ] +def load_balance_algo(S: int, size: int, block_size: int) -> torch.Tensor: + total_num_blk = S // block_size + assert S % (size * total_num_blk) == 0 + local_num_blk = total_num_blk // size + return torch.arange(total_num_blk, device="cuda").view(size, local_num_blk) diff --git a/examples/distributed_benchmark.py b/examples/distributed_benchmark.py index 24470e4..65a407b 100644 --- a/examples/distributed_benchmark.py +++ b/examples/distributed_benchmark.py @@ -32,14 +32,14 @@ def get_device_type() -> str: def create_block_mask_cached(score_mod, B, H, M, N, device="cuda"): block_mask = create_block_mask(score_mod, B, H, M, N, device=device) return block_mask - + # TODO: re-write it into a wrapper??? def rewrite_mask_mod_for_cp( mask_mod: _mask_mod_signature, rank: int, block_size: int, - load_balancer_output: List[List[int]], + load_balancer_output: torch.Tensor, ) -> _mask_mod_signature: def local_q_idx_to_q_idx(local_q_idx) -> int: # calculate local block_idx and block_offset @@ -152,16 +152,20 @@ def test_mask_with_load_balance( # unshard cp_out_dist = DTensor.from_local(cp_out, device_mesh, [Shard(seq_dim)]) full_cp_out_dist = cp_out_dist.full_tensor() - blk_idx_to_origin = [idx for idx_list in load_balancer_output for idx in idx_list] - blk_list_rearranged = [None] * len(blk_idx_to_origin) - blk_list = torch.chunk(full_cp_out_dist, dim=seq_dim) - assert len(blk_idx_to_origin) == len(blk_list) - for blk, blk_idx_origin in zip(blk_list, blk_idx_to_origin): - blk_list_rearranged[blk_idx_origin] = blk + # rearrange + blk_idx_to_origin = load_balancer_output.view(-1) + num_chunks = blk_idx_to_origin.numel() + blk_list_rearranged = [None] * num_chunks + blk_list = torch.chunk(full_cp_out_dist, num_chunks, dim=seq_dim) + assert len(blk_list) == num_chunks + for blk_idx, blk in enumerate(blk_list): + blk_list_rearranged[blk_idx_to_origin[blk_idx].item()] = blk full_cp_out_dist = torch.cat(blk_list_rearranged, dim=seq_dim) - + # local flex attention + expect_out = flex_attention(*qkv, block_mask=block_mask) + torch.testing.assert_close(full_cp_out_dist, expect_out, atol=1e-1, rtol=1e-2) def load_balancing_example(world_size: int, rank: int) -> None: @@ -175,7 +179,7 @@ def load_balancing_example(world_size: int, rank: int) -> None: # init device mesh device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,)) - run_document_masking(device_mesh, max_seq_len=8192, num_docs=12) + run_document_masking(device_mesh, max_seq_len=4096, num_docs=12) if __name__ == "__main__": From ef00e3eea3f7a3c769fb3141453d6aa3e03bf366 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Thu, 13 Feb 2025 23:08:30 -0800 Subject: [PATCH 3/4] Update on "[cp] set up load balancing testbed" [ghstack-poisoned] --- examples/distributed_benchmark.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/distributed_benchmark.py b/examples/distributed_benchmark.py index 65a407b..19741ff 100644 --- a/examples/distributed_benchmark.py +++ b/examples/distributed_benchmark.py @@ -1,12 +1,11 @@ from functools import lru_cache -from typing import Optional, List +from typing import Optional import os import torch import torch.distributed as dist -import torch.nn.functional as F from torch.distributed.device_mesh import init_device_mesh -from torch.distributed.tensor import distribute_tensor, DTensor, DeviceMesh, Partial, Replicate, Shard +from torch.distributed.tensor import distribute_tensor, DTensor, DeviceMesh, Replicate, Shard from torch.nn.attention.flex_attention import ( From dd36aff07746f31a2ef73e7a023a9827748b46c3 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Tue, 18 Feb 2025 02:26:31 -0800 Subject: [PATCH 4/4] Update on "[cp] set up load balancing testbed" [ghstack-poisoned] --- attn_gym/load_balance/load_balancer.py | 45 ++++++++++++++++++--- examples/distributed_benchmark.py | 55 +++++++++++++++++++------- 2 files changed, 81 insertions(+), 19 deletions(-) diff --git a/attn_gym/load_balance/load_balancer.py b/attn_gym/load_balance/load_balancer.py index 8702e7f..211c077 100644 --- a/attn_gym/load_balance/load_balancer.py +++ b/attn_gym/load_balance/load_balancer.py @@ -1,4 +1,4 @@ -from typing import List +from abc import ABC, abstractmethod import torch @@ -7,7 +7,42 @@ def load_balance_algo(S: int, size: int, block_size: int) -> torch.Tensor: - total_num_blk = S // block_size - assert S % (size * total_num_blk) == 0 - local_num_blk = total_num_blk // size - return torch.arange(total_num_blk, device="cuda").view(size, local_num_blk) + return HeadTail.gen_load_balance_plan(S, size, block_size) + + +class LoadAlgorithm(ABC): + @classmethod + @abstractmethod + def gen_load_balance_plan(cls, S: int, size: int, block_size: int) -> torch.Tensor: + pass + + +class Noop(LoadAlgorithm): + @classmethod + def gen_load_balance_plan(cls, S: int, size: int, block_size: int) -> torch.Tensor: + total_num_blk = S // block_size + assert S % (size * block_size) == 0 + local_num_blk = total_num_blk // size + return torch.arange(total_num_blk, device="cuda").view(size, local_num_blk) + + +class HeadTail(LoadAlgorithm): + @classmethod + def gen_load_balance_plan(cls, S: int, size: int, block_size: int) -> torch.Tensor: + total_num_blk = S // block_size + assert S % (size * 2 * block_size) == 0 + local_num_blk_pair = total_num_blk // (size * 2) + plan_tensor = torch.arange(total_num_blk, device="cuda").view( + -1, local_num_blk_pair + ) + return torch.stack( + ( + plan_tensor[:size], + plan_tensor[size:].flip(dims=(0,)), + ), + dim=1, + ).view(size, -1) + + +if __name__ == "__main__": + print(HeadTail.gen_load_balance_plan(32, 4, 1)) diff --git a/examples/distributed_benchmark.py b/examples/distributed_benchmark.py index 19741ff..3963121 100644 --- a/examples/distributed_benchmark.py +++ b/examples/distributed_benchmark.py @@ -107,16 +107,48 @@ def test_mask_with_load_balance( for _ in range(3) ] - # TODO: input sharding with load-balancing - # sparsity_info = get_sparsity_info_from_block_mask(block_mask) - # load_balancer_output = load_balance_algo(sparsity_info) + # NOTE: this shuffle op can be done in other ways + def shuffle_tensor_for_load_balancing( + x: torch.Tensor, shuffle_tensor: torch.Tensor, dim: int + ) -> torch.Tensor: + # shuffle the tensor + num_chunks = shuffle_tensor.numel() + x_chunk_list = torch.chunk(x, num_chunks, dim=dim) + assert len(x_chunk_list) == num_chunks + new_x_chunk_list = [None] * num_chunks + for blk_idx in range(num_chunks): + new_x_chunk_list[blk_idx] = x_chunk_list[shuffle_tensor[blk_idx].item()] + + return torch.cat(new_x_chunk_list, dim=dim) + + def interchange_index_value_2d(tensor: torch.Tensor) -> torch.Tensor: + """ + Interchange the index and value in a PyTorch tensor. The input tensor has + structure: rank -> [block_idx, ...] and the output tensor will be: + block_idx -> block_idx_in_shuffled_tensor + """ + flattened_tensor = tensor.view(-1) + indices = torch.arange( + flattened_tensor.numel(), device=flattened_tensor.device + ) + revert_tensor = torch.empty_like(flattened_tensor) + revert_tensor[flattened_tensor] = indices + + return revert_tensor + cp_mesh_size = device_mesh.size() load_balancer_output = load_balance_algo(S, cp_mesh_size, block_size) seq_dim = 2 + # copy QKV + qkv_copy = [t.detach().clone() for t in qkv] + # shuffle Q + qkv_copy[0] = shuffle_tensor_for_load_balancing( + qkv_copy[0], load_balancer_output.view(-1), dim=seq_dim + ) qkv_dist = [ distribute_tensor( - t.detach().clone().requires_grad_(), device_mesh, [ + t.requires_grad_(), device_mesh, [ Shard(seq_dim) if i == 0 else Replicate() ] ) @@ -152,15 +184,10 @@ def test_mask_with_load_balance( cp_out_dist = DTensor.from_local(cp_out, device_mesh, [Shard(seq_dim)]) full_cp_out_dist = cp_out_dist.full_tensor() # rearrange - blk_idx_to_origin = load_balancer_output.view(-1) - num_chunks = blk_idx_to_origin.numel() - blk_list_rearranged = [None] * num_chunks - blk_list = torch.chunk(full_cp_out_dist, num_chunks, dim=seq_dim) - assert len(blk_list) == num_chunks - for blk_idx, blk in enumerate(blk_list): - blk_list_rearranged[blk_idx_to_origin[blk_idx].item()] = blk - - full_cp_out_dist = torch.cat(blk_list_rearranged, dim=seq_dim) + blk_idx_shuffled = interchange_index_value_2d(load_balancer_output) + full_cp_out_dist = shuffle_tensor_for_load_balancing( + full_cp_out_dist, blk_idx_shuffled, dim=seq_dim + ) # local flex attention expect_out = flex_attention(*qkv, block_mask=block_mask) @@ -179,7 +206,7 @@ def load_balancing_example(world_size: int, rank: int) -> None: device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,)) run_document_masking(device_mesh, max_seq_len=4096, num_docs=12) - + if __name__ == "__main__": # this script is launched via torchrun which automatically manages ProcessGroup