-
Hi all, I have a traj = {'action': ..., 'goal':..., etc} I have a for loop that manipulates traj: for k,v in traj.items():
v = ...
if k!='reward':
v = ....
traj[k] = v Obviously, I want to parallelize this for loop. def op(carry, k_v):
k, v = k_v
v = ..
if k!=0: #I changed k to 0 if it is 'reward' and else to 1, to avoid jax error. not sure if necessary
v = ..
return carry + [(k,v)], None
inputs = jnp.array([[0,v] if k==''reward" else [1,v] for k,v in traj.items()])
results, _ = jax.lax.scan(operation, [], inputs)
traj = dict(results) I currently get error over the inputs = [[0,v] if k==''reward" else [1,v] for k,v in traj.items()] I instead get an error in the
Can I get some help to overcome this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
The way I solved this is by using: jax.tree_util.tree_map(op, traj, keys) where |
Beta Was this translation helpful? Give feedback.
The way I solved this is by using:
where
op
is the logic I applied over any value, and keys is the keys oftraj
that I needed to use also inside that logic