-
Notifications
You must be signed in to change notification settings - Fork 14
Description
There are several places in the cryoJAX image simulation pipeline that does repeated computation on 2D arrays, or does computation that should instead be in precompute.
Examples include
- Repetition in creating radial coordinate grids, all at runtime
- The boolean mask here should be in pre-compute: https://github.com/mjo22/cryojax/blob/main/src/cryojax/inference/distributions/_gaussian_distributions.py#L83
- Multiplying coordinate grids by the pixel size at runtime, when the pixel size is typically fixed
The solutions to these examples, which are more general themes, are 1) Runtime code optimizations 2) Solutions for how we can absorb current runtime compute into pre-compute and 3) Flexibly absorbing pre-compute into compile-time by treating certain parameters as static.
Related to 3), in general some extraneous computation will always exist on a naive usage of cryoJAX because we want any parameter to be possible to transform with a JAX transformation. For many tasks in cryo-EM someone will want to transform some subset of this, and therefore will want fine control of which compile-time vs runtime compute via parameters treated as static vs traced. An elegant solution to this is to inject equinox.internal.Static
wrappers into pytree leaves before JIT compilation. To be concrete, setting
shape, pixel_size = ...
instrument_config = InstrumentConfig(shape, pixel_size)
instrument_config = eqx.tree_at(lambda config: config.pixel_size, instrument_config, replace_fn=equinox.internal.Static)
should treat the pixel size as static in the current code and therefore should compute coordinate grids in angstroms at compile time rather than runtime. I think it would be good to write a tutorial that shows benchmarks and how to do this.
Last, we will also need to profile the code to see what compiler optimizations are doing; there are likely places where it optimizes away repetitive computation.