A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training
- [2025/7] 🚀 We release MagiAttention-v1.0.3 with improvements including documentation, support for all four mask types with arbitary overlapping, deterministic mode, API updates, FFA performance enhancements with bug fixes, optimized dispatch solvers, hierarchical-comm support, and example codes to train Llama-3 1B model with MagiAttention + FSDP / Transformers.
- [2025/6] 📌 We release MagiAttention-v1.0.2 to provide the example code to integrate Megatron with MagiAttention with several training convergence experiments (see here for more details), with some bug fixes and a simple roadmap.
- [2025/5] 📌 We release MagiAttention-v1.0.1 to support overlapped q_ranges when all mask types are
FULL
, with some code cleanup and bug fixes. - [2025/4] 🎉 We release MagiAttention-v1.0.0 with its blog: a distributed attention towards linear scalability for ultra-long context, heterogeneous mask training.
MagiAttention is a distributed attention mechanism, or context-parallel (CP) strategy, which aims to support a wide variety of attention mask types with kernel-level flexibility, while achieving linear scalability with respect to context-parallel (CP) size across a broad range of scenarios, particularly suitable for training tasks involving ultra-long, heterogeneous mask training like video-generation for Magi-1.
Additionally, it can be easily integrated into prevalent training frameworks such as Megatron-LM, Pytorch's native FSDP and transformers, as illustrated in QuickStart.
We are committed to continually improving the performance and generality of MagiAttention for the broader research community. Stay tuned for exciting enhancements and new features on the horizon!
To realize linear scalability for distributed attention, we implement and introduce key designs as follows.
For implementation details, more experimental results and future works, please visit our blog.
- Flexible Flash Attention Kernel. We introduce a generalized formulation for irregular attention mask patterns and implement a flexible flash attention kernel (FFA). It is natively designed for distribution scenarios and provides greater flexibility in handling diverse attention mask types, with performance comparable to Flash-Attention 3 on Hopper GPUs.
- Computation Load-Balance. With a fine-grained sharding strategy, we elaborate an efficient dispatch solver that ensures balanced attention computational loads across each CP rank in every training iteration.
- Zero-Redundant Communication. Instead of adopting the common Ring-style P2P communication pattern in CP, we propose two novel communication primitives, GroupCast and GroupReduce, built upon All-to-All-v as a prototypal implementation, enabling zero-redundant communication volume for both forward and backward passes.
- Adaptive Multi-Stage Overlap. Leveraging the above enhancements, we further implement a multi-stage compute-communication overlap strategy that effectively hides communication latency and adaptively optimizes overlap through manual or automatic tuning.
Please check here.
-
release note: here
-
docker image version: nvcr.io/nvidia/pytorch:25.05-py3
-
docker run command:
docker run --name {container_name} -v {host_mnt_root}:{container_mnt_root} -it -d --privileged --gpus all --network host --ipc host --ulimit memlock=-1 --ulimit stack=67108864 nvcr.io/nvidia/pytorch:25.05-py3 /bin/bash
-
docker exec command:
docker exec -it {container_name} /bin/bash
-
command:
pip install -r requirements.txt
-
command:
git clone https://github.com/SandAI-org/MagiAttention.git cd MagiAttention git submodule update --init --recursive # NOTE: this progress may take around 20~30 minute and occupies a lot of CPU resources for the first time. pip install --no-build-isolation .
Warning
MagiAttention currently only supports Hopper GPUs. We intend to broaden this support in upcoming updates.
We provide basic example code below of how to use flex_flash_attention
(non-distributed attention function) and magi_attention
(distributed attention mechanism), respectively.
For more usage instructions, you can refer to magi_attention/functional/flex_flash_attn.py
and magi_attention/api/magi_attn_interface.py
, respectively.
Basic Usage
-
flex_flash_attention:
import torch from magi_attention.api import flex_flash_attn_func # --- Define attention config --- # total_seqlen = 2048 # 2k tokens num_heads_q = 8 # number of attention (query) heads num_heads_kv = 2 # number of key/value heads (GQA) head_dim = 128 # dimension of each attention head dtype = torch.bfloat16 # attention activation / computation dtype (while the reduction dtype is always fp32 for ffa right now) device = "cuda" # --- Initialize QKV tensor --- # q = torch.randn(total_seqlen, num_heads_q, head_dim, dtype=dtype, device=device) k = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device) v = torch.randn(total_seqlen, num_heads_kv, head_dim, dtype=dtype, device=device) # --- Initialize FFA meta args for customized attention mask --- # # the following customized attention mask looks like (`*` for unmasked, `0` for masked): # - - - - - - - - -> (k) # | * * * * 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * * 0 0 0 # | * * * * * * 0 0 # | * * * * * * * 0 # | * * * * * * * * # V # (q) q_ranges_tensor = torch.tensor([[0, 1024], [1024, 2048]], dtype=torch.int32, device=device) k_ranges_tensor = torch.tensor([[0, 1024], [0, 2048]], dtype=torch.int32, device=device) attn_type_map_tensor = torch.tensor([0, 1], dtype=torch.int32, device=device) # full mask for 1st slice, causal mask for 2nd max_seqlen_q = 1024 # Max length of all q_ranges (2048 - 1024 = 1024) max_seqlen_k = 2048 # Max length of all k_ranges (2048 - 0 = 2048) # --- Attention computation --- # out, _ = flex_flash_attn_func( # the second return value is `lse` (log-sum-exp), known as the online-softmax correction factor q, k, v, q_ranges=q_ranges_tensor, k_ranges=k_ranges_tensor, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, attn_type_map=attn_type_map_tensor, softmax_scale=None, # defaults to 1/sqrt(head_dim) )
-
magi_attention: (NOTE: You should run the following examples in a distributed environment, e.g. using the common
torchrun
script)import torch import torch.nn as nn from magi_attention.api import ( magi_attn_flex_dispatch, calc_attn, undispatch, # interface functions compute_pad_size, # helper functions ) from magi_attention.common import AttnRanges from magi_attention.common.enum import AttnMaskType from magi_attention.utils import setup_dist_env, clearup_dist_env # --- Set up distributed environment --- # rank, local_rank, world_size, world_group, device, seed = setup_dist_env() # --- Define attention config --- # total_seqlen = 32 * 1024 # 32k tokens, if we dispatch it to 8 GPUs, then each GPU holds 4k tokens num_heads_q = 48 # number of attention (query) heads num_heads_kv = 8 # number of key/value heads (GQA) head_dim = 128 # dimension of each attention head dtype = torch.bfloat16 # attention activation / computation dtype (while the reduction dtype for partial attention outputs is always fp32 for magi_attention right now) chunk_size = 512 # chunk size to chunk the input tensor x along the seqlen dim for dispatch to control the granularity of computation load-balance. # --- Initialize token embedding tensor --- # embed_dim = 4096 x = torch.randn(total_seqlen, embed_dim, device=device, dtype=dtype, requires_grad=True) # --- Initialize MagiAttention meta configs for customized attention mask --- # # the following customized attention mask is known as `block-causal` mask where `block_size` = 4096 (4k), # which looks like (`*` for unmasked, `0` for masked): # - - - - - - - - -> (k) # | * * 0 0 0 0 0 0 # | * * 0 0 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * 0 0 0 0 # | * * * * * * 0 0 # | * * * * * * 0 0 # | * * * * * * * * # | * * * * * * * * # V # (q) q_ranges = AttnRanges.from_ranges( [ [0, 4096], # 0~4k [4096, 8192], # 4k~8k [8192, 12288], # 8k~12k [12288, 16384], # 12k~16k [16384, 20480], # 16k~20k [20480, 24576], # 20k~24k [24576, 28672], # 24k~28k [28672, 32768], # 28k~32k ] ) k_ranges = AttnRanges.from_ranges( [ [0, 4096], # 0~4k [0, 8192], # 0~8k [0, 12288], # 0~12k [0, 16384], # 0~16k [0, 20480], # 0~20k [0, 24576], # 0~24k [0, 28672], # 0~28k [0, 32768], # 0~32k ] ) attn_mask_type = [AttnMaskType.FULL] * len(q_ranges) total_seqlen_q = total_seqlen_k = total_seqlen pad_size = compute_pad_size( # pad embeds along seqlen dim for better performance total_seqlen_q=total_seqlen_q, cp_size=world_size, # assuming we only have 1-dim context parallelism (cp) chunk_size=chunk_size, ) # --- Dispatch token embedding tensor along seqlen dim to multiple ranks --- # # NOTE: # 1. the dispatched local token embedding may be shuffled along seqlen dim, # so it's safe for token-wise operations such as matmul, layer-norm, etc # while for sample-wise operations like RoPE, you might need to be more careful # 2. the `magi_runtime_key` holds some inner meta data as one argument for many other magi_attention APIs, # about which the users may have no bother to care local_x, magi_attn_runtime_key = magi_attn_flex_dispatch( x, q_ranges=q_ranges, k_ranges=k_ranges, attn_mask_type=attn_mask_type, total_seqlen_q=total_seqlen_q, total_seqlen_k=total_seqlen_k, pad_size=pad_size, chunk_size=chunk_size, cp_group_or_mesh=world_group, # assuming we only have 1-dim context parallelism (cp) ) # --- Simulate QKV projection --- # q_proj = nn.Linear(embed_dim, num_heads_q * head_dim, dtype=dtype, device=device) k_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device) v_proj = nn.Linear(embed_dim, num_heads_kv * head_dim, dtype=dtype, device=device) local_q = q_proj(local_x).view(-1, num_heads_q, head_dim) local_k = k_proj(local_x).view(-1, num_heads_kv, head_dim) local_v = v_proj(local_x).view(-1, num_heads_kv, head_dim) # --- Distributed attention computation --- # local_out, _ = calc_attn( # the second return value is `local_lse` (log-sum-exp), known as the online-softmax correction factor q=local_q, k=local_k, v=local_v, key=magi_attn_runtime_key, ) # --- Undispatch the output tensor along seqlen dim from multiple ranks and unpad --- # # NOTE: the undispatch API may not be used until the moment you need the seqlen dimension to be compelete and ordered, # e.g. for either aforementioned sample-wise operations, or loss computation total_out = undispatch( x=local_out, key=magi_attn_runtime_key, ) # --- Clear up distributed environment --- # clearup_dist_env()
We provide an example of how to integrate magi_attention with fsdp2 in example/torch_native
. You can use bash run.sh
to run the example.
In this example, we build a llama-1b model and apply fsdp2 with magi_attention as the parallelism strategy.
example/torch_native/modeling_llama.py
: build llama model and integrate with magi_attention.example/torch_native/main.py
: main training loop.
We create a new repository Megatron-LM-MagiAttention, forked from Megatron-LM v0.11.0, to provide an example of training the llama-1B model with Megatron-LM + MagiAttention. Furthermore, we conducted an experiment training llama-3-1B model from scratch to verify the convergence of magiattention.
For more information, you can refer to example/megatron/README.md
.
We provide an example of how to integrate magi_attention with transformers in example/transformers
. Furthermore, we conducted a continue-training experiment on llama-3-1B model to verify the convergence of magiattention.
For more information, you can refer to example/transformers/README.md
.
- Optimize
Flex-Flash-Attention
kernels to improve performance and better support sparse attention (such as NSA) - Support native
GroupCast
andGroupReduce
kernels and hierarchical communication optimization (similar to DeepEP) - Optimize
DistAttnSolver
to reduce CPU overhead for meta info calculation and support better comp-/comm- overlapping. - Support
Dynamic DistAttnSolver
with query/output communication pattern, one for either hybrid attention model or dynamic mask scenarios like sparse attention, the other for reducing communication overhead for many cases when only communicating key/value is not the best choice. - Support other attention patterns including cross-attention, sink tokens (w.r.t. StreamingLLM) and inference scenarios involving KV cache (w.r.t. Paged Attention).
- Support Blackwell as well as other GPU architectures.
- Provide a more comprehensive documentation with tutorials, and a more detailed technical blog.
- Provide more example codes and recipes for various training scenarios.
- Upgrade
MagiAttention
to a distributed nativeFlex-Flash-Attention
kernel (as a major version update). - Refactor
Distributed Attention Solver
to support all mask types with all kinds of overlap. - Improve
Dispatch Solver
to reduce necessary communication volumn while remaining balance in computation (especially for varlen mask patterns). - Build a comprehensive
CP Benchmark
to better compare the performance of different context parallel strategies under various mask patterns and other training configurations. - Provide
Documentation
includingInstallation
,QuickStart
andAPI reference
.
To demonstrate FFA kernels' state-of-the-art performance and flexibility in handling ultra-long, heterogeneous mask training, we measure the computing power (in
settings | value |
---|---|
batch size (b) | 1 |
number of heads (nh) | nhq:nhk:nhv = 64:8:8 (GQA) |
head dimension (hd) | 128 |
dtype | torch.bfloat16 |
dropout probability | 0.0 |
window size | 1024 (for sliding window masks only) |
Benchmark settings: for each mask pattern, we vary the sequence length seqlen
from seqlen_q = seqlen_k = seqlen
) while measuring computation power (in seqlen
.
Some Results are reported in the following figures, see more in our blog.






To validate the scalability of MagiAttention, we assess the per-GPU computing power (in
The experiments are conducted on a large-scale productive GPU cluster (Due to business and confidentiality reasons, specific details about the productive cluster, such as the number and type of GPUs, are withheld.). We scale the total sequence length seqlen
, the context-parallel size cp_size
, and the node size nnodes
together from seqlen:64k, cp_size:1, nnodes:1
, seqlen:128k, cp_size:2, nnodes:2
, ..., to seqlen:3072k (3M), cp_size:48, nnodes:48
.
The tensor-parallel size tp_size
is fixed at 8, with sequence-parallel enabled. Other data and model configurations for different mask types are the same as in the table in Kernel-Level Experiments.
Therefore, in every training setting, each rank is assigned constantly with seqlen=64k
, num_heads_q = 8
and num_heads_k = 1
for attention propagation, while the remaining activations stays seqlen=8k
, num_heads_q = 64
and num_heads_k = 8
with SP enabled. This setup simulates a common training configuration.
Some of the results are presented in the following figures, see more in our blog.
As demonstrated, MagiAttention exhibits linear scalability as the context length and CP size increase, in both full mask and varlen full mask configurations, for both forward and backward passes. In contrast, baseline methods either face strict limitations in scaling up or experience performance degradation with ultra-long contexts, which worsens with varlen mask patterns.




We welcome and value any contributions and collaborations. Please check out CONTRIBUTING.md for how to get involved.
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
If you use MagiAttention in your research, please cite:
@misc{magiattention2025,
title={MagiAttention: A Distributed Attention Towards Linear Scalability for Ultra-Long Context, Heterogeneous Mask Training},
author={Zewei, Tao and Yunpeng, Huang},
year={2025},
howpublished={\url{https://github.com/SandAI-org/MagiAttention/}},
}
We are grateful to the contributors listed below for their valuable contributions during the early stages of MagiAttention.
Member | Affiliations | GitHub Account | |
---|---|---|---|
Zewei Tao | SandAI | zeweitao@sand.ai | littsk |
Yunpeng Huang | SandAI | yunpenghuang@sand.ai | Strivin0311 |
Qiangang Wang | SandAI, Nanjing University | 522024330081@smail.nju.edu.cn | WT1W |
Hanwen Sun | SandAI, Peking University | sunhanwen@stu.pku.edu.cn | hanwen-sun |
Jin Li | SandAI, Tsinghua University | 2609835176@qq.com | lijinnn |
Tao Bu | Nanjing University | 502024330002@smail.nju.edu.cn | Big-TRex |
WenYang Fang | Nanjing University | fwy@smail.nju.edu.cn | kagami4243 |
Siyuang Yan | Nanjing University | siyuanyan@smail.nju.edu.cn | FibonaccciYan |
Zixu Jiang | Nanjing University | 522023330040@smail.nju.edu.cn | 191220042 |
Dingkun Xu | Nanjing University | 211220090@smail.nju.edu.cn | PureDimension |
Mingyu Liang | Nanjing University | mingyuliang518@gmail.com | gaomusiki |
Jingwei Xu | Nanjing University | jingweix@nju.edu.cn | paragonlight |