Releases: erfanzar/jax-flash-attn2
Releases · erfanzar/jax-flash-attn2
JAX Flash Attention 2
JAX Flash Attention 2.0 - v0.0.1
The first release of Jax-flash-attn2 provides a flexible and efficient implementation of Flash Attention 2.0 for JAX with multiple backend support.
🚀 Features
- Multiple backend support (GPU/TPU/CPU)
- Multiple platform implementations (Triton/Pallas/JAX)
- Efficient caching of attention instances
- Support for Grouped Query Attention (GQA)
- Head dimensions up to 256
- JAX sharding-friendly implementation
- Automatic platform selection based on the backend
- Compatible with existing JAX mesh patterns
💻 Installation
pip install jax-flash-attn2✅ Supported Configurations
Backend-Platform Compatibility Matrix
| Backend | Supported Platforms | 
|---|---|
| GPU | Triton, Pallas, JAX | 
| TPU | Pallas, JAX | 
| CPU | JAX | 
📦 Requirements
- Python >=3.10
- JAX >=0.4.33
- JAXlib >=0.4.33
- Triton ~=3.0.0
- scipy ==1.13.1
- einops
- chex
🔍 Known Limitations
- Triton platform is only available on NVIDIA GPUs
- Some platform-backend combinations are not supported
- Custom attention masks are not yet supported (use bias instead)
📝 Usage Example
from jax_flash_attn2 import get_cached_flash_attention
attention = get_cached_flash_attention(
    backend="gpu",
    platform="triton",
    blocksize_q=64,
    blocksize_k=128,
    softmax_scale=headdim ** -0.5
)
outputs = attention(
    query=query_states,
    key=key_states,
    value=value_states,
    bias=attention_bias,  # Optional
)🙏 Acknowledgments
- Based on Flash Attention 2.0 paper
- Uses JAX-Triton
- Kernels adapted from EasyDeL
📚 Documentation
Full documentation will soon be available at: https://erfanzar.github.io/jax-flash-attn2
🐛 Bug Reports
Please report any issues on our GitHub Issues page