Skip to content

Commit ebf7ece

Browse files
feat: Migrate to TensorFlow v2.0 (#541)
* Migrate the TensorFlow backend and optimizer to TensorFlow v2.0 / deprecate TensorFlow v1.0 behavior - Changes evaluation from lazy (v1.0) to eager (v2.0 default) * Require TensorFlow and TensorFlow Probability releases compatible with TensorFlow v2.0 * Simplify test configuration with deprecation of TensorFlow v1.0 Sessions
1 parent 3dc4d36 commit ebf7ece

File tree

9 files changed

+121
-212
lines changed

9 files changed

+121
-212
lines changed

docs/governance/ROADMAP.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ The roadmap will be executed over mostly Quarter 3 of 2019 through Quarter 1 of
6565
- [ ] Add better Model creation [2019-Q4 → 2020-Q1]
6666
- [ ] Add background model support (Issue #514) [2019-Q4 → 2020-Q1]
6767
- [ ] Develop interface for the optimizers similar to tensor/backend [2019-Q4 → 2020-Q1]
68-
- [ ] Migrate to TensorFlow v2.0 (PR #541) [2019-Q4]
68+
- [x] Migrate to TensorFlow v2.0 (PR #541) [2019-Q4]
6969
- [ ] Drop Python 2.7 support at end of 2019 (Issue #469) [2019-Q4 (last week of December 2019)]
7070
- [ ] Finalize public API [2020-Q1]
7171
- [ ] Integrate pyfitcore/Statisfactory API [2020-Q1]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
long_description = readme_md.read()
1010

1111
extras_require = {
12-
'tensorflow': ['tensorflow~=1.15', 'tensorflow-probability~=0.8', 'numpy~=1.16'],
12+
'tensorflow': ['tensorflow~=2.0', 'tensorflow-probability~=0.8'],
1313
'torch': ['torch~=1.2'],
1414
'xmlio': ['uproot'],
1515
'minuit': ['iminuit'],

src/pyhf/__init__.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,13 @@ def get_backend():
2828

2929

3030
@events.register('change_backend')
31-
def set_backend(backend, custom_optimizer=None, _session=None):
31+
def set_backend(backend, custom_optimizer=None):
3232
"""
3333
Set the backend and the associated optimizer
3434
3535
Example:
3636
>>> import pyhf
37-
>>> import tensorflow as tf
38-
>>> sess = tf.compat.v1.Session()
39-
>>> pyhf.set_backend("tensorflow", _session=sess)
37+
>>> pyhf.set_backend("tensorflow")
4038
>>> pyhf.tensorlib.name
4139
'tensorflow'
4240
>>> pyhf.set_backend(b"pytorch")
@@ -49,10 +47,6 @@ def set_backend(backend, custom_optimizer=None, _session=None):
4947
Args:
5048
backend (`str` or `pyhf.tensor` backend): One of the supported pyhf backends: NumPy, TensorFlow, and PyTorch
5149
custom_optimizer (`pyhf.optimize` optimizer): Optional custom optimizer defined by the user
52-
_session (|tf.compat.v1.Session|_): TensorFlow v1 compatible Session to use when the :code:`"tensorflow"` backend API is used
53-
54-
.. |tf.compat.v1.Session| replace:: ``tf.compat.v1.Session``
55-
.. _tf.compat.v1.Session: https://www.tensorflow.org/api_docs/python/tf/compat/v1/Session
5650
5751
Returns:
5852
None
@@ -64,18 +58,14 @@ def set_backend(backend, custom_optimizer=None, _session=None):
6458
if isinstance(backend, bytes):
6559
backend = backend.decode("utf-8")
6660
backend = backend.lower()
67-
# Needed while still using TF v1.0 API
68-
if backend == "tensorflow":
69-
backend = tensor.tensorflow_backend(session=_session)
70-
else:
71-
try:
72-
backend = getattr(tensor, "{0:s}_backend".format(backend))()
73-
except TypeError:
74-
raise InvalidBackend(
75-
"The backend provided is not supported: {0:s}. Select from one of the supported backends: numpy, tensorflow, pytorch".format(
76-
backend
77-
)
61+
try:
62+
backend = getattr(tensor, "{0:s}_backend".format(backend))()
63+
except TypeError:
64+
raise InvalidBackend(
65+
"The backend provided is not supported: {0:s}. Select from one of the supported backends: numpy, tensorflow, pytorch".format(
66+
backend
7867
)
68+
)
7969

8070
_name_supported = getattr(tensor, "{0:s}_backend".format(backend.name))
8171
if _name_supported:
@@ -94,8 +84,6 @@ def set_backend(backend, custom_optimizer=None, _session=None):
9484
new_optimizer = (
9585
custom_optimizer if custom_optimizer else optimize.tflow_optimizer(backend)
9686
)
97-
if tensorlib.name == 'tensorflow':
98-
tensorlib_changed |= bool(backend.session != tensorlib.session)
9987
elif backend.name == 'pytorch':
10088
new_optimizer = (
10189
custom_optimizer

src/pyhf/cli/infer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ def cls(
6868
if backend in ['pytorch', 'torch']:
6969
set_backend(tensor.pytorch_backend(float='float64'))
7070
elif backend in ['tensorflow', 'tf']:
71-
from tensorflow.compat.v1 import Session
72-
73-
set_backend(tensor.tensorflow_backend(session=Session(), float='float64'))
71+
set_backend(tensor.tensorflow_backend(float='float64'))
7472
tensorlib, _ = get_backend()
7573

7674
optconf = {k: v for item in optconf for k, v in item.items()}

src/pyhf/optimize/opt_tflow.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,6 @@
55
import tensorflow as tf
66

77

8-
def _eval_func(op, argop, dataop, data):
9-
def func(pars):
10-
tensorlib, _ = get_backend()
11-
pars_as_list = tensorlib.tolist(pars) if isinstance(pars, tf.Tensor) else pars
12-
data_as_list = tensorlib.tolist(data) if isinstance(data, tf.Tensor) else data
13-
value = tensorlib.session.run(
14-
op, feed_dict={argop: pars_as_list, dataop: data_as_list}
15-
)
16-
return value
17-
18-
return func
19-
20-
218
class tflow_optimizer(AutoDiffOptimizerMixin):
229
"""Tensorflow Optimizer Backend."""
2310

@@ -49,24 +36,18 @@ def setup_minimize(
4936
variable_init = all_init[variable_idx]
5037
variable_bounds = [par_bounds[i] for i in variable_idx]
5138

52-
data_placeholder = tf.placeholder(
53-
tensorlib.dtypemap['float'], (pdf.config.nmaindata + pdf.config.nauxdata,)
54-
)
55-
variable_pars_placeholder = tf.placeholder(
56-
tensorlib.dtypemap['float'], (pdf.config.npars - len(fixed_vals),)
57-
)
58-
5939
tv = _TensorViewer([fixed_idx, variable_idx])
6040

41+
data = tensorlib.astensor(data)
6142
fixed_values_tensor = tensorlib.astensor(fixed_values, dtype='float')
6243

63-
full_pars = tv.stitch([fixed_values_tensor, variable_pars_placeholder])
64-
65-
nll = objective(full_pars, data_placeholder, pdf)
66-
nllgrad = tf.identity(tf.gradients(nll, variable_pars_placeholder)[0])
67-
68-
func = _eval_func(
69-
[nll, nllgrad], variable_pars_placeholder, data_placeholder, data,
70-
)
44+
def func(pars):
45+
pars = tensorlib.astensor(pars)
46+
with tf.GradientTape() as tape:
47+
tape.watch(pars)
48+
constrained_pars = tv.stitch([fixed_values_tensor, pars])
49+
constr_nll = objective(constrained_pars, data, pdf)
50+
grad = tape.gradient(constr_nll, pars).values
51+
return constr_nll.numpy(), grad
7152

7253
return tv, fixed_values_tensor, func, variable_init, variable_bounds

0 commit comments

Comments
 (0)