Skip to content

SandAI-org/MagiAttention

Repository files navigation

MagiAttention

paper docs blog license

blog product Hugging Face Twitter Follow Discord license

A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training

MaiAttnOverview

Latest News 🔥

  • [2025/7] 🚀 We release MagiAttention-v1.0.3 with improvements including documentation, support for all four mask types with arbitary overlapping, deterministic mode, API updates, FFA performance enhancements with bug fixes, optimized dispatch solvers, hierarchical-comm support, and example codes to train Llama-3 1B model with MagiAttention + FSDP / Transformers.
  • [2025/6] 📌 We release MagiAttention-v1.0.2 to provide the example code to integrate Megatron with MagiAttention with several training convergence experiments (see here for more details), with some bug fixes and a simple roadmap.
  • [2025/5] 📌 We release MagiAttention-v1.0.1 to support overlapped q_ranges when all mask types are FULL, with some code cleanup and bug fixes.
  • [2025/4] 🎉 We release MagiAttention-v1.0.0 with its blog: a distributed attention towards linear scalability for ultra-long context, heterogeneous mask training.

About

MagiAttention is a distributed attention mechanism, or context-parallel (CP) strategy, which aims to support a wide variety of attention mask types with kernel-level flexibility, while achieving linear scalability with respect to context-parallel (CP) size across a broad range of scenarios, particularly suitable for training tasks involving ultra-long, heterogeneous mask training like video-generation for Magi-1.

Additionally, it can be easily integrated into prevalent training frameworks such as Megatron-LM, Pytorch's native FSDP and transformers, as illustrated in QuickStart.

We are committed to continually improving the performance and generality of MagiAttention for the broader research community. Stay tuned for exciting enhancements and new features on the horizon!

Key Features ✨

To realize linear scalability for distributed attention, we implement and introduce key designs as follows.

For implementation details, more experimental results and future works, please visit our blog.

  • Flexible Flash Attention Kernel. We introduce a generalized formulation for irregular attention mask patterns and implement a flexible flash attention kernel (FFA). It is natively designed for distribution scenarios and provides greater flexibility in handling diverse attention mask types, with performance comparable to Flash-Attention 3 on Hopper GPUs.
  • Computation Load-Balance. With a fine-grained sharding strategy, we elaborate an efficient dispatch solver that ensures balanced attention computational loads across each CP rank in every training iteration.
  • Zero-Redundant Communication. Instead of adopting the common Ring-style P2P communication pattern in CP, we propose two novel communication primitives, GroupCast and GroupReduce, built upon All-to-All-v as a prototypal implementation, enabling zero-redundant communication volume for both forward and backward passes.
  • Adaptive Multi-Stage Overlap. Leveraging the above enhancements, we further implement a multi-stage compute-communication overlap strategy that effectively hides communication latency and adaptively optimizes overlap through manual or automatic tuning.

Documentation 📚

Please check here.

Installation ⚙️

Step1: Activate an NGC pytorch docker container

  • release note: here

  • docker image version: nvcr.io/nvidia/pytorch:25.05-py3

  • docker run command:

    docker run --name {container_name} -v {host_mnt_root}:{container_mnt_root} -it -d --privileged --gpus all --network host --ipc host --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/pytorch:25.05-py3 /bin/bash
  • docker exec command:

    docker exec -it {container_name} /bin/bash

Step2: Install required packages

  • command:

    pip install -r requirements.txt

Step3: Install MagiAttention from source

  • command:

    git clone https://github.com/SandAI-org/MagiAttention.git
    
    cd MagiAttention
    
    git submodule update --init --recursive
    
    # NOTE: this progress may take around 20~30 minute and occupies a lot of CPU resources for the first time.
    pip install --no-build-isolation .

Quick Start 🚀

Warning

MagiAttention currently only supports Hopper GPUs. We intend to broaden this support in upcoming updates.

Basic Usage

We provide basic example code below of how to use flex_flash_attention (non-distributed attention function) and magi_attention (distributed attention mechanism), respectively.

For more usage instructions, you can refer to magi_attention/functional/flex_flash_attn.py and magi_attention/api/magi_attn_interface.py, respectively.

Basic Usage
  • flex_flash_attention:

    import torch
    from magi_attention.api import flex_flash_attn_func
    
    # --- Define attention config --- #
    
    total_seqlen = 2048    # 2k tokens
    num_heads_q = 8        # number of attention (query) heads
    num_heads_kv = 2       # number of key/value heads (GQA)
    head_dim = 128         # dimension of each attention head
    dtype = torch.bfloat16 # attention activation / computation dtype (while the reduction dtype is always fp32 for ffa right now)
    device = "cuda"
    
    # --- Initialize QKV tensor --- #
    
    q = torch.randn(total_seqlen, num_heads_q, head_dim, dtype=dtype, device=device)
    k = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device)
    v = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device)
    
    # --- Initialize FFA meta args for customized attention mask --- #
    
    # the following customized attention mask looks like (`*` for unmasked, `0` for masked):
    #     - - - - - - - - -> (k)
    #   | * * * * 0 0 0 0
    #   | * * * * 0 0 0 0
    #   | * * * * 0 0 0 0
    #   | * * * * 0 0 0 0
    #   | * * * * * 0 0 0
    #   | * * * * * * 0 0
    #   | * * * * * * * 0
    #   | * * * * * * * *
    #   V
    #  (q)
    q_ranges_tensor = torch.tensor([[0, 1024], [1024, 2048]], dtype=torch.int32, device=device)
    k_ranges_tensor = torch.tensor([[0, 1024], [0, 2048]], dtype=torch.int32, device=device)
    attn_type_map_tensor = torch.tensor([0, 1], dtype=torch.int32, device=device) # full mask for 1st slice, causal mask for 2nd
    
    max_seqlen_q = 1024 # Max length of all q_ranges (2048 - 1024 = 1024)
    max_seqlen_k = 2048 # Max length of all k_ranges (2048 - 0 = 2048)
    
    # --- Attention computation --- #
    
    out, _ = flex_flash_attn_func( # the second return value is `lse` (log-sum-exp), known as the online-softmax correction factor
        q, k, v,
        q_ranges=q_ranges_tensor,
        k_ranges=k_ranges_tensor,
        max_seqlen_q=max_seqlen_q,
        max_seqlen_k=max_seqlen_k,
        attn_type_map=attn_type_map_tensor,
        softmax_scale=None, # defaults to 1/sqrt(head_dim)
    )
  • magi_attention: (NOTE: You should run the following examples in a distributed environment, e.g. using the common torchrun script)

    import torch
    import torch.nn as nn
    from magi_attention.api import (
        magi_attn_flex_dispatch, calc_attn, undispatch, # interface functions
        compute_pad_size, # helper functions
    )
    from magi_attention.common import AttnRanges
    from magi_attention.common.enum import AttnMaskType
    from magi_attention.utils import setup_dist_env, clearup_dist_env
    
    # --- Set up distributed environment --- #
    
    rank, local_rank, world_size, world_group, device, seed = setup_dist_env()
    
    # --- Define attention config --- #
    
    total_seqlen = 32 * 1024   # 32k tokens, if we dispatch it to 8 GPUs, then each GPU holds 4k tokens
    num_heads_q = 48           # number of attention (query) heads
    num_heads_kv = 8           # number of key/value heads (GQA)
    head_dim = 128             # dimension of each attention head
    dtype = torch.bfloat16     # attention activation / computation dtype (while the reduction dtype for partial attention outputs is always fp32 for magi_attention right now)
    chunk_size = 512           # chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance.
    
    # --- Initialize token embedding tensor --- #
    
    embed_dim = 4096
    x = torch.randn(total_seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True)
    
    # --- Initialize MagiAttention meta configs for customized attention mask --- #
    
    # the following customized attention mask is known as `block-causal` mask where `block_size` = 4096 (4k),
    # which looks like (`*` for unmasked, `0` for masked):
    #     - - - - - - - - -> (k)
    #   | * * 0 0 0 0 0 0
    #   | * * 0 0 0 0 0 0
    #   | * * * * 0 0 0 0
    #   | * * * * 0 0 0 0
    #   | * * * * * * 0 0
    #   | * * * * * * 0 0
    #   | * * * * * * * *
    #   | * * * * * * * *
    #   V
    #  (q)
    q_ranges = AttnRanges.from_ranges(
        [
            [0, 4096], # 0~4k
            [4096, 8192], # 4k~8k
            [8192, 12288], # 8k~12k
            [12288, 16384], # 12k~16k
            [16384, 20480], # 16k~20k
            [20480, 24576], # 20k~24k
            [24576, 28672], # 24k~28k
            [28672, 32768], # 28k~32k
        ]
    )
    k_ranges = AttnRanges.from_ranges(
        [
            [0, 4096], # 0~4k
            [0, 8192], # 0~8k
            [0, 12288], # 0~12k
            [0, 16384], # 0~16k
            [0, 20480], # 0~20k
            [0, 24576], # 0~24k
            [0, 28672], # 0~28k
            [0, 32768], # 0~32k
        ]
    )
    attn_mask_type = [AttnMaskType.FULL] * len(q_ranges)
    total_seqlen_q = total_seqlen_k = total_seqlen
    pad_size = compute_pad_size( # pad embeds along seqlen dim for better performance
        total_seqlen_q=total_seqlen_q,
        cp_size=world_size, # assuming we only have 1-dim context parallelism (cp)
        chunk_size=chunk_size,
    )
    
    # --- Dispatch token embedding tensor along seqlen dim to multiple ranks --- #
    
    # NOTE:
    # 1. the dispatched local token embedding may be shuffled along seqlen dim,
    #    so it's safe for token-wise operations such as matmul, layer-norm, etc
    #    while for sample-wise operations like RoPE, you might need to be more careful
    # 2. the `magi_runtime_key` holds some inner meta data as one argument for many other magi_attention APIs,
    #    about which the users may have no bother to care
    local_x, magi_attn_runtime_key = magi_attn_flex_dispatch(
        x,
        q_ranges=q_ranges,
        k_ranges=k_ranges,
        attn_mask_type=attn_mask_type,
        total_seqlen_q=total_seqlen_q,
        total_seqlen_k=total_seqlen_k,
        pad_size=pad_size,
        chunk_size=chunk_size,
        cp_group_or_mesh=world_group, # assuming we only have 1-dim context parallelism (cp)
    )
    
    # --- Simulate QKV projection --- #
    
    q_proj = nn.Linear(embed_dim, num_heads_q * head_dim, dtype=dtype, device=device)
    k_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device)
    v_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device)
    
    local_q = q_proj(local_x).view(-1, num_heads_q, head_dim)
    local_k = k_proj(local_x).view(-1, num_heads_kv, head_dim)
    local_v = v_proj(local_x).view(-1, num_heads_kv, head_dim)
    
    # --- Distributed attention computation --- #
    
    local_out, _ = calc_attn( # the second return value is `local_lse` (log-sum-exp), known as the online-softmax correction factor
        q=local_q,
        k=local_k,
        v=local_v,
        key=magi_attn_runtime_key,
    )
    
    # --- Undispatch the output tensor along seqlen dim from multiple ranks and unpad --- #
    
    # NOTE: the undispatch API may not be used until the moment you need the seqlen dimension to be compelete and ordered,
    # e.g. for either aforementioned sample-wise operations, or loss computation
    total_out = undispatch(
        x=local_out,
        key=magi_attn_runtime_key,
    )
    
    # --- Clear up distributed environment --- #
    
    clearup_dist_env()

Example to integrate with FSDP2

We provide an example of how to integrate magi_attention with fsdp2 in example/torch_native. You can use bash run.sh to run the example.

In this example, we build a llama-1b model and apply fsdp2 with magi_attention as the parallelism strategy.

  • example/torch_native/modeling_llama.py: build llama model and integrate with magi_attention.
  • example/torch_native/main.py: main training loop.

Example to integrate with Megatron-LM

We create a new repository Megatron-LM-MagiAttention, forked from Megatron-LM v0.11.0, to provide an example of training the llama-1B model with Megatron-LM + MagiAttention. Furthermore, we conducted an experiment training llama-3-1B model from scratch to verify the convergence of magiattention.

For more information, you can refer to example/megatron/README.md.

Example to integrate with transformers

We provide an example of how to integrate magi_attention with transformers in example/transformers. Furthermore, we conducted a continue-training experiment on llama-3-1B model to verify the convergence of magiattention.

For more information, you can refer to example/transformers/README.md.

Roadmap ⛏️

  • Optimize Flex-Flash-Attention kernels to improve performance and better support sparse attention (such as NSA)
  • Support native GroupCast and GroupReduce kernels and hierarchical communication optimization (similar to DeepEP)
  • Optimize DistAttnSolver to reduce CPU overhead for meta info calculation and support better comp-/comm- overlapping.
  • Support Dynamic DistAttnSolver with query/output communication pattern, one for either hybrid attention model or dynamic mask scenarios like sparse attention, the other for reducing communication overhead for many cases when only communicating key/value is not the best choice.
  • Support other attention patterns including cross-attention, sink tokens (w.r.t. StreamingLLM) and inference scenarios involving KV cache (w.r.t. Paged Attention).
  • Support Blackwell as well as other GPU architectures.
  • Provide a more comprehensive documentation with tutorials, and a more detailed technical blog.
  • Provide more example codes and recipes for various training scenarios.
  • Upgrade MagiAttention to a distributed native Flex-Flash-Attention kernel (as a major version update).
  • Refactor Distributed Attention Solver to support all mask types with all kinds of overlap.
  • Improve Dispatch Solver to reduce necessary communication volumn while remaining balance in computation (especially for varlen mask patterns).
  • Build a comprehensive CP Benchmark to better compare the performance of different context parallel strategies under various mask patterns and other training configurations.
  • Provide Documentation including Installation, QuickStart and API reference.

Performance Benchmarks 📊

Kernel-Level Performance and Flexibility

To demonstrate FFA kernels' state-of-the-art performance and flexibility in handling ultra-long, heterogeneous mask training, we measure the computing power (in $\texttt{TFLOPs/s}$) on Hopper GPUs for both forward and backward passes of prevalent attention kernels across standard and irregular mask patterns.

settings value
batch size (b) 1
number of heads (nh) nhq:nhk:nhv = 64:8:8 (GQA)
head dimension (hd) 128
dtype torch.bfloat16
dropout probability 0.0
window size 1024 (for sliding window masks only)

Benchmark settings: for each mask pattern, we vary the sequence length seqlen from $4k,8k,16k,...,$ up to $128k$ (seqlen_q = seqlen_k = seqlen) while measuring computation power (in $\texttt{TFLOPs/s}$) for forward and backward passes of different attention kernels. Other configurations are fixed using common training settings (see the table above) to focus on the impact of sequence length and mask pattern. For the varlen packed data, we simply follow the variable sequence length distribution in the open-sourced dataset ChatQA2-Long-SFT-data, from which we sample to pack and pad to the required seqlen.

Some Results are reported in the following figures, see more in our blog.

full mask ffa
Benchmarking FFA's performance and flexibility against other leading attention kernels for full mask scenarios.
causal mask ffa
Benchmarking FFA's performance and flexibility against other leading attention kernels for causal mask scenarios.
varlen full mask ffa
Benchmarking FFA's performance and flexibility against other leading attention kernels for varlen full mask scenarios.
Note that: the E symbol indicates the corresponding distributed attention implementation raises Cuda Out of Memory error in that specific configuration.
varlen causal mask ffa
Benchmarking FFA's performance and flexibility against other leading attention kernels for varlen causal mask scenarios.
Note that: the E symbol indicates the corresponding distributed attention implementation raises Cuda Out of Memory error in that specific configuration.
sliding-window causal mask ffa
Benchmarking FFA's performance and flexibility against other leading attention kernels for sliding-window causal mask scenarios.
Note that: the E symbol indicates the corresponding distributed attention implementation raises Cuda Out of Memory error in that specific configuration.
varlen block causal mask ffa
Benchmarking FFA's performance and flexibility against other leading attention kernels for varlen block causal mask scenarios.
Note that: the E symbol indicates the corresponding distributed attention implementation raises Cuda Out of Memory error in that specific configuration, while the X symbol indicates the corresponding distributed attention implementation is not supported in that specific configuration.

Module-Level Scalability

To validate the scalability of MagiAttention, we assess the per-GPU computing power (in $\texttt{TFLOPs/s/GPU}$) of the attention module during both forward and backward propagation, as the sequence length and parallel size increase. This assessment is compared against common CP strategies including Ring-Attention and Ulysses. Due to the complexity of supporting irregular masks for baselines, our experiments are limited to the full mask and varlen full mask scenarios. And the distribution of variable sequence lengths still follow the one in Kernel-Level Experiments.

The experiments are conducted on a large-scale productive GPU cluster (Due to business and confidentiality reasons, specific details about the productive cluster, such as the number and type of GPUs, are withheld.). We scale the total sequence length seqlen, the context-parallel size cp_size, and the node size nnodes together from seqlen:64k, cp_size:1, nnodes:1, seqlen:128k, cp_size:2, nnodes:2, ..., to seqlen:3072k (3M), cp_size:48, nnodes:48.

The tensor-parallel size tp_size is fixed at 8, with sequence-parallel enabled. Other data and model configurations for different mask types are the same as in the table in Kernel-Level Experiments.

Therefore, in every training setting, each rank is assigned constantly with seqlen=64k, num_heads_q = 8 and num_heads_k = 1 for attention propagation, while the remaining activations stays seqlen=8k, num_heads_q = 64 and num_heads_k = 8 with SP enabled. This setup simulates a common training configuration.

Some of the results are presented in the following figures, see more in our blog.

As demonstrated, MagiAttention exhibits linear scalability as the context length and CP size increase, in both full mask and varlen full mask configurations, for both forward and backward passes. In contrast, baseline methods either face strict limitations in scaling up or experience performance degradation with ultra-long contexts, which worsens with varlen mask patterns.

full mask magi_attention fwd full mask magi_attention bwd
Benchmarking MaiAttention's scalability against other leading CP strategies for full mask scenarios.
Note that: the X symbol indicates the corresponding distributed attention implementation is not supported in that specific configuration.
varlen full mask magi_attention fwd varlen full mask magi_attention bwd
Benchmarking MaiAttention's scalability against other leading CP strategies for varlen full mask scenarios.
Note that: the X symbol indicates the corresponding distributed attention implementation is not supported in that specific configuration.

Contributing 🤝

We welcome and value any contributions and collaborations. Please check out CONTRIBUTING.md for how to get involved.

License ⚖️

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

Citation 📝

If you use MagiAttention in your research, please cite:

@misc{magiattention2025,
  title={MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training},
  author={Zewei, Tao and Yunpeng, Huang},
  year={2025},
  howpublished={\url{https://github.com/SandAI-org/MagiAttention/}},
}

Acknowledgement

We are grateful to the contributors listed below for their valuable contributions during the early stages of MagiAttention.

Member Affiliations Email GitHub Account
Zewei Tao SandAI zeweitao@sand.ai littsk
Yunpeng Huang SandAI yunpenghuang@sand.ai Strivin0311
Qiangang Wang SandAI, Nanjing University 522024330081@smail.nju.edu.cn WT1W
Hanwen Sun SandAI, Peking University sunhanwen@stu.pku.edu.cn hanwen-sun
Jin Li SandAI, Tsinghua University 2609835176@qq.com lijinnn
Tao Bu Nanjing University 502024330002@smail.nju.edu.cn Big-TRex
WenYang Fang Nanjing University fwy@smail.nju.edu.cn kagami4243
Siyuang Yan Nanjing University siyuanyan@smail.nju.edu.cn FibonaccciYan
Zixu Jiang Nanjing University 522023330040@smail.nju.edu.cn 191220042
Dingkun Xu Nanjing University 211220090@smail.nju.edu.cn PureDimension
Mingyu Liang Nanjing University mingyuliang518@gmail.com gaomusiki
Jingwei Xu Nanjing University jingweix@nju.edu.cn paragonlight

Star History

About

A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Data Training

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published

Contributors 10