20
20
from jax .flatten_util import ravel_pytree
21
21
22
22
from blackjax .diagnostics import effective_sample_size
23
- from blackjax .util import incremental_value_update , pytree_size
23
+ from blackjax .util import generate_unit_vector , incremental_value_update , pytree_size
24
24
25
25
26
26
class MCLMCAdaptationState (NamedTuple ):
@@ -147,6 +147,8 @@ def predictor(previous_state, params, adaptive_state, rng_key):
147
147
148
148
time , x_average , step_size_max = adaptive_state
149
149
150
+ rng_key , nan_key = jax .random .split (rng_key )
151
+
150
152
# dynamics
151
153
next_state , info = kernel (params .sqrt_diag_cov )(
152
154
rng_key = rng_key ,
@@ -162,6 +164,7 @@ def predictor(previous_state, params, adaptive_state, rng_key):
162
164
params .step_size ,
163
165
step_size_max ,
164
166
info .energy_change ,
167
+ nan_key ,
165
168
)
166
169
167
170
# Warning: var = 0 if there were nans, but we will give it a very small weight
@@ -203,7 +206,7 @@ def step(iteration_state, weight_and_key):
203
206
streaming_avg = incremental_value_update (
204
207
expectation = jnp .array ([x , jnp .square (x )]),
205
208
incremental_val = streaming_avg ,
206
- weight = ( 1 - mask ) * success * params .step_size ,
209
+ weight = mask * success * params .step_size ,
207
210
)
208
211
209
212
return (state , params , adaptive_state , streaming_avg ), None
@@ -233,7 +236,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
233
236
)
234
237
235
238
# we use the last num_steps2 to compute the diagonal preconditioner
236
- mask = 1 - jnp .concatenate ((jnp .zeros (num_steps1 ), jnp .ones (num_steps2 )))
239
+ mask = jnp .concatenate ((jnp .zeros (num_steps1 ), jnp .ones (num_steps2 )))
237
240
238
241
# run the steps
239
242
state , params , _ , (_ , average ) = run_steps (
@@ -298,7 +301,9 @@ def step(state, key):
298
301
return adaptation_L
299
302
300
303
301
- def handle_nans (previous_state , next_state , step_size , step_size_max , kinetic_change ):
304
+ def handle_nans (
305
+ previous_state , next_state , step_size , step_size_max , kinetic_change , key
306
+ ):
302
307
"""if there are nans, let's reduce the stepsize, and not update the state. The
303
308
function returns the old state in this case."""
304
309
@@ -311,4 +316,13 @@ def handle_nans(previous_state, next_state, step_size, step_size_max, kinetic_ch
311
316
(next_state , step_size_max , kinetic_change ),
312
317
(previous_state , step_size * reduced_step_size , 0.0 ),
313
318
)
319
+
320
+ state = jax .lax .cond (
321
+ jnp .isnan (next_state .logdensity ),
322
+ lambda : state ._replace (
323
+ momentum = generate_unit_vector (key , previous_state .position )
324
+ ),
325
+ lambda : state ,
326
+ )
327
+
314
328
return nonans , state , step_size , kinetic_change
0 commit comments