WIP - This repository attempts to implements the FlashAttention 2 algorithm as described in the research paper. The project provides two implementations:
- Numba-based Implementation: A Python version using Numba's CUDA JIT to gain insight into the algorithm before the full CUDA implementation.
- CUDA C++ Implementation: Work in progress CUDA implementation. Currently numericaly stable and accurate and reproduces the tiled computation of attention as described in the Flash Attention 2 Paper.
🔍 FlashAttention 2 efficiently computes scaled dot-product attention by splitting input matrices into blocks and leveraging on-chip SRAM, custom reductions, and careful kernel synchronization. This project is intended to:
- 🔬 Experiment with advanced CUDA programming techniques
- ⚡ Provide a reference Numba implementation for rapid prototyping and testing
- 📊 Offer profiling and benchmarking tools to evaluate performance
- CUDA-enabled GPU
Ensure you have the CUDA Toolkit installed (tested with CUDA 12.x) along with compatible NVIDIA drivers.
Use Python 3.7 or later
- Numba
- NumPy
- PyTorch
Simply install these with:
pip install numba numpy torch
To run the Numba-based version:
Run the main Numba script:
python numba_flash2/src/implementations/flash_attention2_numba.py
Or run the tests for the Numba implementation:
python -m unittest discover -s numba_flash2/src/tests -p '*_test.py'
Navigate to the repository root and compile using nvcc. For example:
nvcc -arch=sm_70 cuda_flash2/src/flash_attention_2.cu -o flash_attention_2
Execute the compiled binary:
./flash_attention_2
CUDA tests are located in the cuda_flash2/src/tests
directory. Compile and run them similarly. For example, to compile a test:
nvcc -arch=sm_70 cuda_flash2/src/tests/test_flash_attention2.cu -o test_flash_attention2
Then execute:
./test_flash_attention2
For performance analysis, additional profiling tools and scripts are provided. Check the cuda_flash2/profile_flash_attention
directory and refer to cuda_flash2/profiling_report.md
for details on interpreting the profiling results.