A JAX-friendly, auto-differentiable, Python-only implementation of correctionlib correction evaluations.
Table of Contents
pip install correctionlib-gradients- construct a 
CorrectionWithGradientobject from acorrectionlib.schemav2.Correction - there is no point 2: you can use 
CorrectionWithGradient.evaluateas a normal JAX-friendly, auto-differentiable function 
import jax
import jax.numpy as jnp
from correctionlib import schemav2
from correctionlib_gradients import CorrectionWithGradient
# given a correctionlib schema:
formula_schema = schemav2.Correction(
    name="x squared",
    version=2,
    inputs=[schemav2.Variable(name="x", type="real")],
    output=schemav2.Variable(name="a scale", type="real"),
    data=schemav2.Formula(
        nodetype="formula",
        expression="x * x",
        parser="TFormula",
        variables=["x"],
    ),
)
# construct a CorrectionWithGradient
c = CorrectionWithGradient(formula_schema)
# use c.evaluate as a JAX-friendly, auto-differentiable function
value, grad = jax.value_and_grad(c.evaluate)(3.0)
assert jnp.isclose(value, 9.0)
assert jnp.isclose(grad, 6.0)
# for Formula corrections, jax.jit and jax.vmap work too
xs = jnp.array([3.0, 4.0])
values, grads = jax.vmap(jax.jit(jax.value_and_grad(c.evaluate)))(xs)
assert jnp.allclose(values, jnp.array([9.0, 16.0]))
assert jnp.allclose(grads, jnp.array([6.0, 8.0]))Currently the following corrections from correctionlib.schemav2 are supported:
Formula, including parametrical formulasBinningwith uniform or non-uniform bin edges andflow="clamp"; bin contents can be either:- all scalar values
 - all 
FormulaorFormulaRef 
- scalar constants
 
Only the evaluation of Formula corrections is fully JAX traceable.
For other corrections, e.g. Binning, gradients can be computed (jax.grad
works) but as JAX cannot trace the computation utilities such as jax.jit and
jax.vmap will not work. np.vectorize can be used as an alternative to
jax.vmap in these cases.
correctionlib-gradients is distributed under the terms of the
BSD 3-Clause license.