High-performance optimization toolkit for Iterated Function Systems (IFS) using JAX.
This repository provides a complete framework for optimizing IFS parameters to match target distributions. It includes multiple distance metrics (Sinkhorn/Wasserstein and MMD), efficient fixed measure computation, surrogate gradient methods, and a comprehensive optimization orchestrator.
- 🚀 High Performance: JAX-optimized with 68x speedup after compilation
- 📊 Multiple Distance Metrics: Sinkhorn (Wasserstein) and MMD support
- 🔧 Optimization-Ready: Warm-starting and minimal recompilation overhead
- 📦 Modular Design: Separate packages for different components
- ✅ Well-Tested: Comprehensive test suite with benchmarks
- 🎯 Complete Pipeline: From IFS parameters to optimized distributions
IFSOpt/
├── optimizer/ # Main optimizer package (finished product)
│ ├── fixed_measure/ # Fixed measure computation
│ ├── parameter_update/ # Parameter gradient computation
│ ├── surrogate_gradients/ # Surrogate gradient methods
│ └── ifs_opt/ # Main orchestrator
│
├── Sinkhorn/ # Sinkhorn distance package
│ └── src/ifst/ # Custom Sinkhorn implementation
│
├── MMD/ # MMD distance package
│ └── src/mmd/ # FFT-accelerated MMD
│
├── examples/ # Usage examples
├── tests/ # Test suite
│
└── Content/
├── Testing/ # Notebooks and experiments
└── Code-Notes/ # Development documentation
See examples/ for detailed usage examples:
examples/full_pipeline_example.py- Complete optimization pipelineexamples/mmd_loss_example.py- Using MMD distanceexamples/compare_sinkhorn_vs_mmd.py- Comparing distance metrics
Measured on GPU (NVIDIA):
| Scenario | Time | Speedup |
|---|---|---|
| Old implementation | 1.428s | 1x (baseline) |
| New implementation | 0.181s | 7.9x |
| Cached execution | 0.017s | 68x |
No installation required - the package is self-contained.
Requirements:
- Python 3.8+
- JAX with GPU support (recommended)
- NumPy, Matplotlib
Verify setup:
python examples/test_correctness.pyThe main optimization framework combining all components. See optimizer/README.md for details.
Custom Sinkhorn implementation for computing optimal transport distances with O(d² log d) complexity instead of O(d³).
FFT-accelerated Maximum Mean Discrepancy computation for efficient distance metrics. See MMD/README.md for details.
optimizer/README.md- Main optimizer package documentationMMD/README.md- MMD package documentationContent/Code-Notes/- Development notes and design documents- Implementation guides, integration notes, and design documents
See the examples/ directory for complete usage examples:
full_pipeline_example.py- End-to-end optimization pipelinemmd_loss_example.py- Using MMD distance metriccompare_sinkhorn_vs_mmd.py- Comparing distance metricsparameter_update_example.py- Parameter gradient computationsurrogate_gradient_example.py- Surrogate gradient methods
Run tests from the tests/ directory:
# Run specific tests
python tests/test_F_updates.py
python tests/test_mmd_module.py
python tests/benchmark_surrogate_gradients.pyAll test files have been consolidated into the tests/ directory for easy access.
This repository is organized into three main areas:
- Production Code:
optimizer/,Sinkhorn/,MMD/- Finished, optimized packages - Examples & Tests:
examples/,tests/- Usage examples and test suite - Development:
Content/- Notebooks, experiments, and documentation
This is research code. For modifications:
- Review design documents in
Content/Code-Notes/ - Maintain test coverage in
tests/ - Add examples to
examples/for new features
- JAX team for the excellent automatic differentiation framework
- Research group for feedback and testing
Status: Active Development | Core components complete
Last Updated: 2025-10-09