Skip to content

Commit 0ff43e8

Browse files
revert: Allow 0-dimensional shape tensors (#972)
* Allow for arbitrary shape tensors, including 0-dimensional ("empty") shape tensors (floats) - Reverts PR #413 * Enforce optimizers return objective function as a scalar to harmonize return type * Revert tests that enforce minimum shape * Add tests for return structure shape of pyhf.infer.mle.fit * Add docstring examples for astensor
1 parent eadcd11 commit 0ff43e8

14 files changed

+108
-61
lines changed

src/pyhf/infer/mle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def fit(data, pdf, init_pars=None, par_bounds=None, **kwargs):
4040
>>> bestfit_pars
4141
array([0. , 1.0030512 , 0.96266961])
4242
>>> twice_nll
43-
array([24.98393521])
43+
array(24.98393521)
4444
>>> -2 * model.logpdf(bestfit_pars, data) == twice_nll
4545
array([ True])
4646
@@ -86,7 +86,7 @@ def fixed_poi_fit(poi_val, data, pdf, init_pars=None, par_bounds=None, **kwargs)
8686
>>> bestfit_pars
8787
array([1. , 0.97224597, 0.87553894])
8888
>>> twice_nll
89-
array([28.92218013])
89+
array(28.92218013)
9090
>>> -2 * model.logpdf(bestfit_pars, data) == twice_nll
9191
array([ True])
9292

src/pyhf/infer/test_statistics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,5 @@ def qmu(mu, data, pdf, init_pars, par_bounds):
5858
qmu = fixed_poi_fit_lhood_val - unconstrained_fit_lhood_val
5959
qmu = tensorlib.where(
6060
muhatbhat[pdf.config.poi_index] > mu, tensorlib.astensor(0.0), qmu
61-
)[0]
61+
)
6262
return tensorlib.clip(qmu, 0, max_value=None)

src/pyhf/optimize/opt_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@ def wrap_objective(objective, data, pdf, stitch_pars, do_grad=False, jit_pieces=
2727
def func(pars):
2828
pars = tensorlib.astensor(pars)
2929
constrained_pars = stitch_pars(pars)
30-
return objective(constrained_pars, data, pdf)
30+
return objective(constrained_pars, data, pdf)[0]
3131

3232
return func

src/pyhf/optimize/opt_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@ def func(pars):
2929
constrained_pars = stitch_pars(pars)
3030
constr_nll = objective(constrained_pars, data, pdf)
3131
grad = torch.autograd.grad(constr_nll, pars)[0]
32-
return constr_nll.detach().numpy(), grad
32+
return constr_nll.detach().numpy()[0], grad
3333

3434
else:
3535

3636
def func(pars):
3737
pars = tensorlib.astensor(pars)
3838
constrained_pars = stitch_pars(pars)
3939
constr_nll = objective(constrained_pars, data, pdf)
40-
return constr_nll
40+
return constr_nll[0]
4141

4242
return func

src/pyhf/optimize/opt_tflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,13 @@ def func(pars):
3131
# when tf.gather is used and this needs to be converted back to a
3232
# tensor to be usable as a value
3333
grad = tape.gradient(constr_nll, pars)
34-
return constr_nll.numpy(), tf.convert_to_tensor(grad)
34+
return constr_nll.numpy()[0], tf.convert_to_tensor(grad)
3535

3636
else:
3737

3838
def func(pars):
3939
pars = tensorlib.astensor(pars)
4040
constrained_pars = stitch_pars(pars)
41-
return objective(constrained_pars, data, pdf)
41+
return objective(constrained_pars, data, pdf)[0]
4242

4343
return func

src/pyhf/tensor/jax_backend.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,17 @@ def astensor(self, tensor_in, dtype='float'):
152152
"""
153153
Convert to a JAX ndarray.
154154
155+
Example:
156+
157+
>>> import pyhf
158+
>>> pyhf.set_backend("jax")
159+
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
160+
>>> tensor
161+
DeviceArray([[1., 2., 3.],
162+
[4., 5., 6.]], dtype=float64)
163+
>>> type(tensor)
164+
<class 'jax.interpreters.xla.DeviceArray'>
165+
155166
Args:
156167
tensor_in (Number or Tensor): Tensor object
157168
@@ -163,13 +174,8 @@ def astensor(self, tensor_in, dtype='float'):
163174
except KeyError:
164175
log.error('Invalid dtype: dtype must be float, int, or bool.')
165176
raise
166-
tensor = np.asarray(tensor_in, dtype=dtype)
167-
# Ensure non-empty tensor shape for consistency
168-
try:
169-
tensor.shape[0]
170-
except IndexError:
171-
tensor = np.reshape(tensor, [1])
172-
return np.asarray(tensor, dtype=dtype)
177+
178+
return np.asarray(tensor_in, dtype=dtype)
173179

174180
def sum(self, tensor_in, axis=None):
175181
return np.sum(tensor_in, axis=axis)

src/pyhf/tensor/numpy_backend.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,17 @@ def astensor(self, tensor_in, dtype='float'):
145145
"""
146146
Convert to a NumPy array.
147147
148+
Example:
149+
150+
>>> import pyhf
151+
>>> pyhf.set_backend("numpy")
152+
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
153+
>>> tensor
154+
array([[1., 2., 3.],
155+
[4., 5., 6.]])
156+
>>> type(tensor)
157+
<class 'numpy.ndarray'>
158+
148159
Args:
149160
tensor_in (Number or Tensor): Tensor object
150161
@@ -157,13 +168,7 @@ def astensor(self, tensor_in, dtype='float'):
157168
log.error('Invalid dtype: dtype must be float, int, or bool.')
158169
raise
159170

160-
tensor = np.asarray(tensor_in, dtype=dtype)
161-
# Ensure non-empty tensor shape for consistency
162-
try:
163-
tensor.shape[0]
164-
except IndexError:
165-
tensor = tensor.reshape(1)
166-
return tensor
171+
return np.asarray(tensor_in, dtype=dtype)
167172

168173
def sum(self, tensor_in, axis=None):
169174
return np.sum(tensor_in, axis=axis)

src/pyhf/tensor/pytorch_backend.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,17 @@ def astensor(self, tensor_in, dtype='float'):
112112
"""
113113
Convert to a PyTorch Tensor.
114114
115+
Example:
116+
117+
>>> import pyhf
118+
>>> pyhf.set_backend("pytorch")
119+
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
120+
>>> tensor
121+
tensor([[1., 2., 3.],
122+
[4., 5., 6.]])
123+
>>> type(tensor)
124+
<class 'torch.Tensor'>
125+
115126
Args:
116127
tensor_in (Number or Tensor): Tensor object
117128
@@ -124,13 +135,7 @@ def astensor(self, tensor_in, dtype='float'):
124135
log.error('Invalid dtype: dtype must be float, int, or bool.')
125136
raise
126137

127-
tensor = torch.as_tensor(tensor_in, dtype=dtype)
128-
# Ensure non-empty tensor shape for consistency
129-
try:
130-
tensor.shape[0]
131-
except IndexError:
132-
tensor = tensor.expand(1)
133-
return tensor
138+
return torch.as_tensor(tensor_in, dtype=dtype)
134139

135140
def gather(self, tensor, indices):
136141
return tensor[indices.type(torch.LongTensor)]
@@ -222,6 +227,7 @@ def simple_broadcast(self, *args):
222227
list of Tensors: The sequence broadcast together.
223228
"""
224229

230+
args = [arg.view(1) if not self.shape(arg) else arg for arg in args]
225231
max_dim = max(map(len, args))
226232
try:
227233
assert not [arg for arg in args if 1 < len(arg) < max_dim]

src/pyhf/tensor/tensorflow_backend.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,18 @@ def astensor(self, tensor_in, dtype='float'):
136136
"""
137137
Convert to a TensorFlow Tensor.
138138
139+
Example:
140+
141+
>>> import pyhf
142+
>>> pyhf.set_backend("tensorflow")
143+
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
144+
>>> tensor
145+
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
146+
array([[1., 2., 3.],
147+
[4., 5., 6.]], dtype=float32)>
148+
>>> type(tensor)
149+
<class 'tensorflow.python.framework.ops.EagerTensor'>
150+
139151
Args:
140152
tensor_in (Number or Tensor): Tensor object
141153
@@ -156,11 +168,6 @@ def astensor(self, tensor_in, dtype='float'):
156168
tensor.device
157169
except AttributeError:
158170
tensor = tf.convert_to_tensor(tensor_in)
159-
# Ensure non-empty tensor shape for consistency
160-
try:
161-
tensor.shape[0]
162-
except IndexError:
163-
tensor = tf.reshape(tensor, [1])
164171
if tensor.dtype is not dtype:
165172
tensor = tf.cast(tensor, dtype)
166173
return tensor
@@ -276,22 +283,16 @@ def simple_broadcast(self, *args):
276283
list of Tensors: The sequence broadcast together.
277284
278285
"""
279-
max_dim = max(map(lambda arg: arg.shape[0], args))
286+
287+
max_dim = max(map(tf.size, args))
280288
try:
281-
assert not [arg for arg in args if 1 < arg.shape[0] < max_dim]
289+
assert not [arg for arg in args if 1 < tf.size(arg) < max_dim]
282290
except AssertionError as error:
283291
log.error(
284292
'ERROR: The arguments must be of compatible size: 1 or %i', max_dim
285293
)
286294
raise error
287-
288-
broadcast = [
289-
arg
290-
if arg.shape[0] > 1
291-
else tf.tile(tf.slice(arg, [0], [1]), tf.stack([max_dim]))
292-
for arg in args
293-
]
294-
return broadcast
295+
return [tf.broadcast_to(arg, (max_dim,)) for arg in args]
295296

296297
def einsum(self, subscripts, *operands):
297298
"""
@@ -452,7 +453,7 @@ def normal_cdf(self, x, mu=0.0, sigma=1):
452453
TensorFlow Tensor: The CDF
453454
"""
454455
normal = tfp.distributions.Normal(
455-
self.astensor(mu, dtype='float')[0], self.astensor(sigma, dtype='float')[0],
456+
self.astensor(mu, dtype='float'), self.astensor(sigma, dtype='float'),
456457
)
457458
return normal.cdf(x)
458459

tests/test_backend_consistency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_hypotest_q_mu(
117117
q_mu = pyhf.infer.test_statistics.qmu(
118118
1.0, data, pdf, pdf.config.suggested_init(), pdf.config.suggested_bounds(),
119119
)
120-
test_statistic.append(pyhf.tensorlib.tolist(q_mu))
120+
test_statistic.append(q_mu)
121121

122122
# compare to NumPy/SciPy
123123
test_statistic = np.array(test_statistic)

tests/test_constraints.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,15 @@ def test_batched_constraints(backend):
187187
)
188188
)
189189
assert np.isclose(
190-
result[0],
190+
result,
191191
sum(
192192
[
193193
default_backend.poisson_logpdf(data, rate)
194194
for data, rate in zip([12, 13, 14], [12, 13, 14])
195195
]
196196
),
197197
)
198-
assert result.shape == (1,)
198+
assert result.shape == ()
199199

200200
suggested_pars = [1.1] * 3 + [0.0] * 5 # 2 pois 5 norm
201201
constraint = poisson_constraint_combined(config)
@@ -208,15 +208,15 @@ def test_batched_constraints(backend):
208208
)
209209
)
210210
assert np.isclose(
211-
result[0],
211+
result,
212212
sum(
213213
[
214214
default_backend.poisson_logpdf(data, rate)
215215
for data, rate in zip([12, 13, 14], [12 * 1.1, 13 * 1.1, 14 * 1.1])
216216
]
217217
),
218218
)
219-
assert result.shape == (1,)
219+
assert result.shape == ()
220220

221221
constraint = poisson_constraint_combined(config, batch_size=10)
222222
result = constraint.logpdf(

tests/test_infer.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,36 @@ def check_uniform_type(in_list):
2020
)
2121

2222

23+
def test_mle_fit_default(tmpdir, hypotest_args):
24+
"""
25+
Check that the default return structure of pyhf.infer.mle.fit is as expected
26+
"""
27+
tb = pyhf.tensorlib
28+
29+
_, data, model = hypotest_args
30+
kwargs = {}
31+
result = pyhf.infer.mle.fit(data, model, **kwargs)
32+
# bestfit_pars
33+
assert isinstance(result, type(tb.astensor(result)))
34+
assert pyhf.tensorlib.shape(result) == (model.config.npars,)
35+
36+
37+
def test_mle_fit_return_fitted_val(tmpdir, hypotest_args):
38+
"""
39+
Check that the return structure of pyhf.infer.mle.fit with the
40+
return_fitted_val keyword arg is as expected
41+
"""
42+
tb = pyhf.tensorlib
43+
44+
_, data, model = hypotest_args
45+
kwargs = {"return_fitted_val": True}
46+
result = pyhf.infer.mle.fit(data, model, **kwargs)
47+
# bestfit_pars, twice_nll
48+
assert pyhf.tensorlib.shape(result[0]) == (model.config.npars,)
49+
assert isinstance(result[0], type(tb.astensor(result[0])))
50+
assert pyhf.tensorlib.shape(result[1]) == ()
51+
52+
2353
def test_hypotest_default(tmpdir, hypotest_args):
2454
"""
2555
Check that the default return structure of pyhf.infer.hypotest is as expected

tests/test_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def test_optim_with_value(backend, source, spec, mu):
318318
return_fitted_val=True,
319319
)
320320
assert pyhf.tensorlib.tolist(result)
321-
assert pyhf.tensorlib.shape(fitted_val) == (1,)
321+
assert pyhf.tensorlib.shape(fitted_val) == ()
322322

323323

324324
@pytest.mark.parametrize('mu', [1.0], ids=['mu=1'])

tests/test_tensor.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ def test_simple_tensor_ops(backend):
4141
assert tb.tolist(tb.abs(tb.astensor([-1, -2]))) == [1, 2]
4242
a = tb.astensor(1)
4343
b = tb.astensor(2)
44-
assert tb.tolist(a < b)[0] is True
45-
assert tb.tolist(b < a)[0] is False
46-
assert tb.tolist(a < a)[0] is False
47-
assert tb.tolist(a > b)[0] is False
48-
assert tb.tolist(b > a)[0] is True
49-
assert tb.tolist(a > a)[0] is False
44+
assert tb.tolist(a < b) is True
45+
assert tb.tolist(b < a) is False
46+
assert tb.tolist(a < a) is False
47+
assert tb.tolist(a > b) is False
48+
assert tb.tolist(b > a) is True
49+
assert tb.tolist(a > a) is False
5050
a = tb.astensor(4)
5151
b = tb.astensor(5)
52-
assert tb.tolist(tb.conditional((a < b)[0], lambda: a + b, lambda: a - b)) == [9]
53-
assert tb.tolist(tb.conditional((a > b)[0], lambda: a + b, lambda: a - b)) == [-1]
52+
assert tb.tolist(tb.conditional((a < b), lambda: a + b, lambda: a - b)) == 9.0
53+
assert tb.tolist(tb.conditional((a > b), lambda: a + b, lambda: a - b)) == -1.0
5454

5555

5656
def test_complex_tensor_ops(backend):
@@ -145,10 +145,9 @@ def test_shape(backend):
145145
tb = pyhf.tensorlib
146146
assert tb.shape(tb.ones((1, 2, 3, 4, 5))) == (1, 2, 3, 4, 5)
147147
assert tb.shape(tb.ones((0, 0))) == (0, 0)
148+
assert tb.shape(tb.astensor(1.0)) == ()
148149
assert tb.shape(tb.astensor([])) == (0,)
149150
assert tb.shape(tb.astensor([1.0])) == (1,)
150-
assert tb.shape(tb.astensor(1.0)) == tb.shape(tb.astensor([1.0]))
151-
assert tb.shape(tb.astensor(0.0)) == tb.shape(tb.astensor([0.0]))
152151
assert tb.shape(tb.astensor((1.0, 1.0))) == tb.shape(tb.astensor([1.0, 1.0]))
153152
assert tb.shape(tb.astensor((0.0, 0.0))) == tb.shape(tb.astensor([0.0, 0.0]))
154153
with pytest.raises(

0 commit comments

Comments
 (0)