-
I'm trying to recreate the PINN toy example from here. I'm using this for reference as well. I'm stuck in the part where I need to compute the first & second derivatives of my NN with respect to the input variable, in order to compute the "physics loss". I understand it should be possible with either non-trivial import flax.linen as nn
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
key = jax.random.key(123)
# harmonic oscillator params
d, w0 = 2, 20
class FCN(nn.Module):
output_dim: int
@nn.compact
def __call__(self, x):
return nn.Dense(self.output_dim)(x)
fcn = FCN(output_dim=1)
params = fcn.init(key, jnp.ones((1, 1)))
opt = optax.adam(1e-4)
state = TrainState.create(apply_fn=fcn.apply, params=params, tx=opt)
@jax.jit
def physics_loss_fn(params):
x_phy = jnp.linspace(0, 1, 30)
x_phy = jnp.expand_dims(x_phy, axis=-1)
y = fcn.apply(params, x_phy)
# TODO: what's the best way to compute the first & second grads?
_, vjp_fun = jax.vjp(fcn.apply, params, x_phy)
_, dy = vjp_fun(y)
_, dy2 = vjp_fun(dy)
mu = 2 * d
k = w0**2
residuals = dy2 + mu * dy + k * y
return jnp.mean(jnp.square(residuals))
@jax.jit
def mse(params, x, y):
y_pred = fcn.apply(params, x)
return jnp.mean(jnp.square(y_pred - y))
@jax.jit
def loss_fn(params, x, y):
phy_loss = physics_loss_fn(params)
mse_loss = mse(params, x, y)
return mse_loss + 1e-4 * phy_loss
x = jnp.linspace(0, 1, 512)
x = jnp.expand_dims(x, axis=-1)
y = jnp.zeros_like(x) # should be harmonic oscillator data
jax.value_and_grad(loss_fn)(state.params, x, y) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I can't say too much about the specific convergence behavior of this example, but I'd recommend taking a look at this part of the JAX autodiff cookbook which discusses various approaches to computing second derivatives. This will provide some useful context! In your case, I think the key point is that the derivatives that you want are import jax
import jax.numpy as jnp
def first_and_second_derivative(f):
def wrapped(x):
assert jnp.shape(x) == ()
dx = jnp.ones_like(x)
df, df2 = jax.jvp(jax.grad(f), (x,), (dx,))
return df, df2
return wrapped
def fun(x):
return jnp.sin(x)
# For scalar inputs and outputs we use the function directly
x = 0.5
dfdx, df2dx2 = first_and_second_derivative(fun)(x)
print(dfdx - jnp.cos(x), df2dx2 + jnp.sin(x))
# For arrays of inputs we must vmap
x = jnp.linspace(-1, 1, 5)
dfdx, df2dx2 = jax.vmap(first_and_second_derivative(fun))(x)
print(dfdx - jnp.cos(x), df2dx2 + jnp.sin(x)) |
Beta Was this translation helpful? Give feedback.
In this case you could define a batched version of that function by moving the
vmap
inside as follows:Depending on that shape of the output of
f
, you might need to squeeze it, e.g.: