Skip to content

Commit 0e1bdba

Browse files
dylanhmorristest
and
test
authored
Fix behavior of logdiffexp(inf,inf), add test that would have caught. Fix typos (#2007)
Co-authored-by: test <test@example.com>
1 parent 0a47e72 commit 0e1bdba

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

numpyro/contrib/einstein/steinvi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def __init__(
738738
):
739739
assert num_cycles > 0, f"The number of cycles must be >0. Got {num_cycles}."
740740
assert transition_speed > 0, (
741-
f"The transtion speed must be >0. Got {transition_speed}."
741+
f"The transition speed must be >0. Got {transition_speed}."
742742
)
743743

744744
self.num_cycles = num_cycles
@@ -785,7 +785,7 @@ def init(self, rng_key, num_steps, *args, **kwargs):
785785
786786
:param jax.random.PRNGKey rng_key: Random number generator seed.
787787
:param args: Positional arguments to the model and guide.
788-
:param num_steps: Totat number of steps in the optimization.
788+
:param num_steps: Total number of steps in the optimization.
789789
:param kwargs: Keyword arguments to the model and guide.
790790
:return: Initial :data:`ASVGDState`.
791791
"""

numpyro/distributions/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def log1mexp(x: ArrayLike) -> ArrayLike:
448448
:return: The value of :math:`\\log(1 - \\exp(x))`.
449449
"""
450450
return jnp.where(
451-
x > -0.6931472, # approx log(2)
451+
x > -0.6931472, # approx -log(2)
452452
jnp.log(-jnp.expm1(x)),
453453
jnp.log1p(-jnp.exp(x)),
454454
)
@@ -485,7 +485,7 @@ def logdiffexp(a: ArrayLike, b: ArrayLike) -> ArrayLike:
485485
return jnp.where(
486486
(a < jnp.inf) & (a > b),
487487
a + log1mexp(b - a),
488-
jnp.where(a == b, -jnp.inf, jnp.nan),
488+
jnp.where((a < jnp.inf) & (a == b), -jnp.inf, jnp.nan),
489489
)
490490

491491

test/test_distributions_util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,20 +180,23 @@ def test_logdiffexp_grads(a, b):
180180
(0, 0, -np.inf),
181181
(-np.inf, -np.inf, -np.inf),
182182
(5.6, 5.6, -np.inf),
183+
(1e34, 1e34, -np.inf),
183184
(1e34, 1e34 / 0.9999, np.nan),
185+
(np.inf, np.inf, np.nan),
184186
],
185187
)
186188
def test_logdiffexp_bounds_handling(a, b, expected):
187189
"""
188190
Test bounds handling for logdiffexp.
189191
190192
logdiffexp(jnp.inf, anything) should be nan,
193+
including logdiffexp(jnp.inf, jnp.inf).
191194
192195
logdiffexp(a, b) for a < b should be nan, even if numbers
193196
are very close.
194197
195198
logdiffexp(a, b) for a == b should be -jnp.inf
196-
even if a == b == -jnp.inf (log(0 - 0))
199+
even if a == b == -jnp.inf (log(0 - 0)).
197200
"""
198201
a = jnp.asarray(a)
199202
b = jnp.asarray(b)

0 commit comments

Comments
 (0)