This repository provides a comprehensive exploration of GPU optimization techniques for PyTorch models, focusing on improving training efficiency and performance. By implementing and comparing various optimization strategies, the project offers practical insights into enhancing deep learning training workflows.
- Hardware Specifications
- Project Structure
- Installation
- Usage
- Optimization Techniques
- Performance Considerations
- Experimental Results
- References
- License
- Contributing
Experimental Environment:
- GPU: NVIDIA RTX 3050 Ti (4GB VRAM)
- CPU: Intel Core i5-11400H
- RAM: 16GB
Optimizing_GPU_Utilization.ipynb
: contains a full explanation for each optimization implemented.
no_optimization.py
: Baseline implementation without optimizationstensorFloat32.py
: TensorFloat-32 (TF32) precision optimizationbrainFloat16.py
: BFloat16 precision optimizationtorch_compile.py
: Torch JIT compilation optimizationflash_attention.py
: FlashAttention implementationfused_optimizer.py
: Fused optimizer optimization8-bit_optimizer.py
: 8-bit Adam optimizer for reduced memory usage
Utils/
: Contains model and data setup utilitiesMakefile
: Automation script for running experimentsrequirements.txt
: Project dependencies
- Python 3.12+
- CUDA-enabled GPU
- pip package manager
pip install -r requirements.txt
This project includes a Makefile that simplifies running experiments and generating comparisons.
When running experiments, you must specify three mandatory parameters:
STEPS=n
: Number of training steps to performBATCH_SIZE=b
: Size of each training batchPREFIX=path
: Output directory for results and plots
make baseline STEPS=50 BATCH_SIZE=256 PREFIX=./out # No optimization
make tf32 STEPS=50 BATCH_SIZE=256 PREFIX=./out # TensorFloat32
make bf16 STEPS=50 BATCH_SIZE=256 PREFIX=./out # BrainFloat16
make torch_compile STEPS=50 BATCH_SIZE=256 PREFIX=./out # Torch Compile
make flash STEPS=50 BATCH_SIZE=256 PREFIX=./out # FlashAttention
make fused STEPS=50 BATCH_SIZE=256 PREFIX=./out # Fused Optimizer
make 8bit STEPS=50 BATCH_SIZE=256 PREFIX=./out # 8-bit Optimizer
After running one or more experiments:
make plots STEPS=50 BATCH_SIZE=256 PREFIX=./out
make all STEPS=50 BATCH_SIZE=256 PREFIX=./out
make help
make reset # Reset results file and plots
make clean # Remove generated files
make init_results # Initialize results.csv file at `RESULTS_FILE` given path
- No Optimization: Baseline implementation
- TensorFloat-32 (TF32):
- Improved precision for matrix multiplications
- Balanced performance and accuracy
- BrainFloat16 (BF16):
- Reduced memory usage
- Faster training on supported hardware
- Torch Compile:
- Just-in-time (JIT) compilation
- Reduced overhead
- FlashAttention:
- Efficient attention mechanism
- Improved performance for transformer models
- Fused Optimizer:
- Reduced GPU kernel launches
- Enhanced training efficiency
- 8-bit Optimizer:
- Reduced memory footprint
- Potential training speed improvement
- Choose optimization techniques based on your specific hardware and model architecture
- Some techniques may have compilation overhead
- Performance gains vary depending on model complexity and hardware
The following plot shows the mean relative speedup comparison for different optimization techniques compared to the baseline (no optimization). These results were generated using a batch size of 256 and 150 training steps. This plot helps in visualizing the performance gains achieved by each optimization method.
By combining BF16, Torch compile, FlashAttention, and Fused Optimizer, I was able to reduce the average iteration time from 472.88 ms (no optimization) to 159.66 ms, making it ~3× faster! (Excluding compilation steps)
- Andrej Karpathy: Let's reproduce GPT-2 (124M)
- NVIDIA Ampere Architecture Whitepaper
- PyTorch Documentation on set_float32_matmul_precision
- PyTorch Documentation on Automatic Mixed Precision
- PyTorch Documentation on torch.compile
- PyTorch Documentation on scaled_dot_product_attention
- FlashAttention Paper
- Online softmax Paper
This project is licensed under the MIT License. See LICENSE
for details.
Contributions are welcome! Please submit pull requests or open issues to discuss potential improvements.