A high-performance implementation of Conway's Game of Life using JAX, featuring both CPU and GPU support. This implementation leverages JAX's vectorized operations and JIT compilation for optimal performance.
The example below compares python loops (Unoptimized), JAX Numpy arrays without JIT compilation (Optimized) and compiled JAX code (Compiled).
- 🚀 High-performance implementation using JAX's vectorized operations
- 🎮 Multiple implementation variants:
- Unoptimized (Python loops)
- Optimized (JAX vectorized operations)
- Compiled (JIT-compiled for maximum performance)
- 🖥️ GPU support through JAX
- 📊 Visualization capabilities
- 🧪 Comprehensive test suite
- 📈 Benchmarking tools
- Python 3.12.7
- JAX & jaxtyping (follow their guide)
- NVidia GPU for benchmarks
- Numpy, matplotlib and timeit for benchmarking
- pytest and pre-commit for developping
- Clone the repository:
git clone https://github.com/yourusername/game-of-life-jax.git
cd game-of-life-jax
- Create and activate a virtual environment:
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
- Install dependencies. I provide my own setup but you will most likely have to cherry pick libraries versions according to your hardware.
pip install -r requirements.txt
The core functionality is implemented in src/game_of_life.py
. Here's a basic example:
from src.game_of_life import generate_initial_grid, compiled_gameoflife
# Generate a random initial grid
grid = generate_initial_grid(size=100, proportion_alive=0.5)
# Run the simulation for 100 iterations using the compiled version
result = compiled_gameoflife(grid, n_iter=100)
The implementation uses three main functions:
augment_grid
: Adds padding around the grid to handle edge casescompute_neighbors
: Efficiently calculates neighbor sums using JAX's roll operationsnext_turn
: Applies Conway's rules using boolean operations
The implementation offers three variants:
unoptimized_gameoflife
: Traditional Python implementation using loopsoptimized_gameoflife
: Vectorized implementation using JAX operationscompiled_gameoflife
: JIT-compiled version for maximum performance
Run the test suite using pytest:
pytest -v
This project is licensed under the MIT License - see the LICENSE file for details.
- rayanehmi@github