Replies: 3 comments 8 replies
-
Could you provide full traceback? IIUC, def choose_action(self, obs):
self.key, key1, key2 = jax.random.split(self.key, 3)
subkeys = jax.random.split(key1, self.optimization_steps), jax.random.split(key2, self.optimization_steps)
# you can extract this line outside jitted function
obs = obs.astype('float').reshape(1, -1, 1)
elite_size=15
var_len = self._actdim * self.horizon
batch = self.batch_size
batched_weighted_sample_fn = jax.vmap(lambda weight,sample: weight*sample, in_axes=(0, 0), out_axes=(0))
def cem(val, keys):
means, stds = val
subkey1, subkey2 = keys
samples = tfd.MultivariateNormalDiag(means, stds).sample(sample_shape=[batch], seed=subkey1)
samples = jnp.clip(samples_, self.min_action, self.max_action)
#fitness, _ = eval_fitness(samples, subkey2)
fitness = jnp.ones([batch, 1])
elite_values, elite_inds = jax.lax.top_k(jnp.squeeze(fitness), elite_size)
elite_samples = samples[elite_inds]
new_means = jnp.mean(elite_samples, axis=0)
new_vars = jnp.var(elite_samples, axis=0)
new_stds = jnp.sqrt(new_vars + 1e-6)
return (new_means, new_stds), None
means = jnp.zeros((self.horizon, self._actdim), dtype=self._float)
stds = jnp.ones((self.horizon, self._actdim), dtype=self._float)
init_val = (means, stds)
means, stds = jax.lax.scan(cem, init_val, subkeys)[0]
return means[0] |
Beta Was this translation helpful? Give feedback.
-
What version of |
Beta Was this translation helpful? Give feedback.
-
The problem is that your function is not pure. It contains the line means, stds, self.key = jax.lax.fori_loop(0, self.optimization_steps, cem, init_val) which is mutating one of the input arguments, |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I was trying to write a CEM Controller in JAX (Ref code: https://github.com/zchuning/latco/blob/6aab525b66efb8c99e55d6e0587a7bd31a599809/planners/shooting_cem.py).
My
choose_action
function looks as follows (I have showed just the relevant bits to keep the example small):Using this, I run into
UnexpectedTracerError
. My understanding is that this is happening since I am splitting the key inside a jitted function. Since this function is a member of a class, one way to handle this might be to use split and create the keys before thecem
function is called, and store them as class variables. In this scenario, I would have to jit thecem
function rather than the parentchoose_action
fn. I am currently doing this but it is quite slow (which is surprising since the majority of the other function isn't doing much). What would be the best way to sample inside a jitted function?Beta Was this translation helpful? Give feedback.
All reactions