Skip to content

Commit eb9944c

Browse files
feat: Add jax backend (#377)
* add JAX backend
1 parent 98b2cb0 commit eb9944c

14 files changed

+554
-8
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Computational Backends:
5656
- [x] NumPy
5757
- [x] PyTorch
5858
- [x] TensorFlow
59+
- [x] JAX
5960

6061
Available Optimizers
6162

docs/installation.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ Install latest stable release from `PyPI <https://pypi.org/project/pyhf/>`__...
4545
4646
pip install pyhf[torch]
4747
48+
... with JAX backend
49+
++++++++++++++++++++
50+
51+
.. code-block:: console
52+
53+
pip install pyhf[jax]
54+
4855
... with all backends
4956
+++++++++++++++++++++
5057

@@ -85,6 +92,13 @@ Install latest development version from `GitHub <https://github.com/scikit-hep/p
8592
8693
pip install --ignore-installed -U "git+https://github.com/scikit-hep/pyhf.git#egg=pyhf[torch]"
8794
95+
... with JAX backend
96+
++++++++++++++++++++++
97+
98+
.. code-block:: console
99+
100+
pip install --ignore-installed -U "git+https://github.com/scikit-hep/pyhf.git#egg=pyhf[jax]"
101+
88102
... with all backends
89103
+++++++++++++++++++++
90104

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
extras_require = {
1212
'tensorflow': ['tensorflow~=2.0', 'tensorflow-probability~=0.8'],
1313
'torch': ['torch~=1.2'],
14+
'jax': ['jax~=0.1,>0.1.51', 'jaxlib~=0.1,>0.1.33'],
1415
'xmlio': ['uproot'],
1516
'minuit': ['iminuit'],
1617
}
1718
extras_require['backends'] = sorted(
1819
set(
1920
extras_require['tensorflow']
2021
+ extras_require['torch']
22+
+ extras_require['jax']
2123
+ extras_require['minuit']
2224
)
2325
)

src/pyhf/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ def set_backend(backend, custom_optimizer=None):
9090
if custom_optimizer
9191
else optimize.pytorch_optimizer(tensorlib=backend)
9292
)
93+
elif backend.name == 'jax':
94+
new_optimizer = (
95+
custom_optimizer if custom_optimizer else optimize.jax_optimizer()
96+
)
9397
else:
9498
new_optimizer = (
9599
custom_optimizer if custom_optimizer else optimize.scipy_optimizer()

src/pyhf/cli/infer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def cli():
3030
@click.option('--teststat', type=click.Choice(['q', 'qtilde']), default='qtilde')
3131
@click.option(
3232
'--backend',
33-
type=click.Choice(['numpy', 'pytorch', 'tensorflow', 'np', 'torch', 'tf']),
33+
type=click.Choice(['numpy', 'pytorch', 'tensorflow', 'jax', 'np', 'torch', 'tf']),
3434
help='The tensor backend used for the calculation.',
3535
default='numpy',
3636
)
@@ -69,6 +69,8 @@ def cls(
6969
set_backend(tensor.pytorch_backend(float='float64'))
7070
elif backend in ['tensorflow', 'tf']:
7171
set_backend(tensor.tensorflow_backend(float='float64'))
72+
elif backend in ['jax']:
73+
set_backend(tensor.jax_backend())
7274
tensorlib, _ = get_backend()
7375

7476
optconf = {k: v for item in optconf for k, v in item.items()}

src/pyhf/optimize/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@ def __getattr__(self, name):
1212
# for autocomplete and dir() calls
1313
self.scipy_optimizer = scipy_optimizer
1414
return scipy_optimizer
15+
elif name == 'jax_optimizer':
16+
try:
17+
from .opt_jax import jax_optimizer
18+
19+
assert jax_optimizer
20+
self.jax_optimizer = jax_optimizer
21+
return jax_optimizer
22+
except ImportError as e:
23+
raise exceptions.ImportBackendError(
24+
"There was a problem importing jax. The jax optimizer cannot be used.",
25+
e,
26+
)
1527
elif name == 'pytorch_optimizer':
1628
try:
1729
from .opt_pytorch import pytorch_optimizer

src/pyhf/optimize/opt_jax.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""JAX Optimizer Backend."""
2+
3+
from .. import get_backend, default_backend
4+
from ..tensor.common import _TensorViewer
5+
from .autodiff import AutoDiffOptimizerMixin
6+
import jax
7+
8+
9+
def _final_objective(pars, data, fixed_vals, model, objective, fixed_idx, variable_idx):
10+
tensorlib, _ = get_backend()
11+
tv = _TensorViewer([fixed_idx, variable_idx])
12+
pars = tensorlib.astensor(pars)
13+
constrained_pars = tv.stitch([fixed_vals, pars])
14+
return objective(constrained_pars, data, model)[0]
15+
16+
17+
_jitted_objective_and_grad = jax.jit(
18+
jax.value_and_grad(_final_objective), static_argnums=(3, 4, 5, 6)
19+
)
20+
21+
22+
class jax_optimizer(AutoDiffOptimizerMixin):
23+
"""JAX Optimizer Backend."""
24+
25+
def setup_minimize(
26+
self, objective, data, pdf, init_pars, par_bounds, fixed_vals=None
27+
):
28+
"""
29+
Prepare Minimization for AutoDiff-Optimizer.
30+
31+
Args:
32+
objective: objective function
33+
data: observed data
34+
pdf: model
35+
init_pars: initial parameters
36+
par_bounds: parameter boundaries
37+
fixed_vals: fixed parameter values
38+
39+
"""
40+
41+
tensorlib, _ = get_backend()
42+
all_idx = default_backend.astensor(range(pdf.config.npars), dtype='int')
43+
all_init = default_backend.astensor(init_pars)
44+
45+
fixed_vals = fixed_vals or []
46+
fixed_values = [x[1] for x in fixed_vals]
47+
fixed_idx = [x[0] for x in fixed_vals]
48+
49+
variable_idx = [x for x in all_idx if x not in fixed_idx]
50+
variable_init = all_init[variable_idx]
51+
variable_bounds = [par_bounds[i] for i in variable_idx]
52+
53+
tv = _TensorViewer([fixed_idx, variable_idx])
54+
55+
data = tensorlib.astensor(data)
56+
fixed_values_tensor = tensorlib.astensor(fixed_values, dtype='float')
57+
58+
def func(pars):
59+
# need to conver to tuple to make args hashable
60+
return _jitted_objective_and_grad(
61+
pars,
62+
data,
63+
fixed_values_tensor,
64+
pdf,
65+
objective,
66+
tuple(fixed_idx),
67+
tuple(variable_idx),
68+
)
69+
70+
return tv, fixed_values_tensor, func, variable_init, variable_bounds

src/pyhf/tensor/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ def __getattr__(self, name):
1010
# for autocomplete and dir() calls
1111
self.numpy_backend = numpy_backend
1212
return numpy_backend
13+
elif name == 'jax_backend':
14+
from .jax_backend import jax_backend
15+
16+
assert jax_backend
17+
# for autocomplete and dir() calls
18+
self.jax_backend = jax_backend
19+
return jax_backend
1320
elif name == 'pytorch_backend':
1421
try:
1522
from .pytorch_backend import pytorch_backend

0 commit comments

Comments
 (0)