You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
jax.jacobian() gives nan and reports 'Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.'
#17538
Truly sorry that the code is a bit long and may require jaxopt and ott for reproduction.
I came across an issue that jax.jacobian() gives nan and reports 'Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.'. I googled that but cannot find any helpful answer for me.
(Some posts say that it happens when your result is too large. But maybe my result is too small...)
In my coding, when I calculate the jax.jacobian() of a matrix w.r.t. a parameter vector, everything works fine. But then I just need to calculate the jax.jacobian()of the sum of the matrix by columns, w.r.t. the parameter vector, then jax.jacobian() report nans... The coding may be long, but the issue is not complicated to understand...
I appreciate it if someone could give some advice about how to remove this nan result...
from jax.config import config
from jax import lax, jit
import numpy as np
import jaxopt
import jax.random
import jax.scipy.optimize
import jax.scipy as jsp
import jax.numpy as jnp
import jax
from functools import partial
from IPython.core.interactiveshell import InteractiveShell
import ott
InteractiveShell.ast_node_interactivity = "all"
config.update("jax_enable_x64", True)
config.update("jax_debug_nans", False)
key = jax.random.PRNGKey(0)
key, *subkeys = jax.random.split(key, num=100)
theta0 = jnp.array([-7.5, 3.2, -22, -1, .15, jnp.log(0.5)])
uprbd = 2000000
N = 300
# the constant level
tau_c = 4500
k = 388440
t_c = 24 * 365.24
sigma_b_s = jnp.array([.2, .3, .2, .05])
@jax.jit
def gamma(z):
return jnp.exp(jax.scipy.special.gammaln(z))
def sigmoid(x, a):
return 0.5 * (jnp.tanh(x * a / 2) + 1)
def logit(x):
return(jnp.log(x/(1-x)))
def expit(x):
return 1/(1+jnp.exp(-x))
def Teqn(Ts, pars, k):
A = pars[0]
b = pars[1]
C = pars[2]
n = pars[3]
s0 = pars[4]
mu = 1
As = A * k * Ts
Cs = C * k * Ts
intfac = 1 / mu * jnp.power(Cs, n) * Ts / \
(n + 1) * jnp.power(1 - s0, n + 1)
bnfr = (b + 1) / (n + 1)
result = jnp.power(As, b) / jnp.power(Cs, n * bnfr)
result = result * jnp.power(mu / Ts * (n + 1), (b - n) / (n + 1))
result = result *jsp.special.gammainc(bnfr, intfac) * gamma(bnfr) - jnp.exp(-intfac)
return result
def const_Tc(Ts, T0, pars, k, uprbd):
A = pars[0]
b = pars[1]
C = pars[2]
n = pars[3]
s0 = pars[4]
mu = 1
As = A * k * Ts
Cs = C * k * Ts
indicator = 0.0000001
extraload = jnp.where(T0 / Ts - s0 >= 0, T0 / Ts - s0, indicator)
# extraload = jnp.where(T0/Ts-s0>=0,T0/Ts-s0,0)
intfac = jnp.where(extraload == indicator, 1, 1 / mu *
jnp.power(Cs, n) * Ts / (n + 1) * jnp.power(extraload, n + 1))
bnfr = (b + 1) / (n + 1)
res = jnp.zeros(4)
# res[0] good
res = res.at[0].set(1 / mu * jnp.power(As * extraload, b))
res = res.at[1].set(1 / mu * jnp.power(Cs * extraload, n))
result = jnp.exp(-res[1] * T0 + intfac) /jnp.power(As,b)
result = result /jnp.power(Cs,n * bnfr)*jnp.power(mu / Ts * (n + 1),(b - n) / (n + 1))
result = result * jsp.special.gammainc(bnfr, intfac) * gamma(bnfr)
res = res.at[2].set(result)
res = res.at[3].set(jnp.exp(-res[1] * T0))
Tc = -1 / res[1] * jnp.log((res[0] / res[1] *
res[3] + res[2]) / (1 + res[0] / res[1]))
return jnp.where(Tc > 0, Tc, uprbd)
def Teqnroot(pars, k):
bisec = jaxopt.Bisection(
optimality_fun=Teqn,
lower=0.00001,
upper=0.1,
check_bracket=False)
return bisec.run(pars=pars, k=k).params
def canmodel_trans(eps, theta, sigma_b_s, k, tau_c, uprbd):
# generate the noise
# theta = theta.at[5].set(jnp.exp(theta[5]))
# generate the result object
# eps_obs = jnp.array(jax.random.normal(key, shape=[N,5]))
# simualate the five random effects from the parameters
A = jnp.exp(theta[0] + jnp.exp(theta[5]) * eps[0])
b = jnp.exp(theta[1] + sigma_b_s[0] * eps[1])
C = jnp.exp(theta[2] + sigma_b_s[1] * eps[2])
n = jnp.exp(theta[3] + sigma_b_s[2] * eps[3])
# use transformations so that s0 is bounded at [0,1]
eta = jnp.exp(theta[4] + sigma_b_s[3] * eps[4])
s0 = eta / (1 + eta)
pars = jnp.array([A, b, C, n, s0])
# calculate the T0, after which the load is held as a constant
T0 = tau_c / k
# calculate the two boundary values of Teqs, preparing to solve Teqs
# Teqnupper = Teqn(Ts = 0.1, pars = pars, k = k)
# Teqnlower = Teqn(Ts = 0.00001, pars = pars, k = k)
z = Teqnroot(pars, k)
res = jnp.select(
[(z < 0.00001) | (z > 0.1), (T0 < 0) | (z < T0), T0 / z < pars[4]],
[uprbd, z, uprbd], default=const_Tc(z, T0, pars, k, uprbd))
# return res
res = jnp.select(
[res == uprbd , res!= uprbd],
[uprbd*(1+jnp.mean(jax.vmap(expit)(theta))),res]
)
return res
@jax.jit
def sumstat(y_obs,quant, t_c):
prop_cen = jnp.mean(jax.vmap(lambda y_obs: sigmoid(x = (y_obs - t_c), a = 1))(y_obs))
sumstat_quant = ott.tools.soft_sort.quantile(jnp.log(jnp.where(y_obs < t_c, y_obs, t_c)),
quant*(1-prop_cen))
return jnp.append(sumstat_quant,
prop_cen)
def canmodel_gen_obs(theta, sigma_b_s, k, tau_c, uprbd, key):
eps_obs = jnp.array(jax.random.normal(key, shape=[N, 5]))
return jax.vmap(
canmodel_trans,
in_axes=(
0,
None,
None,
None,
None,
None),
out_axes=0)(
eps_obs,
theta,
sigma_b_s,
k,
tau_c,
uprbd)
def Gmat_adjust(G):
R = jnp.shape(G)[0]
return jnp.row_stack([G, -jnp.sum(G, axis=0) / R * jnp.maximum(1, jnp.log(R) / 2)])
def canmodel_gen(eps_obs, theta, sigma_b_s, k, tau_c, uprbd):
return jax.vmap(
canmodel_trans,
in_axes=(
0,
None,
None,
None,
None,
None),
out_axes=0)(
eps_obs,
theta,
sigma_b_s,
k,
tau_c,
uprbd)
def eps_gen(key):
return (jax.random.normal(key, shape=(R,N, 5)))
y_obs = canmodel_gen_obs(theta0, sigma_b_s, k, tau_c, uprbd, subkeys[1])
sigma_b_s = sigma_b_s
N = 300
R = 100
quant = jnp.array([.2,.8])
y_obs_sum = sumstat(y_obs, quant, t_c)
k = k
t_c = t_c
tau_c = tau_c
uprbd = uprbd
supp_adj = True
n_steps = 100
key = subkeys[2]
def canmodel_G(theta, sigma_b_s, quant, y_obs_sum, t_c, k, tau_c, uprbd, key):
eps_obs = eps_gen(key)
t_sim = jax.vmap(
canmodel_gen,
in_axes=(
0,
None,
None,
None,
None,
None),
out_axes=0)(
eps_obs,
theta,
sigma_b_s,
k,
tau_c,
uprbd)
t_sumstat = jax.vmap(
sumstat,
in_axes=(
0,
None,
None),
out_axes=0)(
t_sim,
quant,
t_c)
G = jax.vmap(
jnp.subtract,
in_axes=(
0,
None),
out_axes=0)(
t_sumstat,
y_obs_sum)
return(G)
def G_sum(theta):
G = canmodel_G(theta, sigma_b_s, quant, y_obs_sum, t_c, k, tau_c, uprbd, key)
return(jnp.sum(G,axis= 0))
theta_init = jnp.array([ -7.61107873, 3.55245071, -20.19049176, -1.05644388, 0.15608539,
-0.61183416])
jax.jacobian(canmodel_G,0)(theta_init, sigma_b_s, quant, y_obs_sum, t_c, k, tau_c, uprbd, key)
# This one works fine
jax.jacobian(G_sum)(theta_init)
# Array([[ nan, nan, nan, nan,
# nan, nan],
# [ nan, nan, nan, nan,
# nan, nan],
# [-10.7691014 , -0.64024475, -0.48398892, 6.52979658,
# 9.39442437, 1.79699062]], dtype=float64)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Truly sorry that the code is a bit long and may require
jaxopt
andott
for reproduction.I came across an issue that jax.jacobian() gives nan and reports 'Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.'. I googled that but cannot find any helpful answer for me.
(Some posts say that it happens when your result is too large. But maybe my result is too small...)
In my coding, when I calculate the jax.jacobian() of a matrix w.r.t. a parameter vector, everything works fine. But then I just need to calculate the
jax.jacobian()
of the sum of the matrix by columns, w.r.t. the parameter vector, then jax.jacobian() report nans... The coding may be long, but the issue is not complicated to understand...I appreciate it if someone could give some advice about how to remove this nan result...
Beta Was this translation helpful? Give feedback.
All reactions