Skip to content

Releases: erfanzar/jax-flash-attn2

JAX Flash Attention 2

23 Oct 22:44

Choose a tag to compare

JAX Flash Attention 2.0 - v0.0.1

The first release of Jax-flash-attn2 provides a flexible and efficient implementation of Flash Attention 2.0 for JAX with multiple backend support.

🚀 Features

  • Multiple backend support (GPU/TPU/CPU)
  • Multiple platform implementations (Triton/Pallas/JAX)
  • Efficient caching of attention instances
  • Support for Grouped Query Attention (GQA)
  • Head dimensions up to 256
  • JAX sharding-friendly implementation
  • Automatic platform selection based on the backend
  • Compatible with existing JAX mesh patterns

💻 Installation

pip install jax-flash-attn2

✅ Supported Configurations

Backend-Platform Compatibility Matrix

Backend Supported Platforms
GPU Triton, Pallas, JAX
TPU Pallas, JAX
CPU JAX

📦 Requirements

  • Python >=3.10
  • JAX >=0.4.33
  • JAXlib >=0.4.33
  • Triton ~=3.0.0
  • scipy ==1.13.1
  • einops
  • chex

🔍 Known Limitations

  • Triton platform is only available on NVIDIA GPUs
  • Some platform-backend combinations are not supported
  • Custom attention masks are not yet supported (use bias instead)

📝 Usage Example

from jax_flash_attn2 import get_cached_flash_attention

attention = get_cached_flash_attention(
    backend="gpu",
    platform="triton",
    blocksize_q=64,
    blocksize_k=128,
    softmax_scale=headdim ** -0.5
)

outputs = attention(
    query=query_states,
    key=key_states,
    value=value_states,
    bias=attention_bias,  # Optional
)

🙏 Acknowledgments

📚 Documentation

Full documentation will soon be available at: https://erfanzar.github.io/jax-flash-attn2

🐛 Bug Reports

Please report any issues on our GitHub Issues page