Replies: 3 comments 8 replies
-
I'm not sure if this is what you've already done, since you said you tried using key = jax.random.PRNGKey(0)
data = jax.random.randint(key, (50,), 0, 5)
def while_cond(carry):
_, _, value = carry
return value > 0
def while_body(carry):
write_array, write_index, value = carry
value = value // 2
write_array = write_array.at[write_index].set(value)
return write_array, write_index + 1, value
def f(carry, val):
write_array, write_index = carry
write_array, write_index, _ = jax.lax.while_loop(
cond_fun=while_cond,
body_fun=while_body,
init_val=(write_array, write_index, val)
)
return (write_array, write_index), None
max_len = 150
init_carry = (-1 * jnp.ones(max_len, dtype=jnp.int32)), 0
(out_array, n_valid), _ = jax.lax.scan(f, init_carry, data) |
Beta Was this translation helpful? Give feedback.
-
The fundamental difficulty here is that the output of your code is an array whose shape depends on the values of the inputs; this is incompatible with JAX's tracing and compilation model, which requires the shapes of input and output arrays to be known statically. Regarding |
Beta Was this translation helpful? Give feedback.
-
Has anyone made progress on this question? Even trying the I'm in the same situation of @mmuckley where the output is very sparse, only every say 1000th step on average is relevant information; even though I know how much relevant information I will have in the end, just outputting at every step and then filtering this very sparse array consumes a lot of memory. A minimal example could look like this: import jax
import jax.lax as lax
import jax.numpy as jnp
input_array = jnp.arange(10000)
every_1000th_result = jnp.zeros((input_array.shape[0] // 1000,3))
@jax.jit
def scan_body(carry, input_val):
every_1000th_result, next_idx_to_write,prev_result,iteration_number = carry
# update the result
current_result = prev_result + input_val
# only if iteration_number is a multiple of 1000, write to every_1000th_result at index next_idx_to_write
every_1000th_result = lax.cond(
iteration_number % 1000 == 0,
lambda x: x.at[next_idx_to_write,:].set(jnp.array([current_result,current_result+1,current_result+2])),
lambda x: x,
every_1000th_result
)
# increment next_idx_to_write
next_idx_to_write = lax.cond(
iteration_number % 1000 == 0,
lambda x: x + 1,
lambda x: x,
next_idx_to_write
)
# increment iteration_number
iteration_number = iteration_number + 1
# return the updated carry and the current result
return (every_1000th_result, next_idx_to_write, current_result, iteration_number), None
# Initialize the carry
initial_carry = (every_1000th_result, 0, 0, 0)
# Use lax.scan to iterate over the input array
final_carry, _ = lax.scan(scan_body, initial_carry, input_array)
# Extract the result from the final carry
print("Got every_1000th result:", final_carry[0]) Here, I want to make sure that [NB: I'm returning |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello. I have an application where I would like to scan over an array and, depending on the current value and state, conditionally append to an output list. A single value in the input array could result in zero to many append calls. I have an idea of the maximum size of the output array, but in practice it will usually be much smaller, and I can't predict in advance how small.
For my current implementation I pre-allocate the output array and insert values using
jax.ops.index_update
, but this copies the entire output array at every value, which is extremely slow.I was wondering if there could be a way to do this using a pytree, but I don't really understand them. In pure Python, I would be able to call
a.append(b)
, but since JAX uses pure functions, I can't do this as it mutatesa
.Altering
lax.scan
would help if I could control theys.append(y)
line, but as I said earlier, a single value ofx
could result in zero or many appends.Based on the recent update in sharp bits, I thought that it might be possible to get
jax.ops.index_update
by removing all references to input variables, but I didn't have any success with this. It could be that theindex_update
command is called deep in the stack and there is a dangling reference.So I'm a bit at a loss for any way to speed up the code while using JAX and was curious if any others had any ideas. I don't think I can provide the exact code yet, but this might get at the idea:
Beta Was this translation helpful? Give feedback.
All reactions