Replies: 1 comment
-
Hi Kirk, thanks for reaching out! I modified your code a bit, and the following works for me: import jax
import jax.numpy as jnp
from jax import jit, value_and_grad
import jaxley as jx
cell = jx.Cell()
cell.stimulate(jx.step_current(1.0, 1.0, 0.1, 0.025, 5.0))
cell.record()
cell.make_trainable("radius")
params = cell.get_parameters()
# Pre-compute the locations of the grid on which the KDE will be
# evaluated.
observed_v = jx.integrate(cell, params=params)
observed_dv = (observed_v[:,2:] - observed_v[:,:-2]) / 2
xmin = jnp.squeeze(jnp.min(observed_v, axis=1))
xmax = jnp.squeeze(jnp.max(observed_v, axis=1))
ymin = jnp.squeeze(jnp.min(observed_dv, axis=1))
ymax = jnp.squeeze(jnp.max(observed_dv, axis=1))
X, Y = jnp.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = jnp.vstack([X.ravel(), Y.ravel()])
def summary_stats(v):
dv = (v[:,2:] - v[:,:-2])/2
v_dv = jnp.vstack([v[:,1:-1], dv])
kernel = jax.scipy.stats.gaussian_kde(v_dv)
return jnp.asarray(kernel(positions))
observed_ss = summary_stats(observed_v)
# Modify the parameters such that the loss is not 0.0 (which is boring).
params[0]["radius"] = params[0]["radius"].at[0].set(2.0)
def simulate(params):
return jx.integrate(cell, params=params)
def loss_from_v(v):
ss = summary_stats(v)
return jnp.sum(jnp.sqrt(jnp.abs((ss - observed_ss))))
def loss_fn(params):
v = simulate(params)
return loss_from_v(v)
gradient_fn = jit(value_and_grad(loss_fn))
gradient_fn(params) What did I change?Methods like I hope this helps! Let me know if you have more questions or if things are not behaving as expected, user feedback is super important for us! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to test jaxley with a Phase Plane Density loss function but am getting JAX tracer array Concretisation errors. I thought I'd followed the requirements of JAX, and everything is using jnp specific functions. The initial run through to produce target trace works ok, and subsequent (looped) runs work if I remove JIT wrapper from this function, but then it's very slow.
PHASE PLANE DENSITY LOSS FUNCTION: replaces summary_stats(), and loss_from_v() in 01 synthetic (see below for full code to reproduce)
def summary_stats(v):
dv = (v[:,2:] - v[:,:-2])/2
xmin = jnp.squeeze(jnp.min(v, axis=1))
xmax = jnp.squeeze(jnp.max(v, axis=1))
ymin = jnp.squeeze(jnp.min(dv, axis=1))
ymax = jnp.squeeze(jnp.max(dv, axis=1))
X, Y = jnp.mgrid[xmin:xmax:100j, ymin:ymax:100j] #DEBUG: this line causes Concretisation error
positions = jnp.vstack([X.ravel(), Y.ravel()])
v_dv = jnp.vstack([jnp.asarray(v[:,1:-1]),jnp.asarray(dv)])
kernel = jax.scipy.stats.gaussian_kde(v_dv)
return jnp.asarray(kernel(positions))
def loss_from_v(v):
ss = summary_stats(v)
return jnp.sum(jnp.sqrt(jnp.abs((ss - x_o_ss))), axis=0)
FULL CODE TO REPRODUCE (just 01 synthetic from original jaxley paper with summary_stats() and loss_from_v() replaced by this code). Unrelated code is commented out, but retained for clarity.
01_synthetic(testing with phase plane density loss).txt
Beta Was this translation helpful? Give feedback.
All reactions