This repository has been replaced by mosaix-pde.
Please use the new repository for all code, documentation, and issues.
The old code here is preserved for reference only.
pat-pde-opt
is a package for optimizing pattern forming PDEs that appear in different areas of physics, written in JAX.
It has code for PDE optimization and control with gradient-based methods and reinforcement learning.
We use diffrax for time stepping and implement system-specific solvers, such as semi-implicit Fourier methods and Strang splitting.
You can find the full documentation on read the docs.
To install the package, we recommend cloning the github repo and then installing locally:
git clone https://github.com/acoh64/pde-opt.git
cd pde-opt
conda create -y -n pde-opt-env python=3.12
conda activate pde-opt-env
pip install -e .
By default, it will install the CPU version of JAX. To use with GPU, run:
pip install -U "jax[cuda12]"
Here is an example of solving the Cahn-Hilliard equation in 2D with periodic boundary conditions using a semi-implicit Fourier method:
import jax
import jax.numpy as jnp
from pde_opt import PDEModel
from pde_opt import CahnHilliard2DPeriodic
from pde_opt import SemiImplicitFourierSpectral
from pde_opt import Domain
from pde_opt import PeriodicCNN
Nx = Ny = 128
Lx = Ly = 0.01 * Nx
domain = Domain((Nx, Ny), ((-Lx / 2, Lx / 2), (-Ly / 2, Ly / 2)), "dimensionless")
opt_model = PDEModel(equation_type=CahnHilliard2DPeriodic, domain=domain, solver_type=SemiImplicitFourierSpectral)
params = {"kappa": 0.002, "mu": lambda c: jnp.log(c / (1.0 - c)) + 3.0 * (1.0 - 2.0 * c), "D": lambda c: c * (1. - c)}
solver_params = {"A": 0.5}
key = jax.random.PRNGKey(0)
y0 = jnp.clip(0.01 * jax.random.normal(key, (Nx, Ny)) + 0.5, 0.0, 1.0)
ts = jnp.linspace(0.0, 0.02, 100)
sol = opt_model.solve(params, y0, ts, solver_params, dt0=0.000001, max_steps=1000000)
Next, here is an example of using the previous solution as a dataset to fit a neural network for the chemical potential term:
data = {}
data['ys'] = sol
data['ts'] = ts
model = PeriodicCNN(
in_channels=1,
hidden_channels=(32, 64, 64),
out_channels=1,
kernel_size=3,
key=jax.random.PRNGKey(0),
)
init_params = {"mu": model}
static_params = {"kappa": 0.002, "D": lambda c: c * (1. - c)}
solver_parameters = {"A": 0.5}
weights = {"mu": None}
lambda_reg = 0.0
inds = [[30,40,50], [50,60,70], [70,80,90]]
res = opt_model.train(data, inds, init_params, static_params, solver_parameters, weights, lambda_reg, method="mse", max_steps=100)
This package is designed to support pattern-forming PDEs across a wide-range of physical systems. We have currently implemented variants of the following equations:
- Cahn-Hilliard equation
- 2D with periodic boundary conditions
- 3D with periodic boundary conditions
- 2D with smoothed boundary method
- Allen-Cahn equation
- 2D with periodic boundary conditions
- 2D with constant current conditions + Butler-Volmer kinetics (for battery applications)
- 2D with smoothed boundary
- 2D with smoothed boundar and constant current conditions + Butler-Volmer kinetics (for battery applications)
- Gross-Pitaevskii
- Reduced 2D with periodic boundary conditions
- Rotating reduced with 2D periodic boundary conditions
- Arbitrary boundary conditions
- Implicit time stepping
- Multi-GPU support
- Extend to non-Cartesian domains
- WandB logging and checkpointing
This code has been published under the MIT licence.