Skip to content

Commit 8a9b546

Browse files
authored
NaN Handling (#727)
* bug fix; first part * bug fix; first part * further debug * remove print statements * handle logdensity nans. mask -> 1 - mask.
1 parent b02b60b commit 8a9b546

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

blackjax/adaptation/mclmc_adaptation.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from jax.flatten_util import ravel_pytree
2121

2222
from blackjax.diagnostics import effective_sample_size
23-
from blackjax.util import incremental_value_update, pytree_size
23+
from blackjax.util import generate_unit_vector, incremental_value_update, pytree_size
2424

2525

2626
class MCLMCAdaptationState(NamedTuple):
@@ -147,6 +147,8 @@ def predictor(previous_state, params, adaptive_state, rng_key):
147147

148148
time, x_average, step_size_max = adaptive_state
149149

150+
rng_key, nan_key = jax.random.split(rng_key)
151+
150152
# dynamics
151153
next_state, info = kernel(params.sqrt_diag_cov)(
152154
rng_key=rng_key,
@@ -162,6 +164,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
162164
params.step_size,
163165
step_size_max,
164166
info.energy_change,
167+
nan_key,
165168
)
166169

167170
# Warning: var = 0 if there were nans, but we will give it a very small weight
@@ -203,7 +206,7 @@ def step(iteration_state, weight_and_key):
203206
streaming_avg = incremental_value_update(
204207
expectation=jnp.array([x, jnp.square(x)]),
205208
incremental_val=streaming_avg,
206-
weight=(1 - mask) * success * params.step_size,
209+
weight=mask * success * params.step_size,
207210
)
208211

209212
return (state, params, adaptive_state, streaming_avg), None
@@ -233,7 +236,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
233236
)
234237

235238
# we use the last num_steps2 to compute the diagonal preconditioner
236-
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))
239+
mask = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))
237240

238241
# run the steps
239242
state, params, _, (_, average) = run_steps(
@@ -298,7 +301,9 @@ def step(state, key):
298301
return adaptation_L
299302

300303

301-
def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_change):
304+
def handle_nans(
305+
previous_state, next_state, step_size, step_size_max, kinetic_change, key
306+
):
302307
"""if there are nans, let's reduce the stepsize, and not update the state. The
303308
function returns the old state in this case."""
304309

@@ -311,4 +316,13 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch
311316
(next_state, step_size_max, kinetic_change),
312317
(previous_state, step_size * reduced_step_size, 0.0),
313318
)
319+
320+
state = jax.lax.cond(
321+
jnp.isnan(next_state.logdensity),
322+
lambda: state._replace(
323+
momentum=generate_unit_vector(key, previous_state.position)
324+
),
325+
lambda: state,
326+
)
327+
314328
return nonans, state, step_size, kinetic_change

0 commit comments

Comments
 (0)