You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am struggling with this example for many days and humbly ask for any help...
from jax.config import config
from jax import lax, jit
import jax.random
import jax.numpy as jnp
import jax
import jax.scipy as jsp
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
config.update("jax_enable_x64", True)
Calculate f'' and g'', using jax.hessian()
def gammainc(a, x):
"""
Incomplete Gamma function by series expansion.
Args:
a: Shape parameter (scalar).
x: Quantile parameter (scalar).
n_terms: Number of terms in the expansion.
Returns:
An estimate of `jax.scipy.gammainc(a=a, x=x)`.
"""
n_terms = 100
C = jnp.cumprod(x / (a + jnp.arange(n_terms) + 1.))
S = 1. + jnp.sum(C)
f = jnp.power(x, a) * jnp.exp(-x) / jsp.special.gamma(a + 1.)
return f * S
def const_func(pars):
A = pars[0]
b = pars[1]
C = pars[2]
n = pars[3]
s0 = pars[4]
mu = 1
Ts = 0.02281313
k = 388440
As = A * k * Ts
Cs = C * k * Ts
# extraload = jnp.where(T0/Ts-s0>=0,T0/Ts-s0,0)
intfac = 1.
bnfr = (b + 1) / (n + 1)
f = jnp.exp(intfac) /jnp.power(As,b)
f = f * jnp.power(mu / Ts * (n + 1),(b - n) / (n + 1))
f = f * gammainc(bnfr, intfac) * jsp.special.gamma(bnfr)
g = jnp.power(Cs,n * bnfr)
return f,g
pars = jnp.array([2.76838370e-04, 5.42784687e+01, 2.41308812e-09, 6.76634129e-01,
5.22977021e-01])
f, g = const_func(pars)
f_jac, g_jac = jax.jacobian(const_func)(pars)
f_hess, g_hess = jax.hessian(const_func)(pars)
Calculate $h'' = (\frac{f}{g})''$ using hax.hessian:
def const_h(pars):
A = pars[0]
b = pars[1]
C = pars[2]
n = pars[3]
s0 = pars[4]
mu = 1
Ts = 0.02281313
k = 388440
As = A * k * Ts
Cs = C * k * Ts
# extraload = jnp.where(T0/Ts-s0>=0,T0/Ts-s0,0)
intfac = 1.
bnfr = (b + 1) / (n + 1)
f = jnp.exp(intfac) /jnp.power(As,b)
f = f * jnp.power(mu / Ts * (n + 1),(b - n) / (n + 1))
f = f* gammainc(bnfr, intfac) * jsp.special.gamma(bnfr)
g = jnp.power(Cs,n * bnfr)
return f/g
h = const_h(pars)
h_jac = jax.jacobian(const_h)(pars)
h_hess = jax.hessian(const_h)(pars)
When I evaluate the hessian w.r.t. pars[0], both jax.hessian() and analytical hessian are equal
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Jax.hessian() can give valid results for$f''$ and $g''$ . However, jax.hessian gives nan for $(\frac{f}{g})''$ .
Further, when I assess the analytical hessian by Quotient rule,
I am struggling with this example for many days and humbly ask for any help...
Calculate f'' and g'', using jax.hessian()
Calculate$h'' = (\frac{f}{g})''$ using hax.hessian:
When I evaluate the hessian w.r.t. pars[0], both jax.hessian() and analytical hessian are equal
However, when I evaluate the hessian w.r.t. pars[1], and pars[3], jax.hessian() gives either inf or nan while analytical hessian works fine...
Beta Was this translation helpful? Give feedback.
All reactions