Replies: 1 comment
-
Hi, I try to test your code and find that Luckily, we can have a workaround for your case. Since your function is import jax.numpy as jnp
import jax
from jax.experimental.jet import jet
import numpy as np
from jax import hessian
N = 10
xi = jnp.linspace(0, 1, N, endpoint=False)
grid = jnp.array(jnp.meshgrid(xi, xi)).T.reshape(-1, 2)
a = jnp.array([1, 1])
f = lambda x: jnp.cos(2 * jnp.pi * jnp.dot(a, x)) + jnp.sin(2 * jnp.pi * jnp.dot(a, x))
# define `g` and `h`
g = lambda y: jnp.cos(y) + jnp.sin(y)
h = 2 * jnp.pi * grid @ a
# compute Laplacian of `g`
def laplacian_g(y):
v = jnp.ones_like(y)
zeros = jnp.zeros_like(y)
lap3 = jet(g, (y,), ((v, zeros, zeros, zeros, zeros, zeros),))[1][-1]
return lap3
lap3 = jax.vmap(laplacian_g)(h) * (2 * jnp.pi * jnp.linalg.norm(a)) ** 6
# the rest is the approximation error
lap3_exact = -((2 * jnp.pi * jnp.linalg.norm(a)) ** 6) * jax.vmap(f)(grid)
print(jnp.linalg.norm(lap3_exact - lap3) / np.linalg.norm(lap3_exact))
# 2.956196e-06
hess3 = jax.vmap(hessian(hessian(hessian(f))))(grid)
lap3_ad = jnp.trace(
jnp.trace(jnp.trace(hess3, axis1=-2, axis2=-1), axis1=-2, axis2=-1),
axis1=-2,
axis2=-1,
)
print(jnp.linalg.norm(lap3_exact - lap3_ad) / np.linalg.norm(lap3_exact))
# 1.0007921e-07 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to find an efficient way to compute high order Laplace operator (to later use on neural networks) and it seems
experimental.jet
should be the solution, although when trying simple examples this mode of auto-diff seems to accumulate large errors.Precisely, I am looking at third order laplacian in 2d, which would be: dx6 + 3dx4y2 + 3dx2y4 + dy6.
Taking a simple function on a 2d grid:
Then it seems the
experimental.jet
way of taking the third-order laplacian is as follow:But when comparing to the exact operator the relative error is huge:
In contrast, a very naive implementation using standard jax automatic differentiation gives a precise results:
Is there an error in the way I'm using Jet ? Maybe I'm not dealing with the mixed derivative correctly ?
Beta Was this translation helpful? Give feedback.
All reactions