
Build scientific simulators, treating them as a directed acyclic graph. Handles argument passing for complex nested simulators.
pip install caskade
More details on the docs page.
if you want to use caskade
with jax
then run:
pip install caskade[jax]
Alternately, just pip install jax
/jaxlib
separately as they are the only extra requirements.
Make a Module
object which may have some Param
s. Define a forward
method
using the decorator.
from caskade import Module, Param, forward
class MySim(Module):
def __init__(self, a, b=None):
super().__init__()
self.a = a
self.b = Param("b", b)
@forward
def myfun(self, x, b=None):
return x + self.a + b
We may now create instances of the simulator and pass the dynamic parameters.
import torch
sim = MySim(1.0)
params = [torch.tensor(2.0)]
print(sim.myfun(3.0, params=params))
Which will print 6
by automatically filling b
with the value from params
.
The above example is not very impressive, the real power comes from the fact
that Module
objects can be nested, making an arbitrarily complicated
analysis graph. Some other features include:
- Unroll parameters into 1D vector to interface with other packages (emcee, scipy.optimize, dynesty, etc.)
- Link parameters by value or functional relationship
- Reparametrize (e.g. between polar and cartesian) without modifying underlying code
- Save and load sampling chains automatically in HDF5
- Track metadata alongside parameters
- And much more! Beginner tutorial and Advanced tutorial
caskade
can be run with different backends for torch
, numpy
, and jax
.
See the Beginners Guide
tutorial
to learn more!
The caskade
interface has lots of flexibility, check out the
docs to learn more. For a quick start, jump
right to the Jupyter notebook
tutorial!
The caustics
package can serve
as a project template utilizing the many features of caskade
.
The caskade
package maintains 100% coverage for unit testing, ensuring
reliability as the backbone of a research project.