You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm implementing a simple importance sampling routine to estimate an integral. The idea is to repeatedly draw a fixed number of samples (e.g., 100 per iteration) and refine the estimate until the Monte Carlo error falls below a desired threshold. Since the number of steps needed to reach that threshold is not known in advance, I’ve used jax.lax.while_loop to express this adaptive behavior.
Each loop iteration updates the running estimate and error by combining new samples with previous ones. This logic works correctly on its own, but I run into an issue when trying to differentiate the outer function using jax.grad. The following error is raised:
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.
This becomes a problem because the loop is embedded within a larger differentiable computation. I would like to know:
Is there a way to restructure this logic to make it compatible with reverse-mode autodiff in JAX?
Are there alternative mathematical or numerical strategies that avoid while_loop but still allow for adaptive stopping based on error tolerance?
I've included a MRE below. Thanks in advance.
importfunctoolsasftfromtypingimportTuple, TypeAliasimportjaxfromjaximportnnasjnn, numpyasjnp, randomasjrdfromjaxtypingimportArray, PRNGKeyArrayStateT: TypeAlias=Tuple[
Array, # old monte-carlo-estimateArray, # old errorArray, # old sizePRNGKeyArray, # old key
]
"""State of the Monte Carlo estimation process."""@ft.partial(jax.jit, static_argnames=("n_samples",))def_mvn_samples(
loc: Array, scale_tril: Array, n_samples: int, key: PRNGKeyArray
) ->Array:
"""Generate samples from a multivariate normal distribution using method from `numpyro.distributions.MultivariateNormal`. Parameters ---------- loc : Array Mean vector of the multivariate normal distribution. scale_tril : Array Lower triangular matrix of the covariance matrix (Cholesky decomposition). n_samples : int Number of samples to generate. key : PRNGKeyArray JAX random key for sampling. Returns ------- Array Samples drawn from the multivariate normal distribution. """eps=jrd.normal(key, shape=(n_samples, *loc.shape))
samples=loc+jnp.squeeze(jnp.matmul(scale_tril, eps[..., jnp.newaxis]), axis=-1)
returnsamples@jax.jitdef_monte_carlo_estimate_and_error(log_probs: Array, N: Array) ->Tuple[Array, Array]:
"""Computes the Monte Carlo estimate and error for the given log probabilities. Parameters ---------- log_probs : Array Log probabilities of the samples. N : int Number of samples used for the estimate. Returns ------- Tuple[Array, Array] Monte Carlo estimate and error. """mask=~jnp.isneginf(log_probs)
moment_1=jnp.exp(jnn.logsumexp(log_probs, where=mask, axis=-1)) /Nmoment_2=jnp.exp(jnn.logsumexp(2.0*log_probs, where=mask, axis=-1)) /Nerror=jnp.sqrt((moment_2-jnp.square(moment_1)) / (N-1.0))
returnmoment_1, error@jax.jitdef_combine_monte_carlo_estimates(
estimates_1: Array, estimates_2: Array, N_1: int, N_2: int
) ->Array:
r"""Combine two Monte Carlo estimates into a single estimate using the formula: .. math:: \hat{\mu} = \frac{N_1 \hat{\mu}_1 + N_2 \hat{\mu}_2}{N_1 + N_2} Parameters ---------- estimates_1 : Array First Monte Carlo estimate :math:`\hat{\mu}_1`. estimates_2 : Array Second Monte Carlo estimate :math:`\hat{\mu}_2`. N_1 : int Number of samples used for the first estimate :math:`N_1`. N_2 : int Number of samples used for the second estimate :math:`N_2`. Returns ------- Array Combined Monte Carlo estimate :math:`\hat{\mu}`. """combined_estimate= (N_1*estimates_1+N_2*estimates_2) / (N_1+N_2)
returncombined_estimate@jax.jitdef_combine_monte_carlo_errors(
error_1: Array,
error_2: Array,
estimate_1: Array,
estimate_2: Array,
estimate_3: Array,
N_1: int,
N_2: int,
) ->Array:
r"""Combine two Monte Carlo errors into a single error estimate using the formula: .. math:: \hat{\epsilon}=\sqrt{\frac{1}{N_3(N_3-1)}\sum_{k=1}^{2}\left\{N_k(N_k-1)\hat{\epsilon}_k^2+N_k\hat{\mu}^2_k\right\}-\frac{1}{N_3-1}\hat{\mu}^2} where, :math:`N_3 = N_1 + N_2` is the total number of samples. _extended_summary_ Parameters ---------- error_1 : Array Error of the first Monte Carlo estimate :math:`\hat{\epsilon}_1`. error_2 : Array Error of the second Monte Carlo estimate :math:`\hat{\epsilon}_2`. estimate_1 : Array Estimate of the first Monte Carlo estimate :math:`\hat{\mu}_1`. estimate_2 : Array Estimate of the second Monte Carlo estimate :math:`\hat{\mu}_2`. estimate_3 : Array Estimate of the combined Monte Carlo estimate :math:`\hat{\mu}`. N_1 : int Number of samples used for the first estimate :math:`N_1`. N_2 : int Number of samples used for the second estimate :math:`N_2`. Returns ------- Array Combined Monte Carlo error estimate :math:`\hat{\epsilon}`. """N_3=N_1+N_2sum_prob_sq_1=N_1* ((N_1-1.0) *jnp.square(error_1) +jnp.square(estimate_1))
sum_prob_sq_2=N_2* ((N_2-1.0) *jnp.square(error_2) +jnp.square(estimate_2))
combined_error_sq=-jnp.square(estimate_3) / (N_3-1.0)
combined_error_sq+= (sum_prob_sq_1+sum_prob_sq_2) /N_3/ (N_3-1.0)
combined_error=jnp.sqrt(combined_error_sq)
returncombined_error@jax.jitdef_error_fn(state: StateT) ->Array:
"""Check if the error in the Monte Carlo estimation is below a threshold. Parameters ---------- state : StateT The state of the Monte Carlo estimation. Returns ------- Array A boolean array indicating whether the error is below the threshold (0.01). """_, error, _, _=statereturnjnp.less_equal(error, 0.01)
deff(alpha: Array) ->Array:
@jax.jitdefscan_fn(carry: Array, rng_key: PRNGKeyArray) ->Tuple[Array, None]:
@jax.jitdefwhile_body_fn(state: StateT) ->StateT:
estimate_1, error_1, N_1, rng_key=stateN_2=100rng_key, subkey=jrd.split(rng_key)
data=jax.random.uniform(subkey, (N_2,))
log_prob=alpha*dataestimate_2, error_2=_monte_carlo_estimate_and_error(log_prob, N_1)
estimate_3=_combine_monte_carlo_estimates(
estimate_1, estimate_2, N_1, N_2
)
error_3=_combine_monte_carlo_errors(
error_1,
error_2,
estimate_1,
estimate_2,
estimate_3,
N_1,
N_2,
)
returnestimate_3, error_3, N_1+N_2, rng_keylog_likelihood, _, _, _=jax.lax.while_loop(
_error_fn,
while_body_fn,
(jnp.zeros(()), jnp.ones(()), jnp.zeros(()), rng_key),
)
returncarry+log_likelihood, Nonerng_key: PRNGKeyArray=jrd.PRNGKey(0)
n_events: int=10keys=jrd.split(rng_key, (n_events,))
total_log_likelihood, _=jax.lax.scan(
scan_fn, # type: ignore[arg-type]jnp.zeros(()),
keys,
length=n_events,
)
returntotal_log_likelihoodprint(jax.grad(f)(2.0))
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I'm implementing a simple importance sampling routine to estimate an integral. The idea is to repeatedly draw a fixed number of samples (e.g., 100 per iteration) and refine the estimate until the Monte Carlo error falls below a desired threshold. Since the number of steps needed to reach that threshold is not known in advance, I’ve used jax.lax.while_loop to express this adaptive behavior.
Each loop iteration updates the running estimate and error by combining new samples with previous ones. This logic works correctly on its own, but I run into an issue when trying to differentiate the outer function using
jax.grad
. The following error is raised:This becomes a problem because the loop is embedded within a larger differentiable computation. I would like to know:
I've included a MRE below. Thanks in advance.
Beta Was this translation helpful? Give feedback.
All reactions